package testkit import ( "context" "fmt" "slices" "sync" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/devicesession" "galaxy/authsession/internal/ports" ) // InMemorySessionStore is a deterministic map-backed SessionStore double // suitable for service tests. type InMemorySessionStore struct { mu sync.Mutex records map[common.DeviceSessionID]devicesession.Session } // Get returns the stored device session for deviceSessionID. func (s *InMemorySessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { if err := ctx.Err(); err != nil { return devicesession.Session{}, err } if err := deviceSessionID.Validate(); err != nil { return devicesession.Session{}, fmt.Errorf("get session: %w", err) } s.mu.Lock() defer s.mu.Unlock() record, ok := s.records[deviceSessionID] if !ok { return devicesession.Session{}, fmt.Errorf("get session %q: %w", deviceSessionID, ports.ErrNotFound) } cloned, err := cloneSession(record) if err != nil { return devicesession.Session{}, err } return cloned, nil } // ListByUserID returns every stored session for userID in newest-first order. func (s *InMemorySessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { if err := ctx.Err(); err != nil { return nil, err } if err := userID.Validate(); err != nil { return nil, fmt.Errorf("list sessions by user id: %w", err) } s.mu.Lock() defer s.mu.Unlock() var records []devicesession.Session for _, record := range s.records { if record.UserID == userID { cloned, err := cloneSession(record) if err != nil { return nil, err } records = append(records, cloned) } } sortSessionsNewestFirst(records) return records, nil } // CountActiveByUserID returns the number of active sessions currently stored // for userID. func (s *InMemorySessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { if err := ctx.Err(); err != nil { return 0, err } if err := userID.Validate(); err != nil { return 0, fmt.Errorf("count active sessions by user id: %w", err) } s.mu.Lock() defer s.mu.Unlock() count := 0 for _, record := range s.records { if record.UserID == userID && record.Status == devicesession.StatusActive { count++ } } return count, nil } // Create stores record as a new device session. func (s *InMemorySessionStore) Create(ctx context.Context, record devicesession.Session) error { if err := ctx.Err(); err != nil { return err } if err := record.Validate(); err != nil { return fmt.Errorf("create session: %w", err) } cloned, err := cloneSession(record) if err != nil { return err } s.mu.Lock() defer s.mu.Unlock() if s.records == nil { s.records = make(map[common.DeviceSessionID]devicesession.Session) } if _, exists := s.records[record.ID]; exists { return fmt.Errorf("create session %q: %w", record.ID, ports.ErrConflict) } s.records[record.ID] = cloned return nil } // Revoke stores a revoked view of one target session. func (s *InMemorySessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { if err := ctx.Err(); err != nil { return ports.RevokeSessionResult{}, err } if err := input.Validate(); err != nil { return ports.RevokeSessionResult{}, fmt.Errorf("revoke session: %w", err) } s.mu.Lock() defer s.mu.Unlock() record, ok := s.records[input.DeviceSessionID] if !ok { return ports.RevokeSessionResult{}, fmt.Errorf("revoke session %q: %w", input.DeviceSessionID, ports.ErrNotFound) } if record.Status == devicesession.StatusRevoked { cloned, err := cloneSession(record) if err != nil { return ports.RevokeSessionResult{}, err } result := ports.RevokeSessionResult{ Outcome: ports.RevokeSessionOutcomeAlreadyRevoked, Session: cloned, } if err := result.Validate(); err != nil { return ports.RevokeSessionResult{}, err } return result, nil } record.Status = devicesession.StatusRevoked revocation := input.Revocation record.Revocation = &revocation cloned, err := cloneSession(record) if err != nil { return ports.RevokeSessionResult{}, err } s.records[input.DeviceSessionID] = cloned result := ports.RevokeSessionResult{ Outcome: ports.RevokeSessionOutcomeRevoked, Session: cloned, } if err := result.Validate(); err != nil { return ports.RevokeSessionResult{}, err } return result, nil } // RevokeAllByUserID stores revoked views for all currently active sessions // owned by input.UserID. func (s *InMemorySessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { if err := ctx.Err(); err != nil { return ports.RevokeUserSessionsResult{}, err } if err := input.Validate(); err != nil { return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions: %w", err) } s.mu.Lock() defer s.mu.Unlock() var affected []devicesession.Session for id, record := range s.records { if record.UserID != input.UserID || record.Status != devicesession.StatusActive { continue } record.Status = devicesession.StatusRevoked revocation := input.Revocation record.Revocation = &revocation cloned, err := cloneSession(record) if err != nil { return ports.RevokeUserSessionsResult{}, err } s.records[id] = cloned affected = append(affected, cloned) } sortSessionsNewestFirst(affected) outcome := ports.RevokeUserSessionsOutcomeNoActiveSessions if len(affected) > 0 { outcome = ports.RevokeUserSessionsOutcomeRevoked } result := ports.RevokeUserSessionsResult{ Outcome: outcome, UserID: input.UserID, Sessions: slices.Clone(affected), } if err := result.Validate(); err != nil { return ports.RevokeUserSessionsResult{}, err } return result, nil } var _ ports.SessionStore = (*InMemorySessionStore)(nil)