331 lines
11 KiB
Go
331 lines
11 KiB
Go
package authsession
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/domain/devicesession"
|
|
"galaxy/authsession/internal/ports"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// gatedCreateSessionStore blocks the first target successful Create calls
|
|
// after they persist the session, which lets concurrency tests force overlap
|
|
// between confirm and competing revoke/block flows.
|
|
type gatedCreateSessionStore struct {
|
|
delegate ports.SessionStore
|
|
target int
|
|
|
|
arrived chan common.DeviceSessionID
|
|
release chan struct{}
|
|
|
|
mu sync.Mutex
|
|
seenCreates int
|
|
releaseOnce sync.Once
|
|
}
|
|
|
|
// newGatedCreateSessionStore wraps delegate with deterministic post-create
|
|
// gating for the first target successful session creations.
|
|
func newGatedCreateSessionStore(delegate ports.SessionStore, target int) *gatedCreateSessionStore {
|
|
return &gatedCreateSessionStore{
|
|
delegate: delegate,
|
|
target: target,
|
|
arrived: make(chan common.DeviceSessionID, target),
|
|
release: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Create delegates persistence first and then blocks the first configured
|
|
// number of successful creations until Release is called.
|
|
func (s *gatedCreateSessionStore) Create(ctx context.Context, record devicesession.Session) error {
|
|
if err := s.delegate.Create(ctx, record); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
shouldGate := s.seenCreates < s.target
|
|
if shouldGate {
|
|
s.seenCreates++
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if !shouldGate {
|
|
return nil
|
|
}
|
|
|
|
s.arrived <- record.ID
|
|
|
|
select {
|
|
case <-s.release:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
// WaitForCreates waits for count gated successful Create calls and returns the
|
|
// corresponding device session identifiers in arrival order.
|
|
func (s *gatedCreateSessionStore) WaitForCreates(t *testing.T, count int) []common.DeviceSessionID {
|
|
t.Helper()
|
|
|
|
ids := make([]common.DeviceSessionID, 0, count)
|
|
timeout := time.After(5 * time.Second)
|
|
|
|
for len(ids) < count {
|
|
select {
|
|
case id := <-s.arrived:
|
|
ids = append(ids, id)
|
|
case <-timeout:
|
|
require.FailNowf(t, "test failed", "timed out waiting for %d gated session creations", count)
|
|
}
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
// Release unblocks every gated Create call.
|
|
func (s *gatedCreateSessionStore) Release() {
|
|
s.releaseOnce.Do(func() {
|
|
close(s.release)
|
|
})
|
|
}
|
|
|
|
// Get delegates to the wrapped session store.
|
|
func (s *gatedCreateSessionStore) Get(ctx context.Context, deviceSessionID common.DeviceSessionID) (devicesession.Session, error) {
|
|
return s.delegate.Get(ctx, deviceSessionID)
|
|
}
|
|
|
|
// ListByUserID delegates to the wrapped session store.
|
|
func (s *gatedCreateSessionStore) ListByUserID(ctx context.Context, userID common.UserID) ([]devicesession.Session, error) {
|
|
return s.delegate.ListByUserID(ctx, userID)
|
|
}
|
|
|
|
// CountActiveByUserID delegates to the wrapped session store.
|
|
func (s *gatedCreateSessionStore) CountActiveByUserID(ctx context.Context, userID common.UserID) (int, error) {
|
|
return s.delegate.CountActiveByUserID(ctx, userID)
|
|
}
|
|
|
|
// Revoke delegates to the wrapped session store.
|
|
func (s *gatedCreateSessionStore) Revoke(ctx context.Context, input ports.RevokeSessionInput) (ports.RevokeSessionResult, error) {
|
|
return s.delegate.Revoke(ctx, input)
|
|
}
|
|
|
|
// RevokeAllByUserID delegates to the wrapped session store.
|
|
func (s *gatedCreateSessionStore) RevokeAllByUserID(ctx context.Context, input ports.RevokeUserSessionsInput) (ports.RevokeUserSessionsResult, error) {
|
|
return s.delegate.RevokeAllByUserID(ctx, input)
|
|
}
|
|
|
|
var _ ports.SessionStore = (*gatedCreateSessionStore)(nil)
|
|
|
|
func TestProductionHardeningConcurrentIdenticalConfirmsConvergeToOneActiveSession(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
env := newHardeningEnvironment(t)
|
|
var gate *gatedCreateSessionStore
|
|
app := newHardeningApp(t, env, hardeningAppOptions{
|
|
SeedExistingUser: true,
|
|
WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore {
|
|
gate = newGatedCreateSessionStore(delegate, 2)
|
|
return gate
|
|
},
|
|
})
|
|
|
|
challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail)
|
|
requestBody := gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey)
|
|
|
|
responses := make([]gatewayCompatibilityHTTPResponse, 2)
|
|
start := make(chan struct{})
|
|
|
|
var requests sync.WaitGroup
|
|
requests.Add(2)
|
|
for index := range responses {
|
|
go func(index int) {
|
|
defer requests.Done()
|
|
<-start
|
|
responses[index] = gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/confirm-email-code", requestBody)
|
|
}(index)
|
|
}
|
|
|
|
close(start)
|
|
createdIDs := gate.WaitForCreates(t, 2)
|
|
require.Len(t, createdIDs, 2)
|
|
assert.NotEqual(t, createdIDs[0], createdIDs[1])
|
|
|
|
gate.Release()
|
|
requests.Wait()
|
|
|
|
var deviceSessionIDs []string
|
|
for _, response := range responses {
|
|
assert.Equal(t, http.StatusOK, response.StatusCode)
|
|
|
|
var body struct {
|
|
DeviceSessionID string `json:"device_session_id"`
|
|
}
|
|
require.NoError(t, json.Unmarshal([]byte(response.Body), &body))
|
|
deviceSessionIDs = append(deviceSessionIDs, body.DeviceSessionID)
|
|
}
|
|
require.Len(t, deviceSessionIDs, 2)
|
|
assert.Equal(t, deviceSessionIDs[0], deviceSessionIDs[1])
|
|
|
|
records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1"))
|
|
require.NoError(t, err)
|
|
require.Len(t, records, 2)
|
|
|
|
activeCount := 0
|
|
revokedCount := 0
|
|
for _, record := range records {
|
|
switch record.Status {
|
|
case devicesession.StatusActive:
|
|
activeCount++
|
|
assert.Equal(t, common.DeviceSessionID(deviceSessionIDs[0]), record.ID)
|
|
case devicesession.StatusRevoked:
|
|
revokedCount++
|
|
require.NotNil(t, record.Revocation)
|
|
assert.Equal(t, common.RevokeReasonCode("confirm_race_repair"), record.Revocation.ReasonCode)
|
|
default:
|
|
require.Failf(t, "test failed", "unexpected final session status %q", record.Status)
|
|
}
|
|
}
|
|
assert.Equal(t, 1, activeCount)
|
|
assert.Equal(t, 1, revokedCount)
|
|
|
|
cacheRecord := env.MustReadGatewayCacheRecord(t, deviceSessionIDs[0])
|
|
assert.Equal(t, "active", cacheRecord.Status)
|
|
}
|
|
|
|
func TestProductionHardeningConcurrentConfirmAndRevokeAllKeepProjectionConsistent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
env := newHardeningEnvironment(t)
|
|
var gate *gatedCreateSessionStore
|
|
app := newHardeningApp(t, env, hardeningAppOptions{
|
|
SeedExistingUser: true,
|
|
WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore {
|
|
gate = newGatedCreateSessionStore(delegate, 1)
|
|
return gate
|
|
},
|
|
})
|
|
|
|
challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail)
|
|
confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1)
|
|
go func() {
|
|
confirmResponseCh <- gatewayCompatibilityPostJSONValue(
|
|
t,
|
|
app.publicBaseURL+"/api/v1/public/auth/confirm-email-code",
|
|
gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey),
|
|
)
|
|
}()
|
|
|
|
createdIDs := gate.WaitForCreates(t, 1)
|
|
sessionID := createdIDs[0].String()
|
|
|
|
revokeAllResponse := gatewayCompatibilityPostJSON(
|
|
t,
|
|
app.internalBaseURL+"/api/v1/internal/users/user-1/sessions/revoke-all",
|
|
`{"reason_code":"logout_all","actor":{"type":"system"}}`,
|
|
)
|
|
assert.Equal(t, http.StatusOK, revokeAllResponse.StatusCode)
|
|
assert.JSONEq(t, `{"outcome":"revoked","user_id":"user-1","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, revokeAllResponse.Body)
|
|
|
|
gate.Release()
|
|
confirmResponse := <-confirmResponseCh
|
|
assert.Equal(t, http.StatusOK, confirmResponse.StatusCode)
|
|
|
|
var confirmBody struct {
|
|
DeviceSessionID string `json:"device_session_id"`
|
|
}
|
|
require.NoError(t, json.Unmarshal([]byte(confirmResponse.Body), &confirmBody))
|
|
assert.Equal(t, sessionID, confirmBody.DeviceSessionID)
|
|
|
|
records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1"))
|
|
require.NoError(t, err)
|
|
require.Len(t, records, 1)
|
|
assert.Equal(t, devicesession.StatusRevoked, records[0].Status)
|
|
require.NotNil(t, records[0].Revocation)
|
|
assert.Equal(t, devicesession.RevokeReasonLogoutAll, records[0].Revocation.ReasonCode)
|
|
|
|
cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID)
|
|
assert.Equal(t, "revoked", cacheRecord.Status)
|
|
require.NotNil(t, cacheRecord.RevokedAtMS)
|
|
}
|
|
|
|
func TestProductionHardeningConcurrentBlockUserAndConfirmDoNotLeakActiveSession(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
env := newHardeningEnvironment(t)
|
|
var gate *gatedCreateSessionStore
|
|
app := newHardeningApp(t, env, hardeningAppOptions{
|
|
SeedExistingUser: true,
|
|
WrapSessionStore: func(delegate ports.SessionStore) ports.SessionStore {
|
|
gate = newGatedCreateSessionStore(delegate, 1)
|
|
return gate
|
|
},
|
|
})
|
|
|
|
challengeID, code := app.SendChallenge(t, gatewayCompatibilityEmail)
|
|
initialAttempts := app.mailSender.RecordedAttempts()
|
|
require.Len(t, initialAttempts, 1)
|
|
|
|
confirmResponseCh := make(chan gatewayCompatibilityHTTPResponse, 1)
|
|
go func() {
|
|
confirmResponseCh <- gatewayCompatibilityPostJSONValue(
|
|
t,
|
|
app.publicBaseURL+"/api/v1/public/auth/confirm-email-code",
|
|
gatewayCompatibilityConfirmRequest(challengeID, code, gatewayCompatibilityClientPublicKey),
|
|
)
|
|
}()
|
|
|
|
createdIDs := gate.WaitForCreates(t, 1)
|
|
sessionID := createdIDs[0].String()
|
|
|
|
blockResponse := gatewayCompatibilityPostJSON(
|
|
t,
|
|
app.internalBaseURL+"/api/v1/internal/user-blocks",
|
|
`{"email":"pilot@example.com","reason_code":"policy_blocked","actor":{"type":"admin"}}`,
|
|
)
|
|
assert.Equal(t, http.StatusOK, blockResponse.StatusCode)
|
|
assert.JSONEq(t, `{"outcome":"blocked","subject_kind":"email","subject_value":"pilot@example.com","affected_session_count":1,"affected_device_session_ids":["`+sessionID+`"]}`, blockResponse.Body)
|
|
|
|
gate.Release()
|
|
confirmResponse := <-confirmResponseCh
|
|
assert.Contains(t, []int{http.StatusOK, http.StatusForbidden}, confirmResponse.StatusCode)
|
|
|
|
records, err := app.sessionStore.ListByUserID(context.Background(), common.UserID("user-1"))
|
|
require.NoError(t, err)
|
|
require.Len(t, records, 1)
|
|
assert.Equal(t, devicesession.StatusRevoked, records[0].Status)
|
|
require.NotNil(t, records[0].Revocation)
|
|
assert.Equal(t, devicesession.RevokeReasonUserBlocked, records[0].Revocation.ReasonCode)
|
|
|
|
cacheRecord := env.MustReadGatewayCacheRecord(t, sessionID)
|
|
assert.Equal(t, "revoked", cacheRecord.Status)
|
|
require.NotNil(t, cacheRecord.RevokedAtMS)
|
|
|
|
followupSend := gatewayCompatibilityPostJSONValue(t, app.publicBaseURL+"/api/v1/public/auth/send-email-code", map[string]string{
|
|
"email": gatewayCompatibilityEmail,
|
|
})
|
|
assert.Equal(t, http.StatusOK, followupSend.StatusCode)
|
|
|
|
var sendBody struct {
|
|
ChallengeID string `json:"challenge_id"`
|
|
}
|
|
require.NoError(t, json.Unmarshal([]byte(followupSend.Body), &sendBody))
|
|
assert.NotEmpty(t, sendBody.ChallengeID)
|
|
assert.Len(t, app.mailSender.RecordedAttempts(), 1)
|
|
|
|
followupConfirm := gatewayCompatibilityPostJSONValue(
|
|
t,
|
|
app.publicBaseURL+"/api/v1/public/auth/confirm-email-code",
|
|
gatewayCompatibilityConfirmRequest(sendBody.ChallengeID, gatewayCompatibilityCode, gatewayCompatibilityClientPublicKey),
|
|
)
|
|
assert.Equal(t, http.StatusForbidden, followupConfirm.StatusCode)
|
|
assert.JSONEq(t, `{"error":{"code":"blocked_by_policy","message":"authentication is blocked by policy"}}`, followupConfirm.Body)
|
|
}
|