218 lines
6.8 KiB
Go
218 lines
6.8 KiB
Go
package backendclient
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/session"
|
|
)
|
|
|
|
// HeaderUserID is the trusted gateway → backend identity header.
|
|
const HeaderUserID = "X-User-Id"
|
|
|
|
// errSessionNotFound is the public error returned by LookupSession when
|
|
// backend reports HTTP 404 for a device session id. It wraps
|
|
// session.ErrNotFound so callers can keep using the existing typed
|
|
// equality check at the gateway hot path.
|
|
func errSessionNotFound() error {
|
|
return fmt.Errorf("backendclient: lookup session: %w", session.ErrNotFound)
|
|
}
|
|
|
|
// RESTClient owns the gateway's HTTP conversation with backend.
|
|
//
|
|
// All methods are safe for concurrent use.
|
|
type RESTClient struct {
|
|
baseURL string
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewRESTClient constructs a RESTClient targeting the backend HTTP
|
|
// listener configured in cfg.
|
|
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
|
transport, ok := http.DefaultTransport.(*http.Transport)
|
|
if !ok {
|
|
return nil, errors.New("backendclient: default HTTP transport is not *http.Transport")
|
|
}
|
|
parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.HTTPBaseURL), "/"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
|
|
}
|
|
if parsed.Scheme == "" || parsed.Host == "" {
|
|
return nil, errors.New("backendclient: HTTPBaseURL must be absolute")
|
|
}
|
|
return &RESTClient{
|
|
baseURL: parsed.String(),
|
|
httpClient: &http.Client{
|
|
Transport: transport.Clone(),
|
|
Timeout: cfg.HTTPTimeout,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Close releases idle HTTP connections owned by the client transport.
|
|
func (c *RESTClient) Close() error {
|
|
if c == nil || c.httpClient == nil {
|
|
return nil
|
|
}
|
|
type idleCloser interface {
|
|
CloseIdleConnections()
|
|
}
|
|
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
|
transport.CloseIdleConnections()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LookupSession resolves deviceSessionID against
|
|
// `GET /api/v1/internal/sessions/{device_session_id}`.
|
|
// Returns session.ErrNotFound (wrapped) when backend reports 404.
|
|
func (c *RESTClient) LookupSession(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
|
if c == nil || c.httpClient == nil {
|
|
return session.Record{}, errors.New("backendclient: nil REST client")
|
|
}
|
|
if strings.TrimSpace(deviceSessionID) == "" {
|
|
return session.Record{}, errors.New("backendclient: lookup session: device_session_id must not be empty")
|
|
}
|
|
|
|
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID)
|
|
body, status, err := c.do(ctx, http.MethodGet, target, "", nil)
|
|
if err != nil {
|
|
return session.Record{}, fmt.Errorf("backendclient: lookup session: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case status == http.StatusOK:
|
|
return decodeDeviceSession(deviceSessionID, body)
|
|
case status == http.StatusNotFound:
|
|
return session.Record{}, errSessionNotFound()
|
|
default:
|
|
return session.Record{}, fmt.Errorf("backendclient: lookup session: unexpected HTTP status %d", status)
|
|
}
|
|
}
|
|
|
|
// do executes a JSON request and reads the response body. userID, when
|
|
// non-empty, is sent as the X-User-Id header (required for `/api/v1/user/*`).
|
|
func (c *RESTClient) do(ctx context.Context, method, target, userID string, body any) ([]byte, int, error) {
|
|
return c.doWithHeaders(ctx, method, target, userID, body, nil)
|
|
}
|
|
|
|
// doWithHeaders is the shared transport entry point. extraHeaders are
|
|
// applied verbatim after Content-Type/X-User-Id; an empty value drops
|
|
// the header so callers can pass optional language tags etc.
|
|
func (c *RESTClient) doWithHeaders(ctx context.Context, method, target, userID string, body any, extraHeaders map[string]string) ([]byte, int, error) {
|
|
if c == nil || c.httpClient == nil {
|
|
return nil, 0, errors.New("nil REST client")
|
|
}
|
|
if ctx == nil {
|
|
return nil, 0, errors.New("nil context")
|
|
}
|
|
|
|
var reader io.Reader
|
|
if body != nil {
|
|
buf, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
|
}
|
|
reader = bytes.NewReader(buf)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, target, reader)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("build request: %w", err)
|
|
}
|
|
if body != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
if userID != "" {
|
|
req.Header.Set(HeaderUserID, userID)
|
|
}
|
|
for key, value := range extraHeaders {
|
|
if strings.TrimSpace(value) == "" {
|
|
continue
|
|
}
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
defer resp.Body.Close()
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
return payload, resp.StatusCode, nil
|
|
}
|
|
|
|
// deviceSessionWire mirrors backend openapi `DeviceSession`.
|
|
type deviceSessionWire struct {
|
|
DeviceSessionID string `json:"device_session_id"`
|
|
UserID string `json:"user_id"`
|
|
Status string `json:"status"`
|
|
ClientPublicKey string `json:"client_public_key,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
|
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
|
|
}
|
|
|
|
func decodeDeviceSession(expectedDeviceSessionID string, payload []byte) (session.Record, error) {
|
|
var wire deviceSessionWire
|
|
if err := decodeStrictJSON(payload, &wire); err != nil {
|
|
return session.Record{}, fmt.Errorf("decode device session: %w", err)
|
|
}
|
|
|
|
if strings.TrimSpace(wire.DeviceSessionID) == "" {
|
|
return session.Record{}, errors.New("decode device session: device_session_id must not be empty")
|
|
}
|
|
if wire.DeviceSessionID != expectedDeviceSessionID {
|
|
return session.Record{}, fmt.Errorf("decode device session: device_session_id %q does not match requested %q", wire.DeviceSessionID, expectedDeviceSessionID)
|
|
}
|
|
if strings.TrimSpace(wire.UserID) == "" {
|
|
return session.Record{}, errors.New("decode device session: user_id must not be empty")
|
|
}
|
|
|
|
status := session.Status(strings.TrimSpace(wire.Status))
|
|
if !status.IsKnown() {
|
|
return session.Record{}, fmt.Errorf("decode device session: status %q is unsupported", wire.Status)
|
|
}
|
|
if status == session.StatusActive && strings.TrimSpace(wire.ClientPublicKey) == "" {
|
|
return session.Record{}, errors.New("decode device session: active record missing client_public_key")
|
|
}
|
|
|
|
record := session.Record{
|
|
DeviceSessionID: wire.DeviceSessionID,
|
|
UserID: wire.UserID,
|
|
ClientPublicKey: wire.ClientPublicKey,
|
|
Status: status,
|
|
}
|
|
if wire.RevokedAt != nil {
|
|
ms := wire.RevokedAt.UnixMilli()
|
|
record.RevokedAtMS = &ms
|
|
}
|
|
return record, nil
|
|
}
|
|
|
|
func decodeStrictJSON(payload []byte, target any) error {
|
|
decoder := json.NewDecoder(bytes.NewReader(payload))
|
|
decoder.DisallowUnknownFields()
|
|
if err := decoder.Decode(target); err != nil {
|
|
return err
|
|
}
|
|
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
|
if err == nil {
|
|
return errors.New("unexpected trailing JSON input")
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|