package publichttp import ( "context" "encoding/json" "io" "net" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConfigValidate(t *testing.T) { t.Parallel() base := Config{ Addr: ":0", ReadHeaderTimeout: time.Second, ReadTimeout: time.Second, IdleTimeout: time.Second, } require.NoError(t, base.Validate()) tests := []struct { name string mutate func(*Config) wantErr string }{ {name: "empty addr", mutate: func(cfg *Config) { cfg.Addr = "" }, wantErr: "addr must not be empty"}, {name: "zero header", mutate: func(cfg *Config) { cfg.ReadHeaderTimeout = 0 }, wantErr: "read header timeout"}, {name: "zero read", mutate: func(cfg *Config) { cfg.ReadTimeout = 0 }, wantErr: "read timeout"}, {name: "zero idle", mutate: func(cfg *Config) { cfg.IdleTimeout = 0 }, wantErr: "idle timeout"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() cfg := base tt.mutate(&cfg) err := cfg.Validate() require.Error(t, err) require.Contains(t, err.Error(), tt.wantErr) }) } } func TestHandlerRoutes(t *testing.T) { t.Parallel() handler := newHandler(Dependencies{}, nil) server := httptest.NewServer(handler) t.Cleanup(server.Close) tests := []struct { name string method string path string wantStatus int wantStatusBody string }{ {name: "healthz", method: http.MethodGet, path: HealthzPath, wantStatus: http.StatusOK, wantStatusBody: "ok"}, {name: "readyz", method: http.MethodGet, path: ReadyzPath, wantStatus: http.StatusOK, wantStatusBody: "ready"}, {name: "not found", method: http.MethodGet, path: "/nope", wantStatus: http.StatusNotFound}, {name: "method not allowed", method: http.MethodPost, path: HealthzPath, wantStatus: http.StatusMethodNotAllowed}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() req, err := http.NewRequest(tt.method, server.URL+tt.path, nil) require.NoError(t, err) resp, err := server.Client().Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, tt.wantStatus, resp.StatusCode) if tt.wantStatusBody != "" { body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) var payload statusResponse require.NoError(t, json.Unmarshal(body, &payload)) assert.Equal(t, tt.wantStatusBody, payload.Status) } }) } } func TestShutdownBeforeRunIsNoop(t *testing.T) { t.Parallel() server, err := NewServer(Config{ Addr: "127.0.0.1:0", ReadHeaderTimeout: time.Second, ReadTimeout: time.Second, IdleTimeout: time.Second, }, Dependencies{}) require.NoError(t, err) require.NoError(t, server.Shutdown(context.Background())) } func TestServerRunAndShutdown(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) addr := listener.Addr().String() require.NoError(t, listener.Close()) server, err := NewServer(Config{ Addr: addr, ReadHeaderTimeout: time.Second, ReadTimeout: time.Second, IdleTimeout: time.Second, }, Dependencies{}) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) runErr := make(chan error, 1) go func() { runErr <- server.Run(ctx) }() require.Eventually(t, func() bool { return server.Addr() != "" }, 2*time.Second, 10*time.Millisecond) resp, err := http.Get("http://" + server.Addr() + HealthzPath) require.NoError(t, err) _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(shutdownCancel) require.NoError(t, server.Shutdown(shutdownCtx)) select { case err := <-runErr: require.NoError(t, err) case <-time.After(2 * time.Second): t.Fatal("server did not stop after shutdown") } }