89 lines
2.6 KiB
Go
89 lines
2.6 KiB
Go
package httpcommon_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"galaxy/lobby/internal/api/httpcommon"
|
|
"galaxy/lobby/internal/logging"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRequestIDPropagatesIncomingHeader(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var observed string
|
|
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
observed = logging.RequestIDFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
|
request.Header.Set(httpcommon.RequestIDHeader, "rid-test-1")
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
assert.Equal(t, "rid-test-1", observed)
|
|
assert.Equal(t, "rid-test-1", recorder.Header().Get(httpcommon.RequestIDHeader))
|
|
}
|
|
|
|
func TestRequestIDGeneratesWhenMissing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var observed string
|
|
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
observed = logging.RequestIDFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
recorder := httptest.NewRecorder()
|
|
handler.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
|
|
|
require.NotEmpty(t, observed)
|
|
assert.True(t, strings.HasPrefix(observed, "rid-"), "got %q", observed)
|
|
assert.Equal(t, observed, recorder.Header().Get(httpcommon.RequestIDHeader))
|
|
}
|
|
|
|
func TestRequestIDRejectsControlCharacters(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var observed string
|
|
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
observed = logging.RequestIDFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
|
request.Header.Set(httpcommon.RequestIDHeader, "bad\x00id")
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
require.NotEqual(t, "bad\x00id", observed)
|
|
assert.True(t, strings.HasPrefix(observed, "rid-"))
|
|
}
|
|
|
|
func TestRequestIDRejectsOverlongValues(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var observed string
|
|
handler := httpcommon.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
observed = logging.RequestIDFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
request := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
|
request.Header.Set(httpcommon.RequestIDHeader, strings.Repeat("a", 200))
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
require.NotEqual(t, strings.Repeat("a", 200), observed)
|
|
assert.True(t, strings.HasPrefix(observed, "rid-"))
|
|
}
|