Files
galaxy-game/backend/internal/server/middleware/geocounter/geocounter_test.go
T
2026-05-06 10:14:55 +03:00

165 lines
3.7 KiB
Go

package geocounter_test
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"galaxy/backend/internal/server/middleware/geocounter"
"galaxy/backend/internal/server/middleware/userid"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type recordingSvc struct {
mu sync.Mutex
calls []recordedCall
}
type recordedCall struct {
UserID uuid.UUID
SourceIP string
}
func (r *recordingSvc) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) {
r.mu.Lock()
defer r.mu.Unlock()
r.calls = append(r.calls, recordedCall{UserID: userID, SourceIP: sourceIP})
}
func (r *recordingSvc) snapshot() []recordedCall {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]recordedCall, len(r.calls))
copy(out, r.calls)
return out
}
func newEngine(t *testing.T, svc geocounter.Service) *gin.Engine {
t.Helper()
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(userid.Middleware())
r.Use(geocounter.Middleware(svc))
r.GET("/probe", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
return r
}
func TestMiddlewareInvokesIncrementOnAuthenticatedRequest(t *testing.T) {
t.Parallel()
svc := &recordingSvc{}
r := newEngine(t, svc)
userID := uuid.New()
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
req.Header.Set(userid.Header, userID.String())
req.Header.Set("X-Forwarded-For", "203.0.113.5")
req.RemoteAddr = "10.0.0.1:1000"
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: want 200, got %d", rec.Code)
}
calls := svc.snapshot()
if len(calls) != 1 {
t.Fatalf("calls: want 1, got %+v", calls)
}
if calls[0].UserID != userID {
t.Errorf("user id: want %s, got %s", userID, calls[0].UserID)
}
if calls[0].SourceIP != "203.0.113.5" {
t.Errorf("source ip: want 203.0.113.5, got %q", calls[0].SourceIP)
}
}
func TestMiddlewareFallsBackToRemoteAddr(t *testing.T) {
t.Parallel()
svc := &recordingSvc{}
r := newEngine(t, svc)
userID := uuid.New()
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
req.Header.Set(userid.Header, userID.String())
req.RemoteAddr = "198.51.100.7:60000"
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
calls := svc.snapshot()
if len(calls) != 1 {
t.Fatalf("calls: want 1, got %+v", calls)
}
if calls[0].SourceIP != "198.51.100.7" {
t.Errorf("source ip: want 198.51.100.7, got %q", calls[0].SourceIP)
}
}
func TestMiddlewareSkipsWhenNoSourceIP(t *testing.T) {
t.Parallel()
svc := &recordingSvc{}
r := newEngine(t, svc)
userID := uuid.New()
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
req.Header.Set(userid.Header, userID.String())
req.RemoteAddr = ""
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if calls := svc.snapshot(); len(calls) != 0 {
t.Fatalf("calls: want 0, got %+v", calls)
}
}
func TestMiddlewareSkipsWithoutUserContext(t *testing.T) {
t.Parallel()
svc := &recordingSvc{}
gin.SetMode(gin.TestMode)
r := gin.New()
// No userid.Middleware on this chain.
r.Use(geocounter.Middleware(svc))
r.GET("/probe", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
req.RemoteAddr = "203.0.113.5:1000"
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if calls := svc.snapshot(); len(calls) != 0 {
t.Fatalf("calls: want 0, got %+v", calls)
}
}
func TestMiddlewareNilServiceIsPassThrough(t *testing.T) {
t.Parallel()
r := newEngine(t, nil)
userID := uuid.New()
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
req.Header.Set(userid.Header, userID.String())
req.RemoteAddr = "203.0.113.5:1000"
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status with nil service: want 200, got %d", rec.Code)
}
}