feat: backend service
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
// Package geocounter exposes the gin middleware that records
|
||||
// `(user_id, country)` counters for every authenticated user-surface
|
||||
// request. The middleware sits one layer below `userid.Middleware` in
|
||||
// the route chain: it relies on the parsed user id already being on
|
||||
// the request context.
|
||||
//
|
||||
// The middleware never blocks: the underlying counter implementation
|
||||
// looks up the country synchronously (mmap read) and dispatches the
|
||||
// database upsert to a fire-and-forget goroutine. Errors from the
|
||||
// asynchronous path are logged inside the geo service, never surfaced
|
||||
// to the response.
|
||||
package geocounter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/backend/internal/server/clientip"
|
||||
"galaxy/backend/internal/server/middleware/userid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Service is the narrow contract the middleware needs from the geo
|
||||
// package. It is satisfied by `*geo.Service` directly; tests inject a
|
||||
// recording stub. A nil Service is allowed and disables the
|
||||
// middleware's side effect.
|
||||
type Service interface {
|
||||
IncrementCounterAsync(ctx context.Context, userID uuid.UUID, sourceIP string)
|
||||
}
|
||||
|
||||
// Middleware returns a gin handler that, after the wrapped handler
|
||||
// chain has run, dispatches an `IncrementCounterAsync` call for the
|
||||
// authenticated user and the originating IP. svc may be nil, in which
|
||||
// case the middleware is a no-op pass-through.
|
||||
//
|
||||
// The middleware reads the user id from the request context populated
|
||||
// by `userid.Middleware`; routes that mount this middleware without
|
||||
// `userid.Middleware` ahead of it will silently skip the increment
|
||||
// because the user id is absent.
|
||||
func Middleware(svc Service) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
if svc == nil {
|
||||
return
|
||||
}
|
||||
userID, ok := userid.FromContext(c.Request.Context())
|
||||
if !ok || userID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
ip := clientip.ExtractSourceIP(c)
|
||||
if ip == "" {
|
||||
return
|
||||
}
|
||||
svc.IncrementCounterAsync(c.Request.Context(), userID, ip)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package geocounter_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"galaxy/backend/internal/server/middleware/geocounter"
|
||||
"galaxy/backend/internal/server/middleware/userid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type recordingSvc struct {
|
||||
mu sync.Mutex
|
||||
calls []recordedCall
|
||||
}
|
||||
|
||||
type recordedCall struct {
|
||||
UserID uuid.UUID
|
||||
SourceIP string
|
||||
}
|
||||
|
||||
func (r *recordingSvc) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.calls = append(r.calls, recordedCall{UserID: userID, SourceIP: sourceIP})
|
||||
}
|
||||
|
||||
func (r *recordingSvc) snapshot() []recordedCall {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]recordedCall, len(r.calls))
|
||||
copy(out, r.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
func newEngine(t *testing.T, svc geocounter.Service) *gin.Engine {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(userid.Middleware())
|
||||
r.Use(geocounter.Middleware(svc))
|
||||
r.GET("/probe", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestMiddlewareInvokesIncrementOnAuthenticatedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &recordingSvc{}
|
||||
r := newEngine(t, svc)
|
||||
|
||||
userID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
||||
req.Header.Set(userid.Header, userID.String())
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.5")
|
||||
req.RemoteAddr = "10.0.0.1:1000"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status: want 200, got %d", rec.Code)
|
||||
}
|
||||
calls := svc.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls: want 1, got %+v", calls)
|
||||
}
|
||||
if calls[0].UserID != userID {
|
||||
t.Errorf("user id: want %s, got %s", userID, calls[0].UserID)
|
||||
}
|
||||
if calls[0].SourceIP != "203.0.113.5" {
|
||||
t.Errorf("source ip: want 203.0.113.5, got %q", calls[0].SourceIP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareFallsBackToRemoteAddr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &recordingSvc{}
|
||||
r := newEngine(t, svc)
|
||||
|
||||
userID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
||||
req.Header.Set(userid.Header, userID.String())
|
||||
req.RemoteAddr = "198.51.100.7:60000"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
calls := svc.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls: want 1, got %+v", calls)
|
||||
}
|
||||
if calls[0].SourceIP != "198.51.100.7" {
|
||||
t.Errorf("source ip: want 198.51.100.7, got %q", calls[0].SourceIP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareSkipsWhenNoSourceIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &recordingSvc{}
|
||||
r := newEngine(t, svc)
|
||||
|
||||
userID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
||||
req.Header.Set(userid.Header, userID.String())
|
||||
req.RemoteAddr = ""
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if calls := svc.snapshot(); len(calls) != 0 {
|
||||
t.Fatalf("calls: want 0, got %+v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareSkipsWithoutUserContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &recordingSvc{}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
// No userid.Middleware on this chain.
|
||||
r.Use(geocounter.Middleware(svc))
|
||||
r.GET("/probe", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
||||
req.RemoteAddr = "203.0.113.5:1000"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if calls := svc.snapshot(); len(calls) != 0 {
|
||||
t.Fatalf("calls: want 0, got %+v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareNilServiceIsPassThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newEngine(t, nil)
|
||||
|
||||
userID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
||||
req.Header.Set(userid.Header, userID.String())
|
||||
req.RemoteAddr = "203.0.113.5:1000"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status with nil service: want 200, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user