Files
2026-05-07 00:58:53 +03:00

218 lines
6.8 KiB
Go

package backendclient
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"galaxy/gateway/internal/session"
)
// HeaderUserID is the trusted gateway → backend identity header.
const HeaderUserID = "X-User-Id"
// errSessionNotFound is the public error returned by LookupSession when
// backend reports HTTP 404 for a device session id. It wraps
// session.ErrNotFound so callers can keep using the existing typed
// equality check at the gateway hot path.
func errSessionNotFound() error {
return fmt.Errorf("backendclient: lookup session: %w", session.ErrNotFound)
}
// RESTClient owns the gateway's HTTP conversation with backend.
//
// All methods are safe for concurrent use.
type RESTClient struct {
baseURL string
httpClient *http.Client
}
// NewRESTClient constructs a RESTClient targeting the backend HTTP
// listener configured in cfg.
func NewRESTClient(cfg Config) (*RESTClient, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("backendclient: default HTTP transport is not *http.Transport")
}
parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.HTTPBaseURL), "/"))
if err != nil {
return nil, fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return nil, errors.New("backendclient: HTTPBaseURL must be absolute")
}
return &RESTClient{
baseURL: parsed.String(),
httpClient: &http.Client{
Transport: transport.Clone(),
Timeout: cfg.HTTPTimeout,
},
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *RESTClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// LookupSession resolves deviceSessionID against
// `GET /api/v1/internal/sessions/{device_session_id}`.
// Returns session.ErrNotFound (wrapped) when backend reports 404.
func (c *RESTClient) LookupSession(ctx context.Context, deviceSessionID string) (session.Record, error) {
if c == nil || c.httpClient == nil {
return session.Record{}, errors.New("backendclient: nil REST client")
}
if strings.TrimSpace(deviceSessionID) == "" {
return session.Record{}, errors.New("backendclient: lookup session: device_session_id must not be empty")
}
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID)
body, status, err := c.do(ctx, http.MethodGet, target, "", nil)
if err != nil {
return session.Record{}, fmt.Errorf("backendclient: lookup session: %w", err)
}
switch {
case status == http.StatusOK:
return decodeDeviceSession(deviceSessionID, body)
case status == http.StatusNotFound:
return session.Record{}, errSessionNotFound()
default:
return session.Record{}, fmt.Errorf("backendclient: lookup session: unexpected HTTP status %d", status)
}
}
// do executes a JSON request and reads the response body. userID, when
// non-empty, is sent as the X-User-Id header (required for `/api/v1/user/*`).
func (c *RESTClient) do(ctx context.Context, method, target, userID string, body any) ([]byte, int, error) {
return c.doWithHeaders(ctx, method, target, userID, body, nil)
}
// doWithHeaders is the shared transport entry point. extraHeaders are
// applied verbatim after Content-Type/X-User-Id; an empty value drops
// the header so callers can pass optional language tags etc.
func (c *RESTClient) doWithHeaders(ctx context.Context, method, target, userID string, body any, extraHeaders map[string]string) ([]byte, int, error) {
if c == nil || c.httpClient == nil {
return nil, 0, errors.New("nil REST client")
}
if ctx == nil {
return nil, 0, errors.New("nil context")
}
var reader io.Reader
if body != nil {
buf, err := json.Marshal(body)
if err != nil {
return nil, 0, fmt.Errorf("marshal request body: %w", err)
}
reader = bytes.NewReader(buf)
}
req, err := http.NewRequestWithContext(ctx, method, target, reader)
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if userID != "" {
req.Header.Set(HeaderUserID, userID)
}
for key, value := range extraHeaders {
if strings.TrimSpace(value) == "" {
continue
}
req.Header.Set(key, value)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err)
}
return payload, resp.StatusCode, nil
}
// deviceSessionWire mirrors backend openapi `DeviceSession`.
type deviceSessionWire struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
Status string `json:"status"`
ClientPublicKey string `json:"client_public_key,omitempty"`
CreatedAt time.Time `json:"created_at"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
}
func decodeDeviceSession(expectedDeviceSessionID string, payload []byte) (session.Record, error) {
var wire deviceSessionWire
if err := decodeStrictJSON(payload, &wire); err != nil {
return session.Record{}, fmt.Errorf("decode device session: %w", err)
}
if strings.TrimSpace(wire.DeviceSessionID) == "" {
return session.Record{}, errors.New("decode device session: device_session_id must not be empty")
}
if wire.DeviceSessionID != expectedDeviceSessionID {
return session.Record{}, fmt.Errorf("decode device session: device_session_id %q does not match requested %q", wire.DeviceSessionID, expectedDeviceSessionID)
}
if strings.TrimSpace(wire.UserID) == "" {
return session.Record{}, errors.New("decode device session: user_id must not be empty")
}
status := session.Status(strings.TrimSpace(wire.Status))
if !status.IsKnown() {
return session.Record{}, fmt.Errorf("decode device session: status %q is unsupported", wire.Status)
}
if status == session.StatusActive && strings.TrimSpace(wire.ClientPublicKey) == "" {
return session.Record{}, errors.New("decode device session: active record missing client_public_key")
}
record := session.Record{
DeviceSessionID: wire.DeviceSessionID,
UserID: wire.UserID,
ClientPublicKey: wire.ClientPublicKey,
Status: status,
}
if wire.RevokedAt != nil {
ms := wire.RevokedAt.UnixMilli()
record.RevokedAtMS = &ms
}
return record, nil
}
func decodeStrictJSON(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("unexpected trailing JSON input")
}
return err
}
return nil
}