165 lines
3.7 KiB
Go
165 lines
3.7 KiB
Go
package geocounter_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"galaxy/backend/internal/server/middleware/geocounter"
|
|
"galaxy/backend/internal/server/middleware/userid"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type recordingSvc struct {
|
|
mu sync.Mutex
|
|
calls []recordedCall
|
|
}
|
|
|
|
type recordedCall struct {
|
|
UserID uuid.UUID
|
|
SourceIP string
|
|
}
|
|
|
|
func (r *recordingSvc) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.calls = append(r.calls, recordedCall{UserID: userID, SourceIP: sourceIP})
|
|
}
|
|
|
|
func (r *recordingSvc) snapshot() []recordedCall {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
out := make([]recordedCall, len(r.calls))
|
|
copy(out, r.calls)
|
|
return out
|
|
}
|
|
|
|
func newEngine(t *testing.T, svc geocounter.Service) *gin.Engine {
|
|
t.Helper()
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
r.Use(userid.Middleware())
|
|
r.Use(geocounter.Middleware(svc))
|
|
r.GET("/probe", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
return r
|
|
}
|
|
|
|
func TestMiddlewareInvokesIncrementOnAuthenticatedRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc := &recordingSvc{}
|
|
r := newEngine(t, svc)
|
|
|
|
userID := uuid.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
req.Header.Set(userid.Header, userID.String())
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.5")
|
|
req.RemoteAddr = "10.0.0.1:1000"
|
|
rec := httptest.NewRecorder()
|
|
|
|
r.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status: want 200, got %d", rec.Code)
|
|
}
|
|
calls := svc.snapshot()
|
|
if len(calls) != 1 {
|
|
t.Fatalf("calls: want 1, got %+v", calls)
|
|
}
|
|
if calls[0].UserID != userID {
|
|
t.Errorf("user id: want %s, got %s", userID, calls[0].UserID)
|
|
}
|
|
if calls[0].SourceIP != "203.0.113.5" {
|
|
t.Errorf("source ip: want 203.0.113.5, got %q", calls[0].SourceIP)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareFallsBackToRemoteAddr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc := &recordingSvc{}
|
|
r := newEngine(t, svc)
|
|
|
|
userID := uuid.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
req.Header.Set(userid.Header, userID.String())
|
|
req.RemoteAddr = "198.51.100.7:60000"
|
|
rec := httptest.NewRecorder()
|
|
|
|
r.ServeHTTP(rec, req)
|
|
|
|
calls := svc.snapshot()
|
|
if len(calls) != 1 {
|
|
t.Fatalf("calls: want 1, got %+v", calls)
|
|
}
|
|
if calls[0].SourceIP != "198.51.100.7" {
|
|
t.Errorf("source ip: want 198.51.100.7, got %q", calls[0].SourceIP)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareSkipsWhenNoSourceIP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc := &recordingSvc{}
|
|
r := newEngine(t, svc)
|
|
|
|
userID := uuid.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
req.Header.Set(userid.Header, userID.String())
|
|
req.RemoteAddr = ""
|
|
rec := httptest.NewRecorder()
|
|
|
|
r.ServeHTTP(rec, req)
|
|
|
|
if calls := svc.snapshot(); len(calls) != 0 {
|
|
t.Fatalf("calls: want 0, got %+v", calls)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareSkipsWithoutUserContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
svc := &recordingSvc{}
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
// No userid.Middleware on this chain.
|
|
r.Use(geocounter.Middleware(svc))
|
|
r.GET("/probe", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
req.RemoteAddr = "203.0.113.5:1000"
|
|
rec := httptest.NewRecorder()
|
|
|
|
r.ServeHTTP(rec, req)
|
|
|
|
if calls := svc.snapshot(); len(calls) != 0 {
|
|
t.Fatalf("calls: want 0, got %+v", calls)
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareNilServiceIsPassThrough(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newEngine(t, nil)
|
|
|
|
userID := uuid.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
req.Header.Set(userid.Header, userID.String())
|
|
req.RemoteAddr = "203.0.113.5:1000"
|
|
rec := httptest.NewRecorder()
|
|
|
|
r.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status with nil service: want 200, got %d", rec.Code)
|
|
}
|
|
}
|