package geocounter_test import ( "context" "net/http" "net/http/httptest" "sync" "testing" "galaxy/backend/internal/server/middleware/geocounter" "galaxy/backend/internal/server/middleware/userid" "github.com/gin-gonic/gin" "github.com/google/uuid" ) type recordingSvc struct { mu sync.Mutex calls []recordedCall } type recordedCall struct { UserID uuid.UUID SourceIP string } func (r *recordingSvc) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) { r.mu.Lock() defer r.mu.Unlock() r.calls = append(r.calls, recordedCall{UserID: userID, SourceIP: sourceIP}) } func (r *recordingSvc) snapshot() []recordedCall { r.mu.Lock() defer r.mu.Unlock() out := make([]recordedCall, len(r.calls)) copy(out, r.calls) return out } func newEngine(t *testing.T, svc geocounter.Service) *gin.Engine { t.Helper() gin.SetMode(gin.TestMode) r := gin.New() r.Use(userid.Middleware()) r.Use(geocounter.Middleware(svc)) r.GET("/probe", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) return r } func TestMiddlewareInvokesIncrementOnAuthenticatedRequest(t *testing.T) { t.Parallel() svc := &recordingSvc{} r := newEngine(t, svc) userID := uuid.New() req := httptest.NewRequest(http.MethodGet, "/probe", nil) req.Header.Set(userid.Header, userID.String()) req.Header.Set("X-Forwarded-For", "203.0.113.5") req.RemoteAddr = "10.0.0.1:1000" rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status: want 200, got %d", rec.Code) } calls := svc.snapshot() if len(calls) != 1 { t.Fatalf("calls: want 1, got %+v", calls) } if calls[0].UserID != userID { t.Errorf("user id: want %s, got %s", userID, calls[0].UserID) } if calls[0].SourceIP != "203.0.113.5" { t.Errorf("source ip: want 203.0.113.5, got %q", calls[0].SourceIP) } } func TestMiddlewareFallsBackToRemoteAddr(t *testing.T) { t.Parallel() svc := &recordingSvc{} r := newEngine(t, svc) userID := uuid.New() req := httptest.NewRequest(http.MethodGet, "/probe", nil) req.Header.Set(userid.Header, userID.String()) req.RemoteAddr = "198.51.100.7:60000" rec := httptest.NewRecorder() r.ServeHTTP(rec, req) calls := svc.snapshot() if len(calls) != 1 { t.Fatalf("calls: want 1, got %+v", calls) } if calls[0].SourceIP != "198.51.100.7" { t.Errorf("source ip: want 198.51.100.7, got %q", calls[0].SourceIP) } } func TestMiddlewareSkipsWhenNoSourceIP(t *testing.T) { t.Parallel() svc := &recordingSvc{} r := newEngine(t, svc) userID := uuid.New() req := httptest.NewRequest(http.MethodGet, "/probe", nil) req.Header.Set(userid.Header, userID.String()) req.RemoteAddr = "" rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if calls := svc.snapshot(); len(calls) != 0 { t.Fatalf("calls: want 0, got %+v", calls) } } func TestMiddlewareSkipsWithoutUserContext(t *testing.T) { t.Parallel() svc := &recordingSvc{} gin.SetMode(gin.TestMode) r := gin.New() // No userid.Middleware on this chain. r.Use(geocounter.Middleware(svc)) r.GET("/probe", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) req := httptest.NewRequest(http.MethodGet, "/probe", nil) req.RemoteAddr = "203.0.113.5:1000" rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if calls := svc.snapshot(); len(calls) != 0 { t.Fatalf("calls: want 0, got %+v", calls) } } func TestMiddlewareNilServiceIsPassThrough(t *testing.T) { t.Parallel() r := newEngine(t, nil) userID := uuid.New() req := httptest.NewRequest(http.MethodGet, "/probe", nil) req.Header.Set(userid.Header, userID.String()) req.RemoteAddr = "203.0.113.5:1000" rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status with nil service: want 200, got %d", rec.Code) } }