package server import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/google/uuid" ) // TestRequireUserID checks that the middleware accepts a valid X-User-ID, // exposes it through the request context, and rejects missing or malformed // headers with 401. func TestRequireUserID(t *testing.T) { gin.SetMode(gin.TestMode) var seen uuid.UUID var ok bool r := gin.New() r.Use(RequireUserID()) r.GET("/x", func(c *gin.Context) { seen, ok = UserIDFromContext(c.Request.Context()) c.String(http.StatusOK, "ok") }) t.Run("valid", func(t *testing.T) { seen, ok = uuid.Nil, false id := uuid.New() req := httptest.NewRequest(http.MethodGet, "/x", nil) req.Header.Set("X-User-ID", id.String()) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want 200", rec.Code) } if !ok || seen != id { t.Fatalf("context id = %s (ok=%v), want %s", seen, ok, id) } }) t.Run("missing", func(t *testing.T) { rec := httptest.NewRecorder() r.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/x", nil)) if rec.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rec.Code) } }) t.Run("malformed", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/x", nil) req.Header.Set("X-User-ID", "not-a-uuid") rec := httptest.NewRecorder() r.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rec.Code) } }) }