230 lines
5.8 KiB
Go
230 lines
5.8 KiB
Go
package testkit
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"sync"
|
|
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/domain/devicesession"
|
|
"galaxy/authsession/internal/ports"
|
|
)
|
|
|
|
// InMemorySessionStore is a deterministic map-backed SessionStore double
|
|
// suitable for service tests.
|
|
type InMemorySessionStore struct {
|
|
mu sync.Mutex
|
|
records map[common.DeviceSessionID]devicesession.Session
|
|
}
|
|
|
|
// Get returns the stored device session for deviceSessionID.
|
|
func (s *InMemorySessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return devicesession.Session{}, err
|
|
}
|
|
if err := deviceSessionID.Validate(); err != nil {
|
|
return devicesession.Session{}, fmt.Errorf("get session: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
record, ok := s.records[deviceSessionID]
|
|
if !ok {
|
|
return devicesession.Session{}, fmt.Errorf("get session %q: %w", deviceSessionID, ports.ErrNotFound)
|
|
}
|
|
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return devicesession.Session{}, err
|
|
}
|
|
|
|
return cloned, nil
|
|
}
|
|
|
|
// ListByUserID returns every stored session for userID in newest-first order.
|
|
func (s *InMemorySessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := userID.Validate(); err != nil {
|
|
return nil, fmt.Errorf("list sessions by user id: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var records []devicesession.Session
|
|
for _, record := range s.records {
|
|
if record.UserID == userID {
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
records = append(records, cloned)
|
|
}
|
|
}
|
|
sortSessionsNewestFirst(records)
|
|
|
|
return records, nil
|
|
}
|
|
|
|
// CountActiveByUserID returns the number of active sessions currently stored
|
|
// for userID.
|
|
func (s *InMemorySessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
if err := userID.Validate(); err != nil {
|
|
return 0, fmt.Errorf("count active sessions by user id: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
count := 0
|
|
for _, record := range s.records {
|
|
if record.UserID == userID && record.Status == devicesession.StatusActive {
|
|
count++
|
|
}
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// Create stores record as a new device session.
|
|
func (s *InMemorySessionStore) Create(ctx context.Context, record devicesession.Session) error {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("create session: %w", err)
|
|
}
|
|
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.records == nil {
|
|
s.records = make(map[common.DeviceSessionID]devicesession.Session)
|
|
}
|
|
if _, exists := s.records[record.ID]; exists {
|
|
return fmt.Errorf("create session %q: %w", record.ID, ports.ErrConflict)
|
|
}
|
|
|
|
s.records[record.ID] = cloned
|
|
return nil
|
|
}
|
|
|
|
// Revoke stores a revoked view of one target session.
|
|
func (s *InMemorySessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return ports.RevokeSessionResult{}, err
|
|
}
|
|
if err := input.Validate(); err != nil {
|
|
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
record, ok := s.records[input.DeviceSessionID]
|
|
if !ok {
|
|
return ports.RevokeSessionResult{}, fmt.Errorf("revoke session %q: %w", input.DeviceSessionID, ports.ErrNotFound)
|
|
}
|
|
|
|
if record.Status == devicesession.StatusRevoked {
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return ports.RevokeSessionResult{}, err
|
|
}
|
|
|
|
result := ports.RevokeSessionResult{
|
|
Outcome: ports.RevokeSessionOutcomeAlreadyRevoked,
|
|
Session: cloned,
|
|
}
|
|
if err := result.Validate(); err != nil {
|
|
return ports.RevokeSessionResult{}, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
record.Status = devicesession.StatusRevoked
|
|
revocation := input.Revocation
|
|
record.Revocation = &revocation
|
|
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return ports.RevokeSessionResult{}, err
|
|
}
|
|
s.records[input.DeviceSessionID] = cloned
|
|
|
|
result := ports.RevokeSessionResult{
|
|
Outcome: ports.RevokeSessionOutcomeRevoked,
|
|
Session: cloned,
|
|
}
|
|
if err := result.Validate(); err != nil {
|
|
return ports.RevokeSessionResult{}, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// RevokeAllByUserID stores revoked views for all currently active sessions
|
|
// owned by input.UserID.
|
|
func (s *InMemorySessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return ports.RevokeUserSessionsResult{}, err
|
|
}
|
|
if err := input.Validate(); err != nil {
|
|
return ports.RevokeUserSessionsResult{}, fmt.Errorf("revoke user sessions: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var affected []devicesession.Session
|
|
for id, record := range s.records {
|
|
if record.UserID != input.UserID || record.Status != devicesession.StatusActive {
|
|
continue
|
|
}
|
|
|
|
record.Status = devicesession.StatusRevoked
|
|
revocation := input.Revocation
|
|
record.Revocation = &revocation
|
|
|
|
cloned, err := cloneSession(record)
|
|
if err != nil {
|
|
return ports.RevokeUserSessionsResult{}, err
|
|
}
|
|
s.records[id] = cloned
|
|
affected = append(affected, cloned)
|
|
}
|
|
|
|
sortSessionsNewestFirst(affected)
|
|
|
|
outcome := ports.RevokeUserSessionsOutcomeNoActiveSessions
|
|
if len(affected) > 0 {
|
|
outcome = ports.RevokeUserSessionsOutcomeRevoked
|
|
}
|
|
|
|
result := ports.RevokeUserSessionsResult{
|
|
Outcome: outcome,
|
|
UserID: input.UserID,
|
|
Sessions: slices.Clone(affected),
|
|
}
|
|
if err := result.Validate(); err != nil {
|
|
return ports.RevokeUserSessionsResult{}, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
var _ ports.SessionStore = (*InMemorySessionStore)(nil)
|