Files
galaxy-game/lobby/internal/api/httpcommon/requestid_test.go
T
2026-04-25 23:20:55 +02:00

89 lines
2.6 KiB
Go

package httpcommon_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"galaxy/lobby/internal/api/httpcommon"
"galaxy/lobby/internal/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRequestIDPropagatesIncomingHeader(t *testing.T) {
t.Parallel()
var observed string
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
observed = logging.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
request.Header.Set(httpcommon.RequestIDHeader, "rid-test-1")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
assert.Equal(t, "rid-test-1", observed)
assert.Equal(t, "rid-test-1", recorder.Header().Get(httpcommon.RequestIDHeader))
}
func TestRequestIDGeneratesWhenMissing(t *testing.T) {
t.Parallel()
var observed string
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
observed = logging.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/foo", nil))
require.NotEmpty(t, observed)
assert.True(t, strings.HasPrefix(observed, "rid-"), "got %q", observed)
assert.Equal(t, observed, recorder.Header().Get(httpcommon.RequestIDHeader))
}
func TestRequestIDRejectsControlCharacters(t *testing.T) {
t.Parallel()
var observed string
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
observed = logging.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
request.Header.Set(httpcommon.RequestIDHeader, "bad\x00id")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
require.NotEqual(t, "bad\x00id", observed)
assert.True(t, strings.HasPrefix(observed, "rid-"))
}
func TestRequestIDRejectsOverlongValues(t *testing.T) {
t.Parallel()
var observed string
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
observed = logging.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
request.Header.Set(httpcommon.RequestIDHeader, strings.Repeat("a", 200))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
require.NotEqual(t, strings.Repeat("a", 200), observed)
assert.True(t, strings.HasPrefix(observed, "rid-"))
}