package authsession import ( "context" "encoding/json" "net/http" "sync" "testing" "time" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/domain/devicesession" "galaxy/authsession/internal/ports" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // gatedCreateSessionStore blocks the first target successful Create calls // after they persist the session, which lets concurrency tests force overlap // between confirm and competing revoke/block flows. type gatedCreateSessionStore struct { delegate ports.SessionStore target int arrived chan common.DeviceSessionID release chan struct{} mu sync.Mutex seenCreates int releaseOnce sync.Once } // newGatedCreateSessionStore wraps delegate with deterministic post-create // gating for the first target successful session creations. func newGatedCreateSessionStore(delegate ports.SessionStore, target int) *gatedCreateSessionStore { return &gatedCreateSessionStore{ delegate: delegate, target: target, arrived: make(chan common.DeviceSessionID, target), release: make(chan struct{}), } } // Create delegates persistence first and then blocks the first configured // number of successful creations until Release is called. func (s *gatedCreateSessionStore) Create(ctx context.Context, record devicesession.Session) error { if err := s.delegate.Create(ctx, record); err != nil { return err } s.mu.Lock() shouldGate := s.seenCreates < s.target if shouldGate { s.seenCreates++ } s.mu.Unlock() if !shouldGate { return nil } s.arrived <- record.ID select { case <-s.release: return nil case <-ctx.Done(): return ctx.Err() } } // WaitForCreates waits for count gated successful Create calls and returns the // corresponding device session identifiers in arrival order. func (s *gatedCreateSessionStore) WaitForCreates(t *testing.T, count int) []common.DeviceSessionID { t.Helper() ids := make([]common.DeviceSessionID, 0, count) timeout := time.After(5 * time.Second) for len(ids) < count { select { case id := <-s.arrived: ids = append(ids, id) case <-timeout: require.FailNowf(t, "test failed", "timed out waiting for %d gated session creations", count) } } return ids } // Release unblocks every gated Create call. func (s *gatedCreateSessionStore) Release() { s.releaseOnce.Do(func() { close(s.release) }) } // Get delegates to the wrapped session store. func (s *gatedCreateSessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) { return s.delegate.Get(ctx, deviceSessionID) } // ListByUserID delegates to the wrapped session store. func (s *gatedCreateSessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) { return s.delegate.ListByUserID(ctx, userID) } // CountActiveByUserID delegates to the wrapped session store. func (s *gatedCreateSessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) { return s.delegate.CountActiveByUserID(ctx, userID) } // Revoke delegates to the wrapped session store. func (s *gatedCreateSessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) { return s.delegate.Revoke(ctx, input) } // RevokeAllByUserID delegates to the wrapped session store. func (s *gatedCreateSessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) { return s.delegate.RevokeAllByUserID(ctx, input) } var _ ports.SessionStore = (*gatedCreateSessionStore)(nil) func TestProductionHardeningConcurrentIdenticalConfirmsConvergeToOneActiveSession(t *testing.T) { t.Parallel() env := newHardeningEnvironment(t) var gate *gatedCreateSessionStore app := newHardeningApp(t, env, hardeningAppOptions{ SeedExistingUser: true, WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { gate = newGatedCreateSessionStore(delegate, 2) return gate }, }) challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) requestBody := gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey) responses := make([]gatewayCompatibilityHTTPResponse, 2) start := make(chan struct{}) var requests sync.WaitGroup requests.Add(2) for index := range responses { go func(index int) { defer requests.Done() <-start responses[index] = gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody) }(index) } close(start) createdIDs := gate.WaitForCreates(t, 2) require.Len(t, createdIDs, 2) assert.NotEqual(t, createdIDs[0], createdIDs[1]) gate.Release() requests.Wait() var deviceSessionIDs []string for _, response := range responses { assert.Equal(t, http.StatusOK, response.StatusCode) var body struct { DeviceSessionID string `json:"device_session_id"` } require.NoError(t, json.Unmarshal([]byte(response.Body), &body)) deviceSessionIDs = append(deviceSessionIDs, body.DeviceSessionID) } require.Len(t, deviceSessionIDs, 2) assert.Equal(t, deviceSessionIDs[0], deviceSessionIDs[1]) records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) require.Len(t, records, 2) activeCount := 0 revokedCount := 0 for _, record := range records { switch record.Status { case devicesession.StatusActive: activeCount++ assert.Equal(t, common.DeviceSessionID(deviceSessionIDs[0]), record.ID) case devicesession.StatusRevoked: revokedCount++ require.NotNil(t, record.Revocation) assert.Equal(t, common.RevokeReasonCode("confirm_race_repair"), record.Revocation.ReasonCode) default: require.Failf(t, "test failed", "unexpected final session status %q", record.Status) } } assert.Equal(t, 1, activeCount) assert.Equal(t, 1, revokedCount) cacheRecord := env.MustReadGatewayCacheRecord(t, deviceSessionIDs[0]) assert.Equal(t, "active", cacheRecord.Status) } func TestProductionHardeningConcurrentConfirmAndRevokeAllKeepProjectionConsistent(t *testing.T) { t.Parallel() env := newHardeningEnvironment(t) var gate *gatedCreateSessionStore app := newHardeningApp(t, env, hardeningAppOptions{ SeedExistingUser: true, WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { gate = newGatedCreateSessionStore(delegate, 1) return gate }, }) challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1) go func() { confirmResponseCh <- gatewayCompatibilityPostJSONValue( t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey), ) }() createdIDs := gate.WaitForCreates(t, 1) sessionID := createdIDs[0].String() revokeAllResponse := gatewayCompatibilityPostJSON( t, app.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all", `{"reason_code":"logout_all","actor":{"type":"system"}}`, ) assert.Equal(t, http.StatusOK, revokeAllResponse.StatusCode) assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, revokeAllResponse.Body) gate.Release() confirmResponse := <-confirmResponseCh assert.Equal(t, http.StatusOK, confirmResponse.StatusCode) var confirmBody struct { DeviceSessionID string `json:"device_session_id"` } require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody)) assert.Equal(t, sessionID, confirmBody.DeviceSessionID) records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) require.Len(t, records, 1) assert.Equal(t, devicesession.StatusRevoked, records[0].Status) require.NotNil(t, records[0].Revocation) assert.Equal(t, devicesession.RevokeReasonLogoutAll, records[0].Revocation.ReasonCode) cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID) assert.Equal(t, "revoked", cacheRecord.Status) require.NotNil(t, cacheRecord.RevokedAtMS) } func TestProductionHardeningConcurrentBlockUserAndConfirmDoNotLeakActiveSession(t *testing.T) { t.Parallel() env := newHardeningEnvironment(t) var gate *gatedCreateSessionStore app := newHardeningApp(t, env, hardeningAppOptions{ SeedExistingUser: true, WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore { gate = newGatedCreateSessionStore(delegate, 1) return gate }, }) challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail) initialAttempts := app.mailSender.RecordedAttempts() require.Len(t, initialAttempts, 1) confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1) go func() { confirmResponseCh <- gatewayCompatibilityPostJSONValue( t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey), ) }() createdIDs := gate.WaitForCreates(t, 1) sessionID := createdIDs[0].String() blockResponse := gatewayCompatibilityPostJSON( t, app.internalBaseURL+"/api/v1/internal/user-blocks", `{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`, ) assert.Equal(t, http.StatusOK, blockResponse.StatusCode) assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, blockResponse.Body) gate.Release() confirmResponse := <-confirmResponseCh assert.Contains(t, []int{http.StatusOK, http.StatusForbidden}, confirmResponse.StatusCode) records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1")) require.NoError(t, err) require.Len(t, records, 1) assert.Equal(t, devicesession.StatusRevoked, records[0].Status) require.NotNil(t, records[0].Revocation) assert.Equal(t, devicesession.RevokeReasonUserBlocked, records[0].Revocation.ReasonCode) cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID) assert.Equal(t, "revoked", cacheRecord.Status) require.NotNil(t, cacheRecord.RevokedAtMS) followupSend := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{ "email": gatewayCompatibilityEmail, }) assert.Equal(t, http.StatusOK, followupSend.StatusCode) var sendBody struct { ChallengeID string `json:"challenge_id"` } require.NoError(t, json.Unmarshal([]byte(followupSend.Body), &sendBody)) assert.NotEmpty(t, sendBody.ChallengeID) assert.Len(t, app.mailSender.RecordedAttempts(), 1) followupConfirm := gatewayCompatibilityPostJSONValue( t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", gatewayCompatibilityConfirmRequest(sendBody.ChallengeID, gatewayCompatibilityCode, gatewayCompatibilityClientPublicKey), ) assert.Equal(t, http.StatusForbidden, followupConfirm.StatusCode) assert.JSONEq(t, `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`, followupConfirm.Body) }