123 lines
3.0 KiB
Go
123 lines
3.0 KiB
Go
package testkit
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
|
|
"galaxy/authsession/internal/domain/challenge"
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/ports"
|
|
)
|
|
|
|
// InMemoryChallengeStore is a deterministic map-backed ChallengeStore double
|
|
// suitable for service tests.
|
|
type InMemoryChallengeStore struct {
|
|
mu sync.Mutex
|
|
records map[common.ChallengeID]challenge.Challenge
|
|
}
|
|
|
|
// Get returns the stored challenge for challengeID.
|
|
func (s *InMemoryChallengeStore) Get(ctx context.Context, challengeID common.ChallengeID) (challenge.Challenge, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return challenge.Challenge{}, err
|
|
}
|
|
if err := challengeID.Validate(); err != nil {
|
|
return challenge.Challenge{}, fmt.Errorf("get challenge: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
record, ok := s.records[challengeID]
|
|
if !ok {
|
|
return challenge.Challenge{}, fmt.Errorf("get challenge %q: %w", challengeID, ports.ErrNotFound)
|
|
}
|
|
|
|
cloned, err := cloneChallenge(record)
|
|
if err != nil {
|
|
return challenge.Challenge{}, err
|
|
}
|
|
|
|
return cloned, nil
|
|
}
|
|
|
|
// Create stores record as a new challenge.
|
|
func (s *InMemoryChallengeStore) Create(ctx context.Context, record challenge.Challenge) error {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("create challenge: %w", err)
|
|
}
|
|
|
|
cloned, err := cloneChallenge(record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.records == nil {
|
|
s.records = make(map[common.ChallengeID]challenge.Challenge)
|
|
}
|
|
if _, exists := s.records[record.ID]; exists {
|
|
return fmt.Errorf("create challenge %q: %w", record.ID, ports.ErrConflict)
|
|
}
|
|
|
|
s.records[record.ID] = cloned
|
|
return nil
|
|
}
|
|
|
|
// CompareAndSwap replaces previous with next when the currently stored
|
|
// challenge matches previous exactly.
|
|
func (s *InMemoryChallengeStore) CompareAndSwap(ctx context.Context, previous challenge.Challenge, next challenge.Challenge) error {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if err := ports.ValidateComparableChallenges(previous, next); err != nil {
|
|
return fmt.Errorf("compare and swap challenge: %w", err)
|
|
}
|
|
|
|
clonedPrevious, err := cloneChallenge(previous)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
clonedNext, err := cloneChallenge(next)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
current, ok := s.records[previous.ID]
|
|
if !ok {
|
|
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrNotFound)
|
|
}
|
|
if !reflect.DeepEqual(current, clonedPrevious) {
|
|
return fmt.Errorf("compare and swap challenge %q: %w", previous.ID, ports.ErrConflict)
|
|
}
|
|
|
|
s.records[next.ID] = clonedNext
|
|
return nil
|
|
}
|
|
|
|
var _ ports.ChallengeStore = (*InMemoryChallengeStore)(nil)
|
|
|
|
func mustGetChallenge(store *InMemoryChallengeStore, challengeID common.ChallengeID) challenge.Challenge {
|
|
record, err := store.Get(context.Background(), challengeID)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return record
|
|
}
|
|
|
|
func isNotFound(err error) bool {
|
|
return errors.Is(err, ports.ErrNotFound)
|
|
}
|