394 lines
11 KiB
Go
394 lines
11 KiB
Go
package shared
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/domain/devicesession"
|
|
"galaxy/authsession/internal/domain/gatewayprojection"
|
|
"galaxy/authsession/internal/domain/sessionlimit"
|
|
"galaxy/authsession/internal/ports"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNormalizeString(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, "pilot@example.com", NormalizeString(" pilot@example.com \n"))
|
|
}
|
|
|
|
func TestParseClientPublicKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
key, err := ParseClientPublicKey(" AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8= ")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", key.String())
|
|
|
|
_, err = ParseClientPublicKey("invalid")
|
|
require.Error(t, err)
|
|
assert.Equal(t, ErrorCodeInvalidClientPublicKey, CodeOf(err))
|
|
}
|
|
|
|
func TestParseTimeZone(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
timeZone, err := ParseTimeZone(" Europe/Kaliningrad ")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Europe/Kaliningrad", timeZone)
|
|
|
|
_, err = ParseTimeZone("Mars/Olympus")
|
|
require.Error(t, err)
|
|
assert.Equal(t, ErrorCodeInvalidRequest, CodeOf(err))
|
|
assert.Equal(t, "time_zone must be a valid IANA time zone name", err.Error())
|
|
}
|
|
|
|
func TestToSession(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
record := revokedSessionFixture()
|
|
|
|
dto, err := ToSession(record)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, record.ID.String(), dto.DeviceSessionID)
|
|
require.NotNil(t, dto.RevokedAt)
|
|
assert.Equal(t, record.Revocation.At.UTC().Format(time.RFC3339), *dto.RevokedAt)
|
|
}
|
|
|
|
func TestToGatewayProjectionSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
record := revokedSessionFixture()
|
|
|
|
snapshot, err := ToGatewayProjectionSnapshot(record)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
|
|
}
|
|
|
|
func TestEvaluateSessionLimit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
limit := 2
|
|
|
|
tests := []struct {
|
|
name string
|
|
config ports.SessionLimitConfig
|
|
active int
|
|
want sessionlimit.Kind
|
|
}{
|
|
{name: "disabled", config: ports.SessionLimitConfig{}, active: 3, want: sessionlimit.KindDisabled},
|
|
{name: "allowed", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 1, want: sessionlimit.KindAllowed},
|
|
{name: "exceeded", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 2, want: sessionlimit.KindExceeded},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
decision, err := EvaluateSessionLimit(tt.config, tt.active)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.want, decision.Kind)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServiceErrorCodePreservation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseErr := errors.New("base")
|
|
err := ServiceUnavailable(baseErr)
|
|
|
|
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
|
|
assert.ErrorIs(t, err, baseErr)
|
|
}
|
|
|
|
func TestErrorCodeClassification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
publicCodes := []string{
|
|
ErrorCodeInvalidRequest,
|
|
ErrorCodeChallengeNotFound,
|
|
ErrorCodeChallengeExpired,
|
|
ErrorCodeInvalidCode,
|
|
ErrorCodeInvalidClientPublicKey,
|
|
ErrorCodeBlockedByPolicy,
|
|
ErrorCodeSessionLimitExceeded,
|
|
ErrorCodeServiceUnavailable,
|
|
}
|
|
for _, code := range publicCodes {
|
|
assert.Truef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
|
|
assert.Falsef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
|
|
}
|
|
|
|
internalOnlyCodes := []string{
|
|
ErrorCodeSessionNotFound,
|
|
ErrorCodeSubjectNotFound,
|
|
ErrorCodeInternalError,
|
|
}
|
|
for _, code := range internalOnlyCodes {
|
|
assert.Falsef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
|
|
assert.Truef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
|
|
}
|
|
}
|
|
|
|
func TestPublicUseCaseErrorCodeSets(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeInvalidRequest))
|
|
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeServiceUnavailable))
|
|
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeBlockedByPolicy))
|
|
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeChallengeNotFound))
|
|
|
|
confirmCodes := []string{
|
|
ErrorCodeInvalidRequest,
|
|
ErrorCodeChallengeNotFound,
|
|
ErrorCodeChallengeExpired,
|
|
ErrorCodeInvalidCode,
|
|
ErrorCodeInvalidClientPublicKey,
|
|
ErrorCodeBlockedByPolicy,
|
|
ErrorCodeSessionLimitExceeded,
|
|
ErrorCodeServiceUnavailable,
|
|
}
|
|
for _, code := range confirmCodes {
|
|
assert.Truef(t, IsConfirmEmailCodePublicErrorCode(code), "IsConfirmEmailCodePublicErrorCode(%q)", code)
|
|
}
|
|
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeInternalError))
|
|
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeSessionNotFound))
|
|
}
|
|
|
|
func TestPublicHTTPStatusCode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
want int
|
|
}{
|
|
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
|
|
{name: "invalid client public key", code: ErrorCodeInvalidClientPublicKey, want: http.StatusBadRequest},
|
|
{name: "invalid code", code: ErrorCodeInvalidCode, want: http.StatusBadRequest},
|
|
{name: "challenge not found", code: ErrorCodeChallengeNotFound, want: http.StatusNotFound},
|
|
{name: "challenge expired", code: ErrorCodeChallengeExpired, want: http.StatusGone},
|
|
{name: "blocked by policy", code: ErrorCodeBlockedByPolicy, want: http.StatusForbidden},
|
|
{name: "session limit exceeded", code: ErrorCodeSessionLimitExceeded, want: http.StatusConflict},
|
|
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
|
|
{name: "internal error normalized", code: ErrorCodeInternalError, want: http.StatusServiceUnavailable},
|
|
{name: "unknown normalized", code: "unknown", want: http.StatusServiceUnavailable},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.want, PublicHTTPStatusCode(tt.code))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInternalHTTPStatusCode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
want int
|
|
}{
|
|
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
|
|
{name: "session not found", code: ErrorCodeSessionNotFound, want: http.StatusNotFound},
|
|
{name: "subject not found", code: ErrorCodeSubjectNotFound, want: http.StatusNotFound},
|
|
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
|
|
{name: "internal error", code: ErrorCodeInternalError, want: http.StatusInternalServerError},
|
|
{name: "unknown normalized", code: "unknown", want: http.StatusInternalServerError},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.want, InternalHTTPStatusCode(tt.code))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProjectPublicError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
want PublicErrorProjection
|
|
}{
|
|
{
|
|
name: "invalid request keeps detailed message",
|
|
err: InvalidRequest("email must be a single valid email address"),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusBadRequest,
|
|
Code: ErrorCodeInvalidRequest,
|
|
Message: "email must be a single valid email address",
|
|
},
|
|
},
|
|
{
|
|
name: "invalid code keeps canonical message",
|
|
err: NewServiceError(ErrorCodeInvalidCode, "custom detail should not leak", nil),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusBadRequest,
|
|
Code: ErrorCodeInvalidCode,
|
|
Message: "confirmation code is invalid",
|
|
},
|
|
},
|
|
{
|
|
name: "service unavailable keeps generic message",
|
|
err: NewServiceError(ErrorCodeServiceUnavailable, "dependency timeout", errors.New("dependency timeout")),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Code: ErrorCodeServiceUnavailable,
|
|
Message: "service is unavailable",
|
|
},
|
|
},
|
|
{
|
|
name: "internal error is hidden",
|
|
err: InternalError(errors.New("broken invariant")),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Code: ErrorCodeServiceUnavailable,
|
|
Message: "service is unavailable",
|
|
},
|
|
},
|
|
{
|
|
name: "internal only session not found is hidden",
|
|
err: SessionNotFound(),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Code: ErrorCodeServiceUnavailable,
|
|
Message: "service is unavailable",
|
|
},
|
|
},
|
|
{
|
|
name: "non service error is hidden",
|
|
err: errors.New("boom"),
|
|
want: PublicErrorProjection{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Code: ErrorCodeServiceUnavailable,
|
|
Message: "service is unavailable",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.want, ProjectPublicError(tt.err))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProjectInternalError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
want InternalErrorProjection
|
|
}{
|
|
{
|
|
name: "invalid request keeps detailed message",
|
|
err: InvalidRequest("reason_code must not be empty"),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusBadRequest,
|
|
Code: ErrorCodeInvalidRequest,
|
|
Message: "reason_code must not be empty",
|
|
},
|
|
},
|
|
{
|
|
name: "session not found keeps canonical message",
|
|
err: NewServiceError(ErrorCodeSessionNotFound, "custom detail should not leak", nil),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusNotFound,
|
|
Code: ErrorCodeSessionNotFound,
|
|
Message: "session not found",
|
|
},
|
|
},
|
|
{
|
|
name: "subject not found keeps canonical message",
|
|
err: SubjectNotFound(),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusNotFound,
|
|
Code: ErrorCodeSubjectNotFound,
|
|
Message: "subject not found",
|
|
},
|
|
},
|
|
{
|
|
name: "service unavailable keeps generic message",
|
|
err: NewServiceError(ErrorCodeServiceUnavailable, "redis timeout", errors.New("redis timeout")),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Code: ErrorCodeServiceUnavailable,
|
|
Message: "service is unavailable",
|
|
},
|
|
},
|
|
{
|
|
name: "internal error uses internal server error message",
|
|
err: InternalError(errors.New("broken invariant")),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Code: ErrorCodeInternalError,
|
|
Message: "internal server error",
|
|
},
|
|
},
|
|
{
|
|
name: "unexpected error is hidden",
|
|
err: errors.New("boom"),
|
|
want: InternalErrorProjection{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Code: ErrorCodeInternalError,
|
|
Message: "internal server error",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, tt.want, ProjectInternalError(tt.err))
|
|
})
|
|
}
|
|
}
|
|
|
|
func revokedSessionFixture() devicesession.Session {
|
|
key, err := common.NewClientPublicKey(make([]byte, 32))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
revokedAt := time.Unix(20, 0).UTC()
|
|
return devicesession.Session{
|
|
ID: common.DeviceSessionID("device-session-1"),
|
|
UserID: common.UserID("user-1"),
|
|
ClientPublicKey: key,
|
|
Status: devicesession.StatusRevoked,
|
|
CreatedAt: time.Unix(10, 0).UTC(),
|
|
Revocation: &devicesession.Revocation{
|
|
At: revokedAt,
|
|
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
|
ActorType: common.RevokeActorType("system"),
|
|
ActorID: "actor-1",
|
|
},
|
|
}
|
|
}
|