// Package challengestore implements ports.ChallengeStore with Redis-backed // strict JSON challenge records. package challengestore import ( "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "reflect" "strings" "time" "galaxy/authsession/internal/domain/challenge" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/ports" "github.com/redis/go-redis/v9" ) const expirationGracePeriod = 5 * time.Minute const defaultPreferredLanguage = "en" // Config configures one Redis-backed challenge store instance. The store does // not own its Redis client; the runtime supplies a shared client constructed // via `pkg/redisconn`. type Config struct { // KeyPrefix is the namespace prefix applied to every challenge key. KeyPrefix string // OperationTimeout bounds each Redis round trip performed by the adapter. OperationTimeout time.Duration } // Store persists challenges as one strict JSON value per Redis key. type Store struct { client *redis.Client keyPrefix string operationTimeout time.Duration } type redisRecord struct { ChallengeID string `json:"challenge_id"` Email string `json:"email"` CodeHashBase64 string `json:"code_hash_base64"` PreferredLanguage string `json:"preferred_language,omitempty"` Status challenge.Status `json:"status"` DeliveryState challenge.DeliveryState `json:"delivery_state"` CreatedAt string `json:"created_at"` ExpiresAt string `json:"expires_at"` SendAttemptCount int `json:"send_attempt_count"` ConfirmAttemptCount int `json:"confirm_attempt_count"` LastAttemptAt *string `json:"last_attempt_at,omitempty"` ConfirmedSessionID string `json:"confirmed_session_id,omitempty"` ConfirmedClientPublicKey string `json:"confirmed_client_public_key,omitempty"` ConfirmedAt *string `json:"confirmed_at,omitempty"` } // New constructs a Redis-backed challenge store that uses client and applies // the namespace and timeout settings from cfg. func New(client *redis.Client, cfg Config) (*Store, error) { if client == nil { return nil, errors.New("new redis challenge store: nil redis client") } if strings.TrimSpace(cfg.KeyPrefix) == "" { return nil, errors.New("new redis challenge store: redis key prefix must not be empty") } if cfg.OperationTimeout <= 0 { return nil, errors.New("new redis challenge store: operation timeout must be positive") } return &Store{ client: client, keyPrefix: cfg.KeyPrefix, operationTimeout: cfg.OperationTimeout, }, nil } // Get returns the stored challenge for challengeID. func (s *Store) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) { if err := challengeID.Validate(); err != nil { return challenge.Challenge{}, fmt.Errorf("get challenge from redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "get challenge from redis") if err != nil { return challenge.Challenge{}, err } defer cancel() payload, err := s.client.Get(operationCtx, s.lookupKey(challengeID)).Bytes() switch { case errors.Is(err, redis.Nil): return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, ports.ErrNotFound) case err != nil: return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err) } record, err := decodeChallengeRecord(challengeID, payload) if err != nil { return challenge.Challenge{}, fmt.Errorf("get challenge %q from redis: %w", challengeID, err) } return record, nil } // Create persists record as a new challenge. func (s *Store) Create(ctx context.Context, record challenge.Challenge) error { if err := record.Validate(); err != nil { return fmt.Errorf("create challenge in redis: %w", err) } payload, err := marshalChallengeRecord(record) if err != nil { return fmt.Errorf("create challenge in redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "create challenge in redis") if err != nil { return err } defer cancel() created, err := s.client.SetNX(operationCtx, s.lookupKey(record.ID), payload, redisTTL(record.ExpiresAt)).Result() if err != nil { return fmt.Errorf("create challenge %q in redis: %w", record.ID, err) } if !created { return fmt.Errorf("create challenge %q in redis: %w", record.ID, ports.ErrConflict) } return nil } // CompareAndSwap replaces previous with next when the currently stored // challenge matches previous exactly in canonical Redis representation. func (s *Store) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error { if err := ports.ValidateComparableChallenges(previous, next); err != nil { return fmt.Errorf("compare and swap challenge in redis: %w", err) } nextPayload, err := marshalChallengeRecord(next) if err != nil { return fmt.Errorf("compare and swap challenge in redis: %w", err) } operationCtx, cancel, err := s.operationContext(ctx, "compare and swap challenge in redis") if err != nil { return err } defer cancel() key := s.lookupKey(previous.ID) watchErr := s.client.Watch(operationCtx, func(tx *redis.Tx) error { payload, err := tx.Get(operationCtx, key).Bytes() switch { case errors.Is(err, redis.Nil): return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrNotFound) case err != nil: return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) } current, err := decodeChallengeRecord(previous.ID, payload) if err != nil { return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) } matches, err := equalStoredChallenges(current, previous) if err != nil { return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) } if !matches { return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict) } _, err = tx.TxPipelined(operationCtx, func(pipe redis.Pipeliner) error { pipe.Set(operationCtx, key, nextPayload, redisTTL(next.ExpiresAt)) return nil }) if err != nil { return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, err) } return nil }, key) switch { case errors.Is(watchErr, redis.TxFailedErr): return fmt.Errorf("compare and swap challenge %q in redis: %w", previous.ID, ports.ErrConflict) case watchErr != nil: return watchErr default: return nil } } func (s *Store) operationContext(ctx context.Context, operation string) (context.Context, context.CancelFunc, error) { if s == nil || s.client == nil { return nil, nil, fmt.Errorf("%s: nil store", operation) } if ctx == nil { return nil, nil, fmt.Errorf("%s: nil context", operation) } operationCtx, cancel := context.WithTimeout(ctx, s.operationTimeout) return operationCtx, cancel, nil } func (s *Store) lookupKey(challengeID common.ChallengeID) string { return s.keyPrefix + encodeKeyComponent(challengeID.String()) } func encodeKeyComponent(value string) string { return base64.RawURLEncoding.EncodeToString([]byte(value)) } func marshalChallengeRecord(record challenge.Challenge) ([]byte, error) { stored, err := redisRecordFromChallenge(record) if err != nil { return nil, err } payload, err := json.Marshal(stored) if err != nil { return nil, fmt.Errorf("encode redis challenge record: %w", err) } return payload, nil } func redisRecordFromChallenge(record challenge.Challenge) (redisRecord, error) { if err := record.Validate(); err != nil { return redisRecord{}, fmt.Errorf("encode redis challenge record: %w", err) } stored := redisRecord{ ChallengeID: record.ID.String(), Email: record.Email.String(), CodeHashBase64: base64.StdEncoding.EncodeToString(record.CodeHash), PreferredLanguage: record.PreferredLanguage, Status: record.Status, DeliveryState: record.DeliveryState, CreatedAt: formatTimestamp(record.CreatedAt), ExpiresAt: formatTimestamp(record.ExpiresAt), SendAttemptCount: record.Attempts.Send, ConfirmAttemptCount: record.Attempts.Confirm, LastAttemptAt: formatOptionalTimestamp(record.Abuse.LastAttemptAt), } if record.Confirmation != nil { stored.ConfirmedSessionID = record.Confirmation.SessionID.String() stored.ConfirmedClientPublicKey = record.Confirmation.ClientPublicKey.String() stored.ConfirmedAt = formatOptionalTimestamp(&record.Confirmation.ConfirmedAt) } return stored, nil } func decodeChallengeRecord(expectedChallengeID common.ChallengeID, payload []byte) (challenge.Challenge, error) { decoder := json.NewDecoder(bytes.NewReader(payload)) decoder.DisallowUnknownFields() var stored redisRecord if err := decoder.Decode(&stored); err != nil { return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) } if err := decoder.Decode(&struct{}{}); err != io.EOF { if err == nil { return challenge.Challenge{}, errors.New("decode redis challenge record: unexpected trailing JSON input") } return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) } record, err := challengeFromRedisRecord(stored) if err != nil { return challenge.Challenge{}, err } if record.ID != expectedChallengeID { return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: challenge_id %q does not match requested %q", record.ID, expectedChallengeID) } return record, nil } func challengeFromRedisRecord(stored redisRecord) (challenge.Challenge, error) { createdAt, err := parseTimestamp("created_at", stored.CreatedAt) if err != nil { return challenge.Challenge{}, err } expiresAt, err := parseTimestamp("expires_at", stored.ExpiresAt) if err != nil { return challenge.Challenge{}, err } lastAttemptAt, err := parseOptionalTimestamp("last_attempt_at", stored.LastAttemptAt) if err != nil { return challenge.Challenge{}, err } codeHash, err := base64.StdEncoding.Strict().DecodeString(stored.CodeHashBase64) if err != nil { return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: code_hash_base64: %w", err) } record := challenge.Challenge{ ID: common.ChallengeID(stored.ChallengeID), Email: common.Email(stored.Email), CodeHash: codeHash, PreferredLanguage: normalizeStoredPreferredLanguage(stored.PreferredLanguage), Status: stored.Status, DeliveryState: stored.DeliveryState, CreatedAt: createdAt, ExpiresAt: expiresAt, Attempts: challenge.AttemptCounters{ Send: stored.SendAttemptCount, Confirm: stored.ConfirmAttemptCount, }, Abuse: challenge.AbuseMetadata{ LastAttemptAt: lastAttemptAt, }, } confirmation, err := parseConfirmation(stored) if err != nil { return challenge.Challenge{}, err } record.Confirmation = confirmation if err := record.Validate(); err != nil { return challenge.Challenge{}, fmt.Errorf("decode redis challenge record: %w", err) } return record, nil } func parseConfirmation(stored redisRecord) (*challenge.Confirmation, error) { hasSessionID := strings.TrimSpace(stored.ConfirmedSessionID) != "" hasClientPublicKey := strings.TrimSpace(stored.ConfirmedClientPublicKey) != "" hasConfirmedAt := stored.ConfirmedAt != nil if !hasSessionID && !hasClientPublicKey && !hasConfirmedAt { return nil, nil } if !hasSessionID || !hasClientPublicKey || !hasConfirmedAt { return nil, errors.New("decode redis challenge record: confirmation metadata must be either fully present or fully absent") } confirmedAt, err := parseTimestamp("confirmed_at", *stored.ConfirmedAt) if err != nil { return nil, err } rawClientPublicKey, err := base64.StdEncoding.Strict().DecodeString(stored.ConfirmedClientPublicKey) if err != nil { return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err) } clientPublicKey, err := common.NewClientPublicKey(rawClientPublicKey) if err != nil { return nil, fmt.Errorf("decode redis challenge record: confirmed_client_public_key: %w", err) } return &challenge.Confirmation{ SessionID: common.DeviceSessionID(stored.ConfirmedSessionID), ClientPublicKey: clientPublicKey, ConfirmedAt: confirmedAt, }, nil } func parseOptionalTimestamp(fieldName string, value *string) (*time.Time, error) { if value == nil { return nil, nil } parsed, err := parseTimestamp(fieldName, *value) if err != nil { return nil, err } return &parsed, nil } func parseTimestamp(fieldName string, value string) (time.Time, error) { if strings.TrimSpace(value) == "" { return time.Time{}, fmt.Errorf("decode redis challenge record: %s must not be empty", fieldName) } parsed, err := time.Parse(time.RFC3339Nano, value) if err != nil { return time.Time{}, fmt.Errorf("decode redis challenge record: %s: %w", fieldName, err) } canonical := parsed.UTC().Format(time.RFC3339Nano) if value != canonical { return time.Time{}, fmt.Errorf("decode redis challenge record: %s must be a canonical UTC RFC3339Nano timestamp", fieldName) } return parsed.UTC(), nil } func formatTimestamp(value time.Time) string { return value.UTC().Format(time.RFC3339Nano) } func formatOptionalTimestamp(value *time.Time) *string { if value == nil { return nil } formatted := formatTimestamp(*value) return &formatted } func normalizeStoredPreferredLanguage(value string) string { preferredLanguage := strings.TrimSpace(value) if preferredLanguage == "" { return defaultPreferredLanguage } return preferredLanguage } func redisTTL(expiresAt time.Time) time.Duration { ttl := time.Until(expiresAt.UTC()) if ttl < 0 { ttl = 0 } return ttl + expirationGracePeriod } func equalStoredChallenges(left challenge.Challenge, right challenge.Challenge) (bool, error) { leftRecord, err := redisRecordFromChallenge(left) if err != nil { return false, err } rightRecord, err := redisRecordFromChallenge(right) if err != nil { return false, err } return reflect.DeepEqual(leftRecord, rightRecord), nil } var _ ports.ChallengeStore = (*Store)(nil)