package backendclient import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" "galaxy/gateway/internal/session" ) // HeaderUserID is the trusted gateway → backend identity header. const HeaderUserID = "X-User-Id" // errSessionNotFound is the public error returned by LookupSession when // backend reports HTTP 404 for a device session id. It wraps // session.ErrNotFound so callers can keep using the existing typed // equality check at the gateway hot path. func errSessionNotFound() error { return fmt.Errorf("backendclient: lookup session: %w", session.ErrNotFound) } // RESTClient owns the gateway's HTTP conversation with backend. // // All methods are safe for concurrent use. type RESTClient struct { baseURL string httpClient *http.Client } // NewRESTClient constructs a RESTClient targeting the backend HTTP // listener configured in cfg. func NewRESTClient(cfg Config) (*RESTClient, error) { transport, ok := http.DefaultTransport.(*http.Transport) if !ok { return nil, errors.New("backendclient: default HTTP transport is not *http.Transport") } parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.HTTPBaseURL), "/")) if err != nil { return nil, fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err) } if parsed.Scheme == "" || parsed.Host == "" { return nil, errors.New("backendclient: HTTPBaseURL must be absolute") } return &RESTClient{ baseURL: parsed.String(), httpClient: &http.Client{ Transport: transport.Clone(), Timeout: cfg.HTTPTimeout, }, }, nil } // Close releases idle HTTP connections owned by the client transport. func (c *RESTClient) Close() error { if c == nil || c.httpClient == nil { return nil } type idleCloser interface { CloseIdleConnections() } if transport, ok := c.httpClient.Transport.(idleCloser); ok { transport.CloseIdleConnections() } return nil } // LookupSession resolves deviceSessionID against // `GET /api/v1/internal/sessions/{device_session_id}`. // Returns session.ErrNotFound (wrapped) when backend reports 404. func (c *RESTClient) LookupSession(ctx context.Context, deviceSessionID string) (session.Record, error) { if c == nil || c.httpClient == nil { return session.Record{}, errors.New("backendclient: nil REST client") } if strings.TrimSpace(deviceSessionID) == "" { return session.Record{}, errors.New("backendclient: lookup session: device_session_id must not be empty") } target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID) body, status, err := c.do(ctx, http.MethodGet, target, "", nil) if err != nil { return session.Record{}, fmt.Errorf("backendclient: lookup session: %w", err) } switch { case status == http.StatusOK: return decodeDeviceSession(deviceSessionID, body) case status == http.StatusNotFound: return session.Record{}, errSessionNotFound() default: return session.Record{}, fmt.Errorf("backendclient: lookup session: unexpected HTTP status %d", status) } } // RevokeSession asks backend to revoke a single device session by id. func (c *RESTClient) RevokeSession(ctx context.Context, deviceSessionID string) error { if strings.TrimSpace(deviceSessionID) == "" { return errors.New("backendclient: revoke session: device_session_id must not be empty") } target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID) + "/revoke" _, status, err := c.do(ctx, http.MethodPost, target, "", nil) if err != nil { return fmt.Errorf("backendclient: revoke session: %w", err) } if status == http.StatusOK || status == http.StatusNoContent { return nil } if status == http.StatusNotFound { return errSessionNotFound() } return fmt.Errorf("backendclient: revoke session: unexpected HTTP status %d", status) } // RevokeAllSessionsForUser asks backend to revoke every active device // session belonging to userID. func (c *RESTClient) RevokeAllSessionsForUser(ctx context.Context, userID string) error { if strings.TrimSpace(userID) == "" { return errors.New("backendclient: revoke-all sessions: user_id must not be empty") } target := c.baseURL + "/api/v1/internal/sessions/users/" + url.PathEscape(userID) + "/revoke-all" _, status, err := c.do(ctx, http.MethodPost, target, "", nil) if err != nil { return fmt.Errorf("backendclient: revoke-all sessions: %w", err) } if status == http.StatusOK || status == http.StatusNoContent { return nil } if status == http.StatusNotFound { return errSessionNotFound() } return fmt.Errorf("backendclient: revoke-all sessions: unexpected HTTP status %d", status) } // do executes a JSON request and reads the response body. userID, when // non-empty, is sent as the X-User-Id header (required for `/api/v1/user/*`). func (c *RESTClient) do(ctx context.Context, method, target, userID string, body any) ([]byte, int, error) { return c.doWithHeaders(ctx, method, target, userID, body, nil) } // doWithHeaders is the shared transport entry point. extraHeaders are // applied verbatim after Content-Type/X-User-Id; an empty value drops // the header so callers can pass optional language tags etc. func (c *RESTClient) doWithHeaders(ctx context.Context, method, target, userID string, body any, extraHeaders map[string]string) ([]byte, int, error) { if c == nil || c.httpClient == nil { return nil, 0, errors.New("nil REST client") } if ctx == nil { return nil, 0, errors.New("nil context") } var reader io.Reader if body != nil { buf, err := json.Marshal(body) if err != nil { return nil, 0, fmt.Errorf("marshal request body: %w", err) } reader = bytes.NewReader(buf) } req, err := http.NewRequestWithContext(ctx, method, target, reader) if err != nil { return nil, 0, fmt.Errorf("build request: %w", err) } if body != nil { req.Header.Set("Content-Type", "application/json") } if userID != "" { req.Header.Set(HeaderUserID, userID) } for key, value := range extraHeaders { if strings.TrimSpace(value) == "" { continue } req.Header.Set(key, value) } resp, err := c.httpClient.Do(req) if err != nil { return nil, 0, err } defer resp.Body.Close() payload, err := io.ReadAll(resp.Body) if err != nil { return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err) } return payload, resp.StatusCode, nil } // deviceSessionWire mirrors backend openapi `DeviceSession`. type deviceSessionWire struct { DeviceSessionID string `json:"device_session_id"` UserID string `json:"user_id"` Status string `json:"status"` ClientPublicKey string `json:"client_public_key,omitempty"` CreatedAt time.Time `json:"created_at"` RevokedAt *time.Time `json:"revoked_at,omitempty"` LastSeenAt *time.Time `json:"last_seen_at,omitempty"` } func decodeDeviceSession(expectedDeviceSessionID string, payload []byte) (session.Record, error) { var wire deviceSessionWire if err := decodeStrictJSON(payload, &wire); err != nil { return session.Record{}, fmt.Errorf("decode device session: %w", err) } if strings.TrimSpace(wire.DeviceSessionID) == "" { return session.Record{}, errors.New("decode device session: device_session_id must not be empty") } if wire.DeviceSessionID != expectedDeviceSessionID { return session.Record{}, fmt.Errorf("decode device session: device_session_id %q does not match requested %q", wire.DeviceSessionID, expectedDeviceSessionID) } if strings.TrimSpace(wire.UserID) == "" { return session.Record{}, errors.New("decode device session: user_id must not be empty") } status := session.Status(strings.TrimSpace(wire.Status)) if !status.IsKnown() { return session.Record{}, fmt.Errorf("decode device session: status %q is unsupported", wire.Status) } if status == session.StatusActive && strings.TrimSpace(wire.ClientPublicKey) == "" { return session.Record{}, errors.New("decode device session: active record missing client_public_key") } record := session.Record{ DeviceSessionID: wire.DeviceSessionID, UserID: wire.UserID, ClientPublicKey: wire.ClientPublicKey, Status: status, } if wire.RevokedAt != nil { ms := wire.RevokedAt.UnixMilli() record.RevokedAtMS = &ms } return record, nil } func decodeStrictJSON(payload []byte, target any) error { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() if err := decoder.Decode(target); err != nil { return err } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return errors.New("unexpected trailing JSON input") } return err } return nil }