package http import ( "context" "errors" "galaxy/connector" "galaxy/model/report" "io" "net" stdhttp "net/http" "net/http/httptest" "net/url" "reflect" "strings" "testing" "time" ) // checkConnectionCase describes one CheckConnection behavior scenario. type checkConnectionCase struct { name string setup func(t *testing.T) (*httpConnector, <-chan string) want bool wantPath string } // checkVersionCase describes one CheckVersion behavior scenario. type checkVersionCase struct { name string setup func(t *testing.T) (*httpConnector, <-chan string) want []connector.VersionInfo wantErr bool wantPath string } // downloadVersionCase describes one DownloadVersion behavior scenario. type downloadVersionCase struct { name string setup func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) want []byte wantErr bool wantPath string } // fetchReportCase describes one FetchReport behavior scenario. type fetchReportCase struct { name string setup func(t *testing.T) (*httpConnector, <-chan requestDetails) turn uint want report.Report wantErr bool wantPath string wantPlayer string wantTurn string wantRequest bool } // fetchReportResult captures one asynchronous FetchReport callback result. type fetchReportResult struct { report report.Report err error } // requestDetails captures one received request path and query parameters. type requestDetails struct { path string query url.Values } // TestCheckConnection verifies backend reachability probe behavior. func TestCheckConnection(t *testing.T) { tests := []checkConnectionCase{ { name: "status 200 returns true", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newServerConnector(t, context.Background(), stdhttp.StatusOK, "") }, want: true, wantPath: "/api/v1/status", }, { name: "non-2xx status returns true", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newServerConnector(t, context.Background(), stdhttp.StatusServiceUnavailable, "") }, want: true, wantPath: "/api/v1/status", }, { name: "canceled context returns false", setup: func(t *testing.T) (*httpConnector, <-chan string) { ctx, cancel := context.WithCancel(context.Background()) cancel() conn, err := NewHttpConnector(ctx, "http://127.0.0.1") if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, nil }, want: false, }, { name: "transport failure returns false", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newUnreachableConnector(t, context.Background()), nil }, want: false, }, { name: "backend path prefix is preserved", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newServerConnector(t, context.Background(), stdhttp.StatusOK, "/base") }, want: true, wantPath: "/base/api/v1/status", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conn, pathCh := tt.setup(t) got := conn.CheckConnection() if got != tt.want { t.Fatalf("CheckConnection() = %v, want %v", got, tt.want) } if tt.wantPath == "" { return } select { case gotPath := <-pathCh: if gotPath != tt.wantPath { t.Fatalf("request path = %q, want %q", gotPath, tt.wantPath) } default: t.Fatalf("expected request path %q, got no request", tt.wantPath) } }) } } // TestCheckVersion verifies versions retrieval behavior. func TestCheckVersion(t *testing.T) { tests := []checkVersionCase{ { name: "status 200 with valid body returns versions", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newVersionServerConnector( t, context.Background(), stdhttp.StatusOK, `[{"os":"darwin","version":"1.2.3","url":"https://example.com/darwin"}]`, "", ) }, want: []connector.VersionInfo{ {OS: "darwin", Version: "1.2.3", URL: "https://example.com/darwin"}, }, wantPath: "/api/v1/versions", }, { name: "status 200 with invalid json returns error", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newVersionServerConnector( t, context.Background(), stdhttp.StatusOK, `{"versions":`, "", ) }, wantErr: true, wantPath: "/api/v1/versions", }, { name: "non-200 status returns error", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newVersionServerConnector( t, context.Background(), stdhttp.StatusServiceUnavailable, `[]`, "", ) }, wantErr: true, wantPath: "/api/v1/versions", }, { name: "canceled context returns error", setup: func(t *testing.T) (*httpConnector, <-chan string) { ctx, cancel := context.WithCancel(context.Background()) cancel() conn, err := NewHttpConnector(ctx, "http://127.0.0.1") if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, nil }, wantErr: true, }, { name: "transport failure returns error", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newUnreachableConnector(t, context.Background()), nil }, wantErr: true, }, { name: "backend path prefix is preserved", setup: func(t *testing.T) (*httpConnector, <-chan string) { return newVersionServerConnector( t, context.Background(), stdhttp.StatusOK, `[{"os":"linux","version":"2.0.0","url":"https://example.com/linux"}]`, "/base", ) }, want: []connector.VersionInfo{ {OS: "linux", Version: "2.0.0", URL: "https://example.com/linux"}, }, wantPath: "/base/api/v1/versions", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conn, pathCh := tt.setup(t) got, err := conn.CheckVersion() if tt.wantErr { if err == nil { t.Fatal("CheckVersion() error = nil, want non-nil") } } else { if err != nil { t.Fatalf("CheckVersion() error = %v, want nil", err) } if !reflect.DeepEqual(got, tt.want) { t.Fatalf("CheckVersion() = %#v, want %#v", got, tt.want) } } if tt.wantPath == "" { return } select { case gotPath := <-pathCh: if gotPath != tt.wantPath { t.Fatalf("request path = %q, want %q", gotPath, tt.wantPath) } default: t.Fatalf("expected request path %q, got no request", tt.wantPath) } }) } } // TestFetchReport verifies asynchronous report retrieval behavior. func TestFetchReport(t *testing.T) { tests := []fetchReportCase{ { name: "status 200 with valid body returns report", setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { return newFetchReportServerConnector( t, context.Background(), stdhttp.StatusOK, `{"version":2,"turn":7,"mapWidth":120,"mapHeight":80,"race":"Race_01","votes":1.5}`, "", ) }, turn: 7, want: report.Report{ Version: 2, Turn: 7, Width: 120, Height: 80, Race: "Race_01", Votes: report.Float(1.5), }, wantPath: "/api/v1/report", wantPlayer: fetchReportPlayer, wantTurn: "7", wantRequest: true, }, { name: "status 200 with invalid json returns error", setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { return newFetchReportServerConnector( t, context.Background(), stdhttp.StatusOK, `{"turn":`, "", ) }, turn: 8, wantErr: true, wantPath: "/api/v1/report", wantPlayer: fetchReportPlayer, wantTurn: "8", wantRequest: true, }, { name: "non-200 status returns error", setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { return newFetchReportServerConnector( t, context.Background(), stdhttp.StatusBadGateway, `{}`, "", ) }, turn: 9, wantErr: true, wantPath: "/api/v1/report", wantPlayer: fetchReportPlayer, wantTurn: "9", wantRequest: true, }, { name: "canceled context returns error", setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { ctx, cancel := context.WithCancel(context.Background()) cancel() return newFetchReportServerConnector( t, ctx, stdhttp.StatusOK, `{"turn":1}`, "", ) }, turn: 10, wantErr: true, wantRequest: false, }, { name: "backend path prefix is preserved", setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { return newFetchReportServerConnector( t, context.Background(), stdhttp.StatusOK, `{"turn":11,"mapWidth":20,"mapHeight":30}`, "/base", ) }, turn: 11, want: report.Report{ Turn: 11, Width: 20, Height: 30, }, wantPath: "/base/api/v1/report", wantPlayer: fetchReportPlayer, wantTurn: "11", wantRequest: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conn, requestCh := tt.setup(t) resultCh := make(chan fetchReportResult, 1) conn.FetchReport("", tt.turn, func(rep report.Report, err error) { resultCh <- fetchReportResult{report: rep, err: err} }) select { case result := <-resultCh: if tt.wantErr { if result.err == nil { t.Fatal("FetchReport() error = nil, want non-nil") } } else { if result.err != nil { t.Fatalf("FetchReport() error = %v, want nil", result.err) } if !reflect.DeepEqual(result.report, tt.want) { t.Fatalf("FetchReport() report = %#v, want %#v", result.report, tt.want) } } case <-time.After(time.Second): t.Fatal("FetchReport() callback was not called") } if !tt.wantRequest { select { case req := <-requestCh: t.Fatalf("unexpected request = %#v", req) case <-time.After(100 * time.Millisecond): } return } select { case req := <-requestCh: if req.path != tt.wantPath { t.Fatalf("request path = %q, want %q", req.path, tt.wantPath) } if got := req.query.Get("player"); got != tt.wantPlayer { t.Fatalf("request player = %q, want %q", got, tt.wantPlayer) } if got := req.query.Get("turn"); got != tt.wantTurn { t.Fatalf("request turn = %q, want %q", got, tt.wantTurn) } case <-time.After(time.Second): t.Fatal("expected request, got none") } }) } } // TestFetchReportAsync verifies FetchReport returns immediately and calls callback once after response is ready. func TestFetchReportAsync(t *testing.T) { requestCh := make(chan requestDetails, 1) releaseResponse := make(chan struct{}) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { requestCh <- requestDetails{path: r.URL.Path, query: r.URL.Query()} <-releaseResponse w.WriteHeader(stdhttp.StatusOK) _, _ = w.Write([]byte(`{"turn":12,"mapWidth":1,"mapHeight":1}`)) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(context.Background(), server.URL) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } resultCh := make(chan fetchReportResult, 2) start := time.Now() conn.FetchReport("", 12, func(rep report.Report, err error) { resultCh <- fetchReportResult{report: rep, err: err} }) if elapsed := time.Since(start); elapsed > 50*time.Millisecond { t.Fatalf("FetchReport() elapsed = %v, want quick return", elapsed) } select { case req := <-requestCh: if req.path != "/api/v1/report" { t.Fatalf("request path = %q, want %q", req.path, "/api/v1/report") } if got := req.query.Get("player"); got != fetchReportPlayer { t.Fatalf("request player = %q, want %q", got, fetchReportPlayer) } if got := req.query.Get("turn"); got != "12" { t.Fatalf("request turn = %q, want %q", got, "12") } case <-time.After(time.Second): t.Fatal("expected request, got none") } select { case result := <-resultCh: t.Fatalf("unexpected early callback = %#v", result) case <-time.After(100 * time.Millisecond): } close(releaseResponse) select { case result := <-resultCh: if result.err != nil { t.Fatalf("FetchReport() error = %v, want nil", result.err) } if result.report.Turn != 12 { t.Fatalf("FetchReport() report turn = %d, want %d", result.report.Turn, 12) } case <-time.After(time.Second): t.Fatal("FetchReport() callback was not called") } select { case extra := <-resultCh: t.Fatalf("FetchReport() callback called more than once: %#v", extra) case <-time.After(100 * time.Millisecond): } } // TestDownloadVersion verifies artifact download behavior for relative and absolute URLs. func TestDownloadVersion(t *testing.T) { tests := []downloadVersionCase{ { name: "relative path uses backend URL", setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { pathCh := make(chan string, 1) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { pathCh <- r.URL.Path w.WriteHeader(stdhttp.StatusOK) _, _ = w.Write([]byte("artifact")) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(context.Background(), server.URL+"/base") if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, "downloads/client.bin", pathCh, func(t *testing.T) {} }, want: []byte("artifact"), wantPath: "/base/downloads/client.bin", }, { name: "fully qualified URL bypasses backend URL", setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { backendPathCh := make(chan string, 1) backendServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { backendPathCh <- r.URL.Path w.WriteHeader(stdhttp.StatusInternalServerError) })) t.Cleanup(backendServer.Close) downloadPathCh := make(chan string, 1) downloadServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { downloadPathCh <- r.URL.Path w.WriteHeader(stdhttp.StatusOK) _, _ = w.Write([]byte("external-artifact")) })) t.Cleanup(downloadServer.Close) conn, err := NewHttpConnector(context.Background(), backendServer.URL+"/base") if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, downloadServer.URL + "/artifact.bin", downloadPathCh, func(t *testing.T) { t.Helper() select { case gotPath := <-backendPathCh: t.Fatalf("unexpected backend request path = %q", gotPath) default: } } }, want: []byte("external-artifact"), wantPath: "/artifact.bin", }, { name: "non-200 status returns error", setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { pathCh := make(chan string, 1) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { pathCh <- r.URL.Path w.WriteHeader(stdhttp.StatusBadGateway) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(context.Background(), server.URL) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, "downloads/client.bin", pathCh, func(t *testing.T) {} }, wantErr: true, wantPath: "/downloads/client.bin", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conn, urlOrPath, pathCh, verify := tt.setup(t) got, err := conn.DownloadVersion(urlOrPath) if tt.wantErr { if err == nil { t.Fatal("DownloadVersion() error = nil, want non-nil") } } else { if err != nil { t.Fatalf("DownloadVersion() error = %v, want nil", err) } if !reflect.DeepEqual(got, tt.want) { t.Fatalf("DownloadVersion() = %q, want %q", got, tt.want) } } select { case gotPath := <-pathCh: if gotPath != tt.wantPath { t.Fatalf("request path = %q, want %q", gotPath, tt.wantPath) } default: t.Fatalf("expected request path %q, got no request", tt.wantPath) } verify(t) }) } } // TestDoRequestUsesPassedContext verifies request context is provided by caller. func TestDoRequestUsesPassedContext(t *testing.T) { conn, pathCh := newServerConnector(t, context.Background(), stdhttp.StatusOK, "") ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := conn.doRequest(ctx, checkConnectionPath) if err == nil { t.Fatal("doRequest() error = nil, want non-nil") } if !errors.Is(err, context.Canceled) { t.Fatalf("doRequest() error = %v, want context canceled", err) } select { case gotPath := <-pathCh: t.Fatalf("expected no request with canceled context, got %q", gotPath) default: } } // TestDoRequestMovedPermanentlyRedirectsRelative verifies 301 responses follow relative Location redirects. func TestDoRequestMovedPermanentlyRedirectsRelative(t *testing.T) { requestPaths := make(chan string, 2) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { requestPaths <- r.URL.Path if r.URL.Path == "/base/api/v1/status" { w.Header().Set("Location", "/redirected") w.WriteHeader(stdhttp.StatusMovedPermanently) return } w.WriteHeader(stdhttp.StatusOK) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(context.Background(), server.URL+"/base") if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } resp, err := conn.doRequest(context.Background(), checkConnectionPath) if err != nil { t.Fatalf("doRequest() error = %v, want nil", err) } defer resp.Body.Close() if resp.StatusCode != stdhttp.StatusOK { t.Fatalf("doRequest() status code = %d, want %d", resp.StatusCode, stdhttp.StatusOK) } gotFirstPath := <-requestPaths if gotFirstPath != "/base/api/v1/status" { t.Fatalf("first request path = %q, want %q", gotFirstPath, "/base/api/v1/status") } gotSecondPath := <-requestPaths if gotSecondPath != "/redirected" { t.Fatalf("redirect request path = %q, want %q", gotSecondPath, "/redirected") } } // TestDoRequestMovedPermanentlyRedirectsAbsolute verifies 301 responses follow absolute Location redirects. func TestDoRequestMovedPermanentlyRedirectsAbsolute(t *testing.T) { initialPaths := make(chan string, 1) redirectPaths := make(chan string, 1) redirectServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { redirectPaths <- r.URL.Path w.WriteHeader(stdhttp.StatusOK) })) t.Cleanup(redirectServer.Close) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { initialPaths <- r.URL.Path w.Header().Set("Location", redirectServer.URL+"/redirected") w.WriteHeader(stdhttp.StatusMovedPermanently) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(context.Background(), server.URL) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } resp, err := conn.doRequest(context.Background(), checkConnectionPath) if err != nil { t.Fatalf("doRequest() error = %v, want nil", err) } defer resp.Body.Close() if resp.StatusCode != stdhttp.StatusOK { t.Fatalf("doRequest() status code = %d, want %d", resp.StatusCode, stdhttp.StatusOK) } select { case gotPath := <-initialPaths: if gotPath != "/api/v1/status" { t.Fatalf("initial request path = %q, want %q", gotPath, "/api/v1/status") } default: t.Fatal("expected initial request, got none") } select { case gotPath := <-redirectPaths: if gotPath != "/redirected" { t.Fatalf("redirect request path = %q, want %q", gotPath, "/redirected") } default: t.Fatal("expected redirect request, got none") } } // TestDoRequestMovedPermanentlyWithoutLocation verifies invalid 301 responses return an error. func TestDoRequestMovedPermanentlyWithoutLocation(t *testing.T) { conn, _ := newServerConnector(t, context.Background(), stdhttp.StatusMovedPermanently, "") _, err := conn.doRequest(context.Background(), checkConnectionPath) if err == nil { t.Fatal("doRequest() error = nil, want non-nil") } if !errors.Is(err, errMovedPermanentlyWithoutLocation) { t.Fatalf("doRequest() error = %v, want %v", err, errMovedPermanentlyWithoutLocation) } } // TestDoRequestMovedPermanentlyUsesRetryBudget verifies redirect follow-up requests consume jittered attempts. func TestDoRequestMovedPermanentlyUsesRetryBudget(t *testing.T) { attempts := 0 jitterCaps := make([]time.Duration, 0) sleepDurations := make([]time.Duration, 0) conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ switch req.URL.String() { case "http://example.com/api/v1/status": return &stdhttp.Response{ StatusCode: stdhttp.StatusMovedPermanently, Header: stdhttp.Header{ "Location": []string{"/redirected"}, }, Body: io.NopCloser(strings.NewReader("")), }, nil case "http://example.com/redirected": return &stdhttp.Response{ StatusCode: stdhttp.StatusOK, Body: io.NopCloser(strings.NewReader("")), }, nil default: t.Fatalf("unexpected request URL = %q", req.URL.String()) return nil, nil } }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCaps = append(jitterCaps, cap) return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepDurations = append(sleepDurations, d) return nil } resp, err := conn.doRequest(context.Background(), checkConnectionPath) if err != nil { t.Fatalf("doRequest() error = %v, want nil", err) } defer resp.Body.Close() if attempts != 2 { t.Fatalf("attempts = %d, want 2", attempts) } wantCaps := []time.Duration{5 * time.Second} if !reflect.DeepEqual(jitterCaps, wantCaps) { t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) } if !reflect.DeepEqual(sleepDurations, wantCaps) { t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) } } // TestDoRequestMovedPermanentlyExhaustsRetryBudget verifies repeated redirects eventually fail. func TestDoRequestMovedPermanentlyExhaustsRetryBudget(t *testing.T) { attempts := 0 jitterCaps := make([]time.Duration, 0) sleepDurations := make([]time.Duration, 0) conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ return &stdhttp.Response{ StatusCode: stdhttp.StatusMovedPermanently, Header: stdhttp.Header{ "Location": []string{"/redirected"}, }, Body: io.NopCloser(strings.NewReader("")), }, nil }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCaps = append(jitterCaps, cap) return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepDurations = append(sleepDurations, d) return nil } _, err := conn.doRequest(context.Background(), checkConnectionPath) if err == nil { t.Fatal("doRequest() error = nil, want non-nil") } wantCaps := append([]time.Duration(nil), defaultRetryCaps...) if attempts != len(wantCaps)+1 { t.Fatalf("attempts = %d, want %d", attempts, len(wantCaps)+1) } if !reflect.DeepEqual(jitterCaps, wantCaps) { t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) } if !reflect.DeepEqual(sleepDurations, wantCaps) { t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) } } // TestDoRequestResponseHeaderTimeout verifies client distinguishes response timeout. func TestDoRequestResponseHeaderTimeout(t *testing.T) { const ( dialTimeout = time.Second headerTimeout = 30 * time.Millisecond serverHeaderDelayTime = 150 * time.Millisecond ) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { time.Sleep(serverHeaderDelayTime) w.WriteHeader(stdhttp.StatusOK) })) t.Cleanup(server.Close) backendURL, err := url.Parse(server.URL) if err != nil { t.Fatalf("parse backend URL error = %v", err) } conn := &httpConnector{ ctx: context.Background(), backendURL: backendURL, httpClient: newHTTPClient(dialTimeout, headerTimeout), } start := time.Now() _, err = conn.doRequest(context.Background(), checkConnectionPath) elapsed := time.Since(start) if err == nil { t.Fatal("doRequest() error = nil, want timeout") } if elapsed < headerTimeout { t.Fatalf("doRequest() elapsed = %v, want >= %v", elapsed, headerTimeout) } var netErr net.Error if !errors.As(err, &netErr) || !netErr.Timeout() { t.Fatalf("doRequest() error = %v, want timeout error", err) } } // TestDoRequestSuccessFirstAttemptNoRetry verifies successful call does not schedule retries. func TestDoRequestSuccessFirstAttemptNoRetry(t *testing.T) { attempts := 0 sleepCalls := 0 jitterCalls := 0 conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ return &stdhttp.Response{ StatusCode: stdhttp.StatusOK, Body: io.NopCloser(strings.NewReader("")), }, nil }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCalls++ return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepCalls++ return nil } resp, err := conn.doRequest(context.Background(), checkConnectionPath) if err != nil { t.Fatalf("doRequest() error = %v, want nil", err) } defer resp.Body.Close() if attempts != 1 { t.Fatalf("attempts = %d, want 1", attempts) } if jitterCalls != 0 { t.Fatalf("jitter calls = %d, want 0", jitterCalls) } if sleepCalls != 0 { t.Fatalf("sleep calls = %d, want 0", sleepCalls) } } // TestDoRequestConnectTimeoutRetriesWithJitter verifies retries for connect timeout errors. func TestDoRequestConnectTimeoutRetriesWithJitter(t *testing.T) { attempts := 0 jitterCaps := make([]time.Duration, 0) sleepDurations := make([]time.Duration, 0) conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ if attempts <= 2 { return nil, newDialTimeoutError() } return &stdhttp.Response{ StatusCode: stdhttp.StatusOK, Body: io.NopCloser(strings.NewReader("")), }, nil }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCaps = append(jitterCaps, cap) return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepDurations = append(sleepDurations, d) return nil } resp, err := conn.doRequest(context.Background(), checkConnectionPath) if err != nil { t.Fatalf("doRequest() error = %v, want nil", err) } defer resp.Body.Close() if attempts != 3 { t.Fatalf("attempts = %d, want 3", attempts) } wantCaps := []time.Duration{5 * time.Second, 15 * time.Second} if !reflect.DeepEqual(jitterCaps, wantCaps) { t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) } if !reflect.DeepEqual(sleepDurations, wantCaps) { t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) } } // TestDoRequestConnectTimeoutExhaustsRetries verifies retry count and final timeout error. func TestDoRequestConnectTimeoutExhaustsRetries(t *testing.T) { attempts := 0 jitterCaps := make([]time.Duration, 0) sleepDurations := make([]time.Duration, 0) conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ return nil, newDialTimeoutError() }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCaps = append(jitterCaps, cap) return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepDurations = append(sleepDurations, d) return nil } _, err := conn.doRequest(context.Background(), checkConnectionPath) if err == nil { t.Fatal("doRequest() error = nil, want timeout") } if !isConnectTimeout(err) { t.Fatalf("doRequest() error = %v, want connect timeout", err) } wantCaps := append([]time.Duration(nil), defaultRetryCaps...) if attempts != len(wantCaps)+1 { t.Fatalf("attempts = %d, want %d", attempts, len(wantCaps)+1) } if !reflect.DeepEqual(jitterCaps, wantCaps) { t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) } if !reflect.DeepEqual(sleepDurations, wantCaps) { t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) } } // TestDoRequestResponseTimeoutNoRetry verifies response timeout does not trigger retries. func TestDoRequestResponseTimeoutNoRetry(t *testing.T) { attempts := 0 sleepCalls := 0 jitterCalls := 0 conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ return nil, newResponseHeaderTimeoutError() }) conn.jitterFn = func(cap time.Duration) time.Duration { jitterCalls++ return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepCalls++ return nil } _, err := conn.doRequest(context.Background(), checkConnectionPath) if err == nil { t.Fatal("doRequest() error = nil, want timeout") } var netErr net.Error if !errors.As(err, &netErr) || !netErr.Timeout() { t.Fatalf("doRequest() error = %v, want timeout error", err) } if attempts != 1 { t.Fatalf("attempts = %d, want 1", attempts) } if jitterCalls != 0 { t.Fatalf("jitter calls = %d, want 0", jitterCalls) } if sleepCalls != 0 { t.Fatalf("sleep calls = %d, want 0", sleepCalls) } } // TestDoRequestContextCanceledDuringBackoff verifies cancellation interrupts retry wait. func TestDoRequestContextCanceledDuringBackoff(t *testing.T) { attempts := 0 sleepCalls := 0 conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { attempts++ return nil, newDialTimeoutError() }) ctx, cancel := context.WithCancel(context.Background()) conn.jitterFn = func(cap time.Duration) time.Duration { return cap } conn.sleepFn = func(ctx context.Context, d time.Duration) error { sleepCalls++ cancel() return ctx.Err() } _, err := conn.doRequest(ctx, checkConnectionPath) if !errors.Is(err, context.Canceled) { t.Fatalf("doRequest() error = %v, want context canceled", err) } if attempts != 1 { t.Fatalf("attempts = %d, want 1", attempts) } if sleepCalls != 1 { t.Fatalf("sleep calls = %d, want 1", sleepCalls) } } type roundTripperFunc func(req *stdhttp.Request) (*stdhttp.Response, error) // RoundTrip implements [http.RoundTripper]. func (f roundTripperFunc) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) { return f(req) } // timeoutError simulates a timeout error returned by transport. type timeoutError struct { message string } func (e timeoutError) Error() string { return e.message } func (e timeoutError) Timeout() bool { return true } func (e timeoutError) Temporary() bool { return true } // newDialTimeoutError builds a connect timeout shaped like dial failure. func newDialTimeoutError() error { return &net.OpError{ Op: "dial", Net: "tcp", Err: timeoutError{message: "i/o timeout"}, } } // newResponseHeaderTimeoutError builds timeout error for response header wait. func newResponseHeaderTimeoutError() error { return timeoutError{message: "net/http: timeout awaiting response headers"} } // newTransportConnector creates connector with custom round tripper for request tests. func newTransportConnector(t *testing.T, transport roundTripperFunc) *httpConnector { t.Helper() backendURL, err := url.Parse("http://example.com") if err != nil { t.Fatalf("parse backend URL error = %v", err) } return &httpConnector{ ctx: context.Background(), backendURL: backendURL, httpClient: &stdhttp.Client{Transport: transport}, retryCaps: append([]time.Duration(nil), defaultRetryCaps...), } } // newServerConnector creates connector backed by an HTTP test server and captures requested path. func newServerConnector(t *testing.T, ctx context.Context, status int, backendPath string) (*httpConnector, <-chan string) { t.Helper() pathCh := make(chan string, 1) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { pathCh <- r.URL.Path w.WriteHeader(status) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(ctx, server.URL+backendPath) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, pathCh } // newVersionServerConnector creates connector with configurable response body for versions endpoint tests. func newVersionServerConnector(t *testing.T, ctx context.Context, status int, body, backendPath string) (*httpConnector, <-chan string) { t.Helper() pathCh := make(chan string, 1) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { pathCh <- r.URL.Path w.WriteHeader(status) _, _ = w.Write([]byte(body)) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(ctx, server.URL+backendPath) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, pathCh } // newFetchReportServerConnector creates connector with configurable response body for report endpoint tests. func newFetchReportServerConnector(t *testing.T, ctx context.Context, status int, body, backendPath string) (*httpConnector, <-chan requestDetails) { t.Helper() requestCh := make(chan requestDetails, 1) server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { requestCh <- requestDetails{path: r.URL.Path, query: r.URL.Query()} w.WriteHeader(status) _, _ = w.Write([]byte(body)) })) t.Cleanup(server.Close) conn, err := NewHttpConnector(ctx, server.URL+backendPath) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn, requestCh } // newUnreachableConnector creates connector pointing to a closed localhost TCP address. func newUnreachableConnector(t *testing.T, ctx context.Context) *httpConnector { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen error = %v", err) } addr := ln.Addr().String() if err := ln.Close(); err != nil { t.Fatalf("close listener error = %v", err) } conn, err := NewHttpConnector(ctx, "http://"+addr) if err != nil { t.Fatalf("NewHttpConnector() error = %v", err) } return conn }