package httpcommon_test import ( "net/http" "net/http/httptest" "strings" "testing" "galaxy/lobby/internal/api/httpcommon" "galaxy/lobby/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRequestIDPropagatesIncomingHeader(t *testing.T) { t.Parallel() var observed string handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { observed = logging.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest(http.MethodGet, "/foo", nil) request.Header.Set(httpcommon.RequestIDHeader, "rid-test-1") recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) assert.Equal(t, "rid-test-1", observed) assert.Equal(t, "rid-test-1", recorder.Header().Get(httpcommon.RequestIDHeader)) } func TestRequestIDGeneratesWhenMissing(t *testing.T) { t.Parallel() var observed string handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { observed = logging.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/foo", nil)) require.NotEmpty(t, observed) assert.True(t, strings.HasPrefix(observed, "rid-"), "got %q", observed) assert.Equal(t, observed, recorder.Header().Get(httpcommon.RequestIDHeader)) } func TestRequestIDRejectsControlCharacters(t *testing.T) { t.Parallel() var observed string handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { observed = logging.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest(http.MethodGet, "/foo", nil) request.Header.Set(httpcommon.RequestIDHeader, "bad\x00id") recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) require.NotEqual(t, "bad\x00id", observed) assert.True(t, strings.HasPrefix(observed, "rid-")) } func TestRequestIDRejectsOverlongValues(t *testing.T) { t.Parallel() var observed string handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { observed = logging.RequestIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest(http.MethodGet, "/foo", nil) request.Header.Set(httpcommon.RequestIDHeader, strings.Repeat("a", 200)) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) require.NotEqual(t, strings.Repeat("a", 200), observed) assert.True(t, strings.HasPrefix(observed, "rid-")) }