59 lines
1.9 KiB
Go
59 lines
1.9 KiB
Go
// Package geocounter exposes the gin middleware that records
|
|
// `(user_id, country)` counters for every authenticated user-surface
|
|
// request. The middleware sits one layer below `userid.Middleware` in
|
|
// the route chain: it relies on the parsed user id already being on
|
|
// the request context.
|
|
//
|
|
// The middleware never blocks: the underlying counter implementation
|
|
// looks up the country synchronously (mmap read) and dispatches the
|
|
// database upsert to a fire-and-forget goroutine. Errors from the
|
|
// asynchronous path are logged inside the geo service, never surfaced
|
|
// to the response.
|
|
package geocounter
|
|
|
|
import (
|
|
"context"
|
|
|
|
"galaxy/backend/internal/server/clientip"
|
|
"galaxy/backend/internal/server/middleware/userid"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// Service is the narrow contract the middleware needs from the geo
|
|
// package. It is satisfied by `*geo.Service` directly; tests inject a
|
|
// recording stub. A nil Service is allowed and disables the
|
|
// middleware's side effect.
|
|
type Service interface {
|
|
IncrementCounterAsync(ctx context.Context, userID uuid.UUID, sourceIP string)
|
|
}
|
|
|
|
// Middleware returns a gin handler that, after the wrapped handler
|
|
// chain has run, dispatches an `IncrementCounterAsync` call for the
|
|
// authenticated user and the originating IP. svc may be nil, in which
|
|
// case the middleware is a no-op pass-through.
|
|
//
|
|
// The middleware reads the user id from the request context populated
|
|
// by `userid.Middleware`; routes that mount this middleware without
|
|
// `userid.Middleware` ahead of it will silently skip the increment
|
|
// because the user id is absent.
|
|
func Middleware(svc Service) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Next()
|
|
|
|
if svc == nil {
|
|
return
|
|
}
|
|
userID, ok := userid.FromContext(c.Request.Context())
|
|
if !ok || userID == uuid.Nil {
|
|
return
|
|
}
|
|
ip := clientip.ExtractSourceIP(c)
|
|
if ip == "" {
|
|
return
|
|
}
|
|
svc.IncrementCounterAsync(c.Request.Context(), userID, ip)
|
|
}
|
|
}
|