From ac3ed31a23fb4f979f21bf0cf1d2ea29c90b12d0 Mon Sep 17 00:00:00 2001 From: Ilia Denisov Date: Sat, 14 Mar 2026 21:11:51 +0200 Subject: [PATCH] connector impl --- game/internal/controller/controller.go | 20 + game/internal/controller/generate_game.go | 6 + game/internal/router/handler/handler.go | 6 + game/internal/router/handler/report.go | 27 + game/internal/router/report_test.go | 60 +++ game/internal/router/router.go | 1 + game/internal/router/router_helper_test.go | 5 + pkg/connector/http/http.go | 147 ++++- pkg/connector/http/http_test.go | 597 +++++++++++++++++++++ 9 files changed, 863 insertions(+), 6 deletions(-) create mode 100644 game/internal/router/handler/report.go create mode 100644 game/internal/router/report_test.go diff --git a/game/internal/controller/controller.go b/game/internal/controller/controller.go index 77173a3..d8b872e 100644 --- a/game/internal/controller/controller.go +++ b/game/internal/controller/controller.go @@ -110,6 +110,14 @@ func GenerateTurn(configure func(*Param)) (err error) { return } +func LoadReport(configure func(*Param), actor string, turn uint) (*report.Report, error) { + ec, err := NewRepoController(configure) + if err != nil { + return nil, err + } + return ec.loadReport(actor, turn) +} + func ExecuteCommand(configure func(*Param), consumer func(c Ctrl) error) (err error) { ec, err := NewRepoController(configure) if err != nil { @@ -197,6 +205,18 @@ func (ec *RepoController) validateOrder(actor string, cmd ...order.DecodableComm }) } +func (ec *RepoController) loadReport(actor string, turn uint) (r *report.Report, err error) { + execErr := ec.executeSafe(func(t uint, c *Controller) (exErr error) { + id, exErr := c.RaceID(actor) + if exErr == nil { + r, exErr = ec.Repo.LoadReport(turn, id) + } + return + }) + err = errors.Join(err, execErr) + return +} + func (ec *RepoController) executeCommand(consumer func(*Controller) error) (err error) { return ec.executeLocked(func(c *Controller) error { err = consumer(c) diff --git a/game/internal/controller/generate_game.go b/game/internal/controller/generate_game.go index 8c6fbde..a564308 100644 --- a/game/internal/controller/generate_game.go +++ b/game/internal/controller/generate_game.go @@ -29,6 +29,12 @@ func newGameOnMap(r Repo, races []string, m generator.Map) (uuid.UUID, error) { if err := r.SaveNewTurn(0, g); err != nil { return uuid.Nil, err } + c := NewCache(g) + for rep := range c.Report(c.g.Turn, nil, nil) { + if err := r.SaveReport(c.g.Turn, rep); err != nil { + return uuid.Nil, err + } + } return g.ID, nil } diff --git a/game/internal/router/handler/handler.go b/game/internal/router/handler/handler.go index 1f5cfb2..094ab3a 100644 --- a/game/internal/router/handler/handler.go +++ b/game/internal/router/handler/handler.go @@ -6,6 +6,7 @@ import ( "os" "galaxy/model/order" + "galaxy/model/report" "galaxy/model/rest" e "galaxy/error" @@ -21,6 +22,7 @@ type CommandExecutor interface { GenerateGame([]string) (rest.StateResponse, error) GenerateTurn() (rest.StateResponse, error) GameState() (rest.StateResponse, error) + LoadReport(actor string, turn uint) (*report.Report, error) Execute(cmd ...Command) error ValidateOrder(actor string, cmd ...order.DecodableCommand) error } @@ -84,6 +86,10 @@ func (e *executor) GameState() (rest.StateResponse, error) { return stateResponse(s), nil } +func (e *executor) LoadReport(actor string, turn uint) (*report.Report, error) { + return controller.LoadReport(e.cfg, actor, turn) +} + func stateResponse(s game.State) rest.StateResponse { result := &rest.StateResponse{ ID: s.ID, diff --git a/game/internal/router/handler/report.go b/game/internal/router/handler/report.go new file mode 100644 index 0000000..f15c0c8 --- /dev/null +++ b/game/internal/router/handler/report.go @@ -0,0 +1,27 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type reportParam struct { + Player string `form:"player" binding:"required,notblank"` + Turn int `form:"turn" binding:"gte=0"` +} + +func ReportHandler(c *gin.Context, executor CommandExecutor) { + p := &reportParam{} + err := c.ShouldBindQuery(p) + if errorResponse(c, err) { + return + } + + r, err := executor.LoadReport(p.Player, uint(p.Turn)) + if errorResponse(c, err) { + return + } + + c.JSON(http.StatusOK, r) +} diff --git a/game/internal/router/report_test.go b/game/internal/router/report_test.go new file mode 100644 index 0000000..a87756e --- /dev/null +++ b/game/internal/router/report_test.go @@ -0,0 +1,60 @@ +package router_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "galaxy/model/rest" + + "galaxy/game/internal/controller" + "galaxy/game/internal/router" + "galaxy/game/internal/router/handler" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestGetReport(t *testing.T) { + root := t.ArtifactDir() + + r := router.SetupRouter(handler.NewDefaultConfigExecutor(func(p *controller.Param) { p.StoragePath = root })) + + payload := generateInitRequest(10) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/init", asBody(payload)) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code, w.Body) + var initResponse rest.StateResponse + assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &initResponse)) + assert.NoError(t, uuid.Validate(initResponse.ID.String())) + assert.NotEqual(t, uuid.Nil, uuid.MustParse(initResponse.ID.String())) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/api/v1/report", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/api/v1/report?player=", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/api/v1/report?player=&turn=0", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/api/v1/report?player=Race_01&turn=-1", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/api/v1/report?player=Race_01&turn=0", nil) + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, w.Body) +} diff --git a/game/internal/router/router.go b/game/internal/router/router.go index ca7d246..4215a45 100644 --- a/game/internal/router/router.go +++ b/game/internal/router/router.go @@ -67,6 +67,7 @@ func setupRouter(executor handler.CommandExecutor) *gin.Engine { groupV1.GET("/status", func(ctx *gin.Context) { handler.StatusHandler(ctx, executor) }) groupV1.POST("/init", func(ctx *gin.Context) { handler.InitHandler(ctx, executor) }) + groupV1.GET("/report", func(ctx *gin.Context) { handler.ReportHandler(ctx, executor) }) groupV1.PUT("/command", LimitMiddleware(1), func(ctx *gin.Context) { handler.CommandHandler(ctx, executor) }) groupV1.PUT("/order", func(ctx *gin.Context) { handler.OrderHandler(ctx, executor) }) groupV1.PUT("/turn", func(ctx *gin.Context) { handler.TurnHandler(ctx, executor) }) diff --git a/game/internal/router/router_helper_test.go b/game/internal/router/router_helper_test.go index 804e5b9..e8db00f 100644 --- a/game/internal/router/router_helper_test.go +++ b/game/internal/router/router_helper_test.go @@ -5,6 +5,7 @@ import ( "net/http" "galaxy/model/order" + "galaxy/model/report" "galaxy/model/rest" "galaxy/game/internal/router" @@ -55,6 +56,10 @@ func (e *dummyExecutor) GameState() (rest.StateResponse, error) { return rest.StateResponse{}, nil } +func (e *dummyExecutor) LoadReport(actor string, turn uint) (*report.Report, error) { + return &report.Report{}, nil +} + func setupRouter() *gin.Engine { return setupRouterExecutor(newExecutor()) } diff --git a/pkg/connector/http/http.go b/pkg/connector/http/http.go index 85b7084..4f08455 100644 --- a/pkg/connector/http/http.go +++ b/pkg/connector/http/http.go @@ -7,11 +7,15 @@ import ( "errors" "fmt" "galaxy/connector" + "galaxy/model/client" + "galaxy/model/report" + "io" "math/rand/v2" "net" "net/http" "net/url" "path" + "strconv" "strings" "time" ) @@ -21,6 +25,10 @@ const ( checkConnectionPath = "api/v1/status" // checkVersionPath is backend endpoint path used to load available app versions. checkVersionPath = "api/v1/versions" + // fetchReportPath is backend endpoint path used to load game report for a specific turn number. + fetchReportPath = "api/v1/report" + // fetchReportPlayer is a temporary player identifier until UI passes actor identity explicitly. + fetchReportPlayer = "Race_01" // connectTimeout is max time for establishing TCP connection. connectTimeout = 3 * time.Second @@ -36,6 +44,9 @@ var defaultRetryCaps = []time.Duration{ 60 * time.Second, } +// errMovedPermanentlyWithoutLocation reports an invalid redirect response. +var errMovedPermanentlyWithoutLocation = errors.New("server returned 301 response without Location header") + type httpConnector struct { ctx context.Context backendURL *url.URL // HTTP REST API Server URL @@ -77,6 +88,12 @@ func newHTTPClient(connectTimeout, responseTimeout time.Duration) *http.Client { } } +// doNotFollowRedirect keeps redirect handling inside doRequest so retry budget +// and jitter stay under connector control. +func doNotFollowRedirect(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse +} + func (h *httpConnector) requestContext() context.Context { if h.ctx != nil { return h.ctx @@ -176,10 +193,110 @@ func (h *httpConnector) CheckVersion() ([]connector.VersionInfo, error) { return versions, nil } -// doRequest performs GET request for a backend relative endpoint with passed context. -func (h *httpConnector) doRequest(ctx context.Context, relativePath string) (*http.Response, error) { - requestURL := *h.backendURL - requestURL.Path = path.Join(requestURL.Path, relativePath) +// DownloadVersion retrieves a version artifact from backend storage. +// urlOrPath may be either a backend-relative path or a fully qualified URL. +func (h *httpConnector) DownloadVersion(urlOrPath string) ([]byte, error) { + resp, err := h.doRequest(h.requestContext(), urlOrPath) + if err != nil { + return nil, fmt.Errorf("download version artifact: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download version artifact: unexpected status code %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read version artifact body: %w", err) + } + + return body, nil +} + +// FetchReport asynchronously loads a report for turn from backend and invokes callback once with the result. +func (h *httpConnector) FetchReport(_ client.GameID, turn uint, callback func(report.Report, error)) { + go func() { + rep, err := h.fetchReport(turn) + if callback != nil { + callback(rep, err) + } + }() +} + +// fetchReport loads a report for turn from backend using the temporary player identifier. +func (h *httpConnector) fetchReport(turn uint) (report.Report, error) { + resp, err := h.doRequest(h.requestContext(), fetchReportRequestPath(turn)) + if err != nil { + return report.Report{}, fmt.Errorf("request report from backend: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return report.Report{}, fmt.Errorf("request report from backend: unexpected status code %d", resp.StatusCode) + } + + var rep report.Report + if err := json.NewDecoder(resp.Body).Decode(&rep); err != nil { + return report.Report{}, fmt.Errorf("decode report response: %w", err) + } + + return rep, nil +} + +// fetchReportRequestPath builds the report endpoint with required query parameters. +func fetchReportRequestPath(turn uint) string { + values := url.Values{} + values.Set("player", fetchReportPlayer) + values.Set("turn", strconv.FormatUint(uint64(turn), 10)) + + return fetchReportPath + "?" + values.Encode() +} + +// resolveRequestURL returns either the fully qualified request URL as-is or +// composes a backend-relative path with connector backendURL. +func (h *httpConnector) resolveRequestURL(urlOrPath string) (*url.URL, error) { + requestURL, err := url.Parse(urlOrPath) + if err != nil { + return nil, fmt.Errorf("parse request URL %q: %w", urlOrPath, err) + } + + if requestURL.IsAbs() { + return requestURL, nil + } + + resolvedURL := *h.backendURL + resolvedURL.Path = path.Join(resolvedURL.Path, requestURL.Path) + if requestURL.RawQuery != "" { + resolvedURL.RawQuery = requestURL.RawQuery + } + if requestURL.Fragment != "" { + resolvedURL.Fragment = requestURL.Fragment + } + + return &resolvedURL, nil +} + +// doHTTP executes a single HTTP exchange without the standard client redirect handling. +func (h *httpConnector) doHTTP(req *http.Request) (*http.Response, error) { + client := h.httpClient + if client == nil { + client = newHTTPClient(connectTimeout, responseTimeout) + } + + noRedirectClient := *client + noRedirectClient.CheckRedirect = doNotFollowRedirect + + return noRedirectClient.Do(req) +} + +// doRequest performs a GET request for either a backend-relative endpoint or a +// fully qualified URL with the passed context. +func (h *httpConnector) doRequest(ctx context.Context, urlOrPath string) (*http.Response, error) { + requestURL, err := h.resolveRequestURL(urlOrPath) + if err != nil { + return nil, err + } retryCaps := h.retryCaps if retryCaps == nil { @@ -211,9 +328,27 @@ func (h *httpConnector) doRequest(ctx context.Context, relativePath string) (*ht return nil, fmt.Errorf("create request: %w", err) } - resp, err := h.httpClient.Do(req) + resp, err := h.doHTTP(req) if err == nil { - return resp, nil + if resp.StatusCode != http.StatusMovedPermanently { + return resp, nil + } + + location := resp.Header.Get("Location") + resp.Body.Close() + if location == "" { + return nil, fmt.Errorf("request %q: %w", requestURL.Redacted(), errMovedPermanentlyWithoutLocation) + } + if attempt == len(retryCaps) { + return nil, fmt.Errorf("request %q: exhausted attempts following redirect to %q", requestURL.Redacted(), location) + } + + redirectURL, err := requestURL.Parse(location) + if err != nil { + return nil, fmt.Errorf("resolve redirect location %q for request %q: %w", location, requestURL.Redacted(), err) + } + requestURL = redirectURL + continue } if !isConnectTimeout(err) { return nil, err diff --git a/pkg/connector/http/http_test.go b/pkg/connector/http/http_test.go index 0bf25f1..e903a61 100644 --- a/pkg/connector/http/http_test.go +++ b/pkg/connector/http/http_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "galaxy/connector" + "galaxy/model/report" "io" "net" stdhttp "net/http" @@ -32,6 +33,40 @@ type checkVersionCase struct { wantPath string } +// downloadVersionCase describes one DownloadVersion behavior scenario. +type downloadVersionCase struct { + name string + setup func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) + want []byte + wantErr bool + wantPath string +} + +// fetchReportCase describes one FetchReport behavior scenario. +type fetchReportCase struct { + name string + setup func(t *testing.T) (*httpConnector, <-chan requestDetails) + turn uint + want report.Report + wantErr bool + wantPath string + wantPlayer string + wantTurn string + wantRequest bool +} + +// fetchReportResult captures one asynchronous FetchReport callback result. +type fetchReportResult struct { + report report.Report + err error +} + +// requestDetails captures one received request path and query parameters. +type requestDetails struct { + path string + query url.Values +} + // TestCheckConnection verifies backend reachability probe behavior. func TestCheckConnection(t *testing.T) { tests := []checkConnectionCase{ @@ -225,6 +260,346 @@ func TestCheckVersion(t *testing.T) { } } +// TestFetchReport verifies asynchronous report retrieval behavior. +func TestFetchReport(t *testing.T) { + tests := []fetchReportCase{ + { + name: "status 200 with valid body returns report", + setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { + return newFetchReportServerConnector( + t, + context.Background(), + stdhttp.StatusOK, + `{"version":2,"turn":7,"mapWidth":120,"mapHeight":80,"race":"Race_01","votes":1.5}`, + "", + ) + }, + turn: 7, + want: report.Report{ + Version: 2, + Turn: 7, + Width: 120, + Height: 80, + Race: "Race_01", + Votes: report.Float(1.5), + }, + wantPath: "/api/v1/report", + wantPlayer: fetchReportPlayer, + wantTurn: "7", + wantRequest: true, + }, + { + name: "status 200 with invalid json returns error", + setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { + return newFetchReportServerConnector( + t, + context.Background(), + stdhttp.StatusOK, + `{"turn":`, + "", + ) + }, + turn: 8, + wantErr: true, + wantPath: "/api/v1/report", + wantPlayer: fetchReportPlayer, + wantTurn: "8", + wantRequest: true, + }, + { + name: "non-200 status returns error", + setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { + return newFetchReportServerConnector( + t, + context.Background(), + stdhttp.StatusBadGateway, + `{}`, + "", + ) + }, + turn: 9, + wantErr: true, + wantPath: "/api/v1/report", + wantPlayer: fetchReportPlayer, + wantTurn: "9", + wantRequest: true, + }, + { + name: "canceled context returns error", + setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return newFetchReportServerConnector( + t, + ctx, + stdhttp.StatusOK, + `{"turn":1}`, + "", + ) + }, + turn: 10, + wantErr: true, + wantRequest: false, + }, + { + name: "backend path prefix is preserved", + setup: func(t *testing.T) (*httpConnector, <-chan requestDetails) { + return newFetchReportServerConnector( + t, + context.Background(), + stdhttp.StatusOK, + `{"turn":11,"mapWidth":20,"mapHeight":30}`, + "/base", + ) + }, + turn: 11, + want: report.Report{ + Turn: 11, + Width: 20, + Height: 30, + }, + wantPath: "/base/api/v1/report", + wantPlayer: fetchReportPlayer, + wantTurn: "11", + wantRequest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn, requestCh := tt.setup(t) + resultCh := make(chan fetchReportResult, 1) + + conn.FetchReport("", tt.turn, func(rep report.Report, err error) { + resultCh <- fetchReportResult{report: rep, err: err} + }) + + select { + case result := <-resultCh: + if tt.wantErr { + if result.err == nil { + t.Fatal("FetchReport() error = nil, want non-nil") + } + } else { + if result.err != nil { + t.Fatalf("FetchReport() error = %v, want nil", result.err) + } + if !reflect.DeepEqual(result.report, tt.want) { + t.Fatalf("FetchReport() report = %#v, want %#v", result.report, tt.want) + } + } + case <-time.After(time.Second): + t.Fatal("FetchReport() callback was not called") + } + + if !tt.wantRequest { + select { + case req := <-requestCh: + t.Fatalf("unexpected request = %#v", req) + case <-time.After(100 * time.Millisecond): + } + return + } + + select { + case req := <-requestCh: + if req.path != tt.wantPath { + t.Fatalf("request path = %q, want %q", req.path, tt.wantPath) + } + if got := req.query.Get("player"); got != tt.wantPlayer { + t.Fatalf("request player = %q, want %q", got, tt.wantPlayer) + } + if got := req.query.Get("turn"); got != tt.wantTurn { + t.Fatalf("request turn = %q, want %q", got, tt.wantTurn) + } + case <-time.After(time.Second): + t.Fatal("expected request, got none") + } + }) + } +} + +// TestFetchReportAsync verifies FetchReport returns immediately and calls callback once after response is ready. +func TestFetchReportAsync(t *testing.T) { + requestCh := make(chan requestDetails, 1) + releaseResponse := make(chan struct{}) + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + requestCh <- requestDetails{path: r.URL.Path, query: r.URL.Query()} + <-releaseResponse + w.WriteHeader(stdhttp.StatusOK) + _, _ = w.Write([]byte(`{"turn":12,"mapWidth":1,"mapHeight":1}`)) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(context.Background(), server.URL) + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + resultCh := make(chan fetchReportResult, 2) + start := time.Now() + conn.FetchReport("", 12, func(rep report.Report, err error) { + resultCh <- fetchReportResult{report: rep, err: err} + }) + if elapsed := time.Since(start); elapsed > 50*time.Millisecond { + t.Fatalf("FetchReport() elapsed = %v, want quick return", elapsed) + } + + select { + case req := <-requestCh: + if req.path != "/api/v1/report" { + t.Fatalf("request path = %q, want %q", req.path, "/api/v1/report") + } + if got := req.query.Get("player"); got != fetchReportPlayer { + t.Fatalf("request player = %q, want %q", got, fetchReportPlayer) + } + if got := req.query.Get("turn"); got != "12" { + t.Fatalf("request turn = %q, want %q", got, "12") + } + case <-time.After(time.Second): + t.Fatal("expected request, got none") + } + + select { + case result := <-resultCh: + t.Fatalf("unexpected early callback = %#v", result) + case <-time.After(100 * time.Millisecond): + } + + close(releaseResponse) + + select { + case result := <-resultCh: + if result.err != nil { + t.Fatalf("FetchReport() error = %v, want nil", result.err) + } + if result.report.Turn != 12 { + t.Fatalf("FetchReport() report turn = %d, want %d", result.report.Turn, 12) + } + case <-time.After(time.Second): + t.Fatal("FetchReport() callback was not called") + } + + select { + case extra := <-resultCh: + t.Fatalf("FetchReport() callback called more than once: %#v", extra) + case <-time.After(100 * time.Millisecond): + } +} + +// TestDownloadVersion verifies artifact download behavior for relative and absolute URLs. +func TestDownloadVersion(t *testing.T) { + tests := []downloadVersionCase{ + { + name: "relative path uses backend URL", + setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { + pathCh := make(chan string, 1) + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + pathCh <- r.URL.Path + w.WriteHeader(stdhttp.StatusOK) + _, _ = w.Write([]byte("artifact")) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(context.Background(), server.URL+"/base") + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + return conn, "downloads/client.bin", pathCh, func(t *testing.T) {} + }, + want: []byte("artifact"), + wantPath: "/base/downloads/client.bin", + }, + { + name: "fully qualified URL bypasses backend URL", + setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { + backendPathCh := make(chan string, 1) + backendServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + backendPathCh <- r.URL.Path + w.WriteHeader(stdhttp.StatusInternalServerError) + })) + t.Cleanup(backendServer.Close) + + downloadPathCh := make(chan string, 1) + downloadServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + downloadPathCh <- r.URL.Path + w.WriteHeader(stdhttp.StatusOK) + _, _ = w.Write([]byte("external-artifact")) + })) + t.Cleanup(downloadServer.Close) + + conn, err := NewHttpConnector(context.Background(), backendServer.URL+"/base") + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + return conn, downloadServer.URL + "/artifact.bin", downloadPathCh, func(t *testing.T) { + t.Helper() + + select { + case gotPath := <-backendPathCh: + t.Fatalf("unexpected backend request path = %q", gotPath) + default: + } + } + }, + want: []byte("external-artifact"), + wantPath: "/artifact.bin", + }, + { + name: "non-200 status returns error", + setup: func(t *testing.T) (*httpConnector, string, <-chan string, func(t *testing.T)) { + pathCh := make(chan string, 1) + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + pathCh <- r.URL.Path + w.WriteHeader(stdhttp.StatusBadGateway) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(context.Background(), server.URL) + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + return conn, "downloads/client.bin", pathCh, func(t *testing.T) {} + }, + wantErr: true, + wantPath: "/downloads/client.bin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn, urlOrPath, pathCh, verify := tt.setup(t) + got, err := conn.DownloadVersion(urlOrPath) + if tt.wantErr { + if err == nil { + t.Fatal("DownloadVersion() error = nil, want non-nil") + } + } else { + if err != nil { + t.Fatalf("DownloadVersion() error = %v, want nil", err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("DownloadVersion() = %q, want %q", got, tt.want) + } + } + + select { + case gotPath := <-pathCh: + if gotPath != tt.wantPath { + t.Fatalf("request path = %q, want %q", gotPath, tt.wantPath) + } + default: + t.Fatalf("expected request path %q, got no request", tt.wantPath) + } + + verify(t) + }) + } +} + // TestDoRequestUsesPassedContext verifies request context is provided by caller. func TestDoRequestUsesPassedContext(t *testing.T) { conn, pathCh := newServerConnector(t, context.Background(), stdhttp.StatusOK, "") @@ -247,6 +622,208 @@ func TestDoRequestUsesPassedContext(t *testing.T) { } } +// TestDoRequestMovedPermanentlyRedirectsRelative verifies 301 responses follow relative Location redirects. +func TestDoRequestMovedPermanentlyRedirectsRelative(t *testing.T) { + requestPaths := make(chan string, 2) + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + requestPaths <- r.URL.Path + if r.URL.Path == "/base/api/v1/status" { + w.Header().Set("Location", "/redirected") + w.WriteHeader(stdhttp.StatusMovedPermanently) + return + } + + w.WriteHeader(stdhttp.StatusOK) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(context.Background(), server.URL+"/base") + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + resp, err := conn.doRequest(context.Background(), checkConnectionPath) + if err != nil { + t.Fatalf("doRequest() error = %v, want nil", err) + } + defer resp.Body.Close() + + if resp.StatusCode != stdhttp.StatusOK { + t.Fatalf("doRequest() status code = %d, want %d", resp.StatusCode, stdhttp.StatusOK) + } + + gotFirstPath := <-requestPaths + if gotFirstPath != "/base/api/v1/status" { + t.Fatalf("first request path = %q, want %q", gotFirstPath, "/base/api/v1/status") + } + + gotSecondPath := <-requestPaths + if gotSecondPath != "/redirected" { + t.Fatalf("redirect request path = %q, want %q", gotSecondPath, "/redirected") + } +} + +// TestDoRequestMovedPermanentlyRedirectsAbsolute verifies 301 responses follow absolute Location redirects. +func TestDoRequestMovedPermanentlyRedirectsAbsolute(t *testing.T) { + initialPaths := make(chan string, 1) + redirectPaths := make(chan string, 1) + redirectServer := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + redirectPaths <- r.URL.Path + w.WriteHeader(stdhttp.StatusOK) + })) + t.Cleanup(redirectServer.Close) + + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + initialPaths <- r.URL.Path + w.Header().Set("Location", redirectServer.URL+"/redirected") + w.WriteHeader(stdhttp.StatusMovedPermanently) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(context.Background(), server.URL) + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + resp, err := conn.doRequest(context.Background(), checkConnectionPath) + if err != nil { + t.Fatalf("doRequest() error = %v, want nil", err) + } + defer resp.Body.Close() + + if resp.StatusCode != stdhttp.StatusOK { + t.Fatalf("doRequest() status code = %d, want %d", resp.StatusCode, stdhttp.StatusOK) + } + + select { + case gotPath := <-initialPaths: + if gotPath != "/api/v1/status" { + t.Fatalf("initial request path = %q, want %q", gotPath, "/api/v1/status") + } + default: + t.Fatal("expected initial request, got none") + } + + select { + case gotPath := <-redirectPaths: + if gotPath != "/redirected" { + t.Fatalf("redirect request path = %q, want %q", gotPath, "/redirected") + } + default: + t.Fatal("expected redirect request, got none") + } +} + +// TestDoRequestMovedPermanentlyWithoutLocation verifies invalid 301 responses return an error. +func TestDoRequestMovedPermanentlyWithoutLocation(t *testing.T) { + conn, _ := newServerConnector(t, context.Background(), stdhttp.StatusMovedPermanently, "") + + _, err := conn.doRequest(context.Background(), checkConnectionPath) + if err == nil { + t.Fatal("doRequest() error = nil, want non-nil") + } + if !errors.Is(err, errMovedPermanentlyWithoutLocation) { + t.Fatalf("doRequest() error = %v, want %v", err, errMovedPermanentlyWithoutLocation) + } +} + +// TestDoRequestMovedPermanentlyUsesRetryBudget verifies redirect follow-up requests consume jittered attempts. +func TestDoRequestMovedPermanentlyUsesRetryBudget(t *testing.T) { + attempts := 0 + jitterCaps := make([]time.Duration, 0) + sleepDurations := make([]time.Duration, 0) + + conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { + attempts++ + switch req.URL.String() { + case "http://example.com/api/v1/status": + return &stdhttp.Response{ + StatusCode: stdhttp.StatusMovedPermanently, + Header: stdhttp.Header{ + "Location": []string{"/redirected"}, + }, + Body: io.NopCloser(strings.NewReader("")), + }, nil + case "http://example.com/redirected": + return &stdhttp.Response{ + StatusCode: stdhttp.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + }, nil + default: + t.Fatalf("unexpected request URL = %q", req.URL.String()) + return nil, nil + } + }) + conn.jitterFn = func(cap time.Duration) time.Duration { + jitterCaps = append(jitterCaps, cap) + return cap + } + conn.sleepFn = func(ctx context.Context, d time.Duration) error { + sleepDurations = append(sleepDurations, d) + return nil + } + + resp, err := conn.doRequest(context.Background(), checkConnectionPath) + if err != nil { + t.Fatalf("doRequest() error = %v, want nil", err) + } + defer resp.Body.Close() + + if attempts != 2 { + t.Fatalf("attempts = %d, want 2", attempts) + } + + wantCaps := []time.Duration{5 * time.Second} + if !reflect.DeepEqual(jitterCaps, wantCaps) { + t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) + } + if !reflect.DeepEqual(sleepDurations, wantCaps) { + t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) + } +} + +// TestDoRequestMovedPermanentlyExhaustsRetryBudget verifies repeated redirects eventually fail. +func TestDoRequestMovedPermanentlyExhaustsRetryBudget(t *testing.T) { + attempts := 0 + jitterCaps := make([]time.Duration, 0) + sleepDurations := make([]time.Duration, 0) + + conn := newTransportConnector(t, func(req *stdhttp.Request) (*stdhttp.Response, error) { + attempts++ + return &stdhttp.Response{ + StatusCode: stdhttp.StatusMovedPermanently, + Header: stdhttp.Header{ + "Location": []string{"/redirected"}, + }, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }) + conn.jitterFn = func(cap time.Duration) time.Duration { + jitterCaps = append(jitterCaps, cap) + return cap + } + conn.sleepFn = func(ctx context.Context, d time.Duration) error { + sleepDurations = append(sleepDurations, d) + return nil + } + + _, err := conn.doRequest(context.Background(), checkConnectionPath) + if err == nil { + t.Fatal("doRequest() error = nil, want non-nil") + } + + wantCaps := append([]time.Duration(nil), defaultRetryCaps...) + if attempts != len(wantCaps)+1 { + t.Fatalf("attempts = %d, want %d", attempts, len(wantCaps)+1) + } + if !reflect.DeepEqual(jitterCaps, wantCaps) { + t.Fatalf("jitter caps = %v, want %v", jitterCaps, wantCaps) + } + if !reflect.DeepEqual(sleepDurations, wantCaps) { + t.Fatalf("sleep durations = %v, want %v", sleepDurations, wantCaps) + } +} + // TestDoRequestResponseHeaderTimeout verifies client distinguishes response timeout. func TestDoRequestResponseHeaderTimeout(t *testing.T) { const ( @@ -574,6 +1151,26 @@ func newVersionServerConnector(t *testing.T, ctx context.Context, status int, bo return conn, pathCh } +// newFetchReportServerConnector creates connector with configurable response body for report endpoint tests. +func newFetchReportServerConnector(t *testing.T, ctx context.Context, status int, body, backendPath string) (*httpConnector, <-chan requestDetails) { + t.Helper() + + requestCh := make(chan requestDetails, 1) + server := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + requestCh <- requestDetails{path: r.URL.Path, query: r.URL.Query()} + w.WriteHeader(status) + _, _ = w.Write([]byte(body)) + })) + t.Cleanup(server.Close) + + conn, err := NewHttpConnector(ctx, server.URL+backendPath) + if err != nil { + t.Fatalf("NewHttpConnector() error = %v", err) + } + + return conn, requestCh +} + // newUnreachableConnector creates connector pointing to a closed localhost TCP address. func newUnreachableConnector(t *testing.T, ctx context.Context) *httpConnector { t.Helper()