477 lines
15 KiB
Go
477 lines
15 KiB
Go
package authsession
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"path/filepath"
|
|
"runtime"
|
|
"slices"
|
|
"testing"
|
|
|
|
"galaxy/authsession/internal/service/shared"
|
|
|
|
"github.com/getkin/kin-openapi/openapi3"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestPublicOpenAPISpecValidates(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
loadSpec(t, "api", "public-openapi.yaml")
|
|
}
|
|
|
|
func TestInternalOpenAPISpecValidates(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
loadSpec(t, "api", "internal-openapi.yaml")
|
|
}
|
|
|
|
func TestPublicOpenAPISpecMatchesGatewayPublicAuthContract(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
authDoc := loadSpec(t, "api", "public-openapi.yaml")
|
|
gatewayDoc := loadSpec(t, "..", "gateway", "openapi.yaml")
|
|
authErrorEnvelope := componentSchemaRef(t, authDoc, "ErrorResponse")
|
|
gatewayProjectedEnvelope := defaultResponseSchemaRef(t, getOperation(t, gatewayDoc, "/api/v1/public/auth/send-email-code", http.MethodPost))
|
|
const errorResponseRef = "#/components/schemas/ErrorResponse"
|
|
|
|
paths := []string{
|
|
"/api/v1/public/auth/send-email-code",
|
|
"/api/v1/public/auth/confirm-email-code",
|
|
}
|
|
|
|
for _, path := range paths {
|
|
authOperation := getOperation(t, authDoc, path, http.MethodPost)
|
|
gatewayOperation := getOperation(t, gatewayDoc, path, http.MethodPost)
|
|
|
|
if authOperation.OperationID != gatewayOperation.OperationID {
|
|
require.Failf(t, "test failed", "operation %s: got operationId %q, want %q", path, authOperation.OperationID, gatewayOperation.OperationID)
|
|
}
|
|
|
|
compareSchemaRefs(
|
|
t,
|
|
requestSchemaRef(t, authOperation),
|
|
requestSchemaRef(t, gatewayOperation),
|
|
"path "+path+" request schema",
|
|
)
|
|
compareSchemaRefs(
|
|
t,
|
|
responseSchemaRef(t, authOperation, http.StatusOK),
|
|
responseSchemaRef(t, gatewayOperation, http.StatusOK),
|
|
"path "+path+" success response schema",
|
|
)
|
|
compareParameterRefs(
|
|
t,
|
|
authOperation.Parameters,
|
|
gatewayOperation.Parameters,
|
|
"path "+path+" parameters",
|
|
)
|
|
|
|
for _, status := range publicErrorStatuses(path) {
|
|
assertSchemaRef(t, responseSchemaRef(t, authOperation, status), errorResponseRef, "path "+path+" error response "+http.StatusText(status)+" envelope")
|
|
}
|
|
}
|
|
|
|
assertOperationParameterRefs(
|
|
t,
|
|
getOperation(t, authDoc, "/api/v1/public/auth/send-email-code", http.MethodPost),
|
|
"#/components/parameters/AcceptLanguage",
|
|
)
|
|
assertOperationParameterRefs(
|
|
t,
|
|
getOperation(t, gatewayDoc, "/api/v1/public/auth/send-email-code", http.MethodPost),
|
|
"#/components/parameters/AcceptLanguage",
|
|
)
|
|
assertOperationParameterRefs(
|
|
t,
|
|
getOperation(t, authDoc, "/api/v1/public/auth/confirm-email-code", http.MethodPost),
|
|
)
|
|
assertOperationParameterRefs(
|
|
t,
|
|
getOperation(t, gatewayDoc, "/api/v1/public/auth/confirm-email-code", http.MethodPost),
|
|
)
|
|
|
|
compareSchemaRefs(
|
|
t,
|
|
authErrorEnvelope,
|
|
componentSchemaRef(t, gatewayDoc, "ErrorResponse"),
|
|
"ErrorResponse schema",
|
|
)
|
|
compareSchemaRefs(
|
|
t,
|
|
componentSchemaRef(t, authDoc, "ErrorBody"),
|
|
componentSchemaRef(t, gatewayDoc, "ErrorBody"),
|
|
"ErrorBody schema",
|
|
)
|
|
assertSchemaRef(t, gatewayProjectedEnvelope, errorResponseRef, "projected gateway auth error envelope")
|
|
}
|
|
|
|
func TestPublicOpenAPISpecErrorExamplesMatchStablePublicErrors(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
doc := loadSpec(t, "api", "public-openapi.yaml")
|
|
|
|
tests := []struct {
|
|
name string
|
|
responseName string
|
|
exampleName string
|
|
projection shared.PublicErrorProjection
|
|
}{
|
|
{
|
|
name: "send invalid request",
|
|
responseName: "SendEmailCodeBadRequestError",
|
|
exampleName: "invalidRequest",
|
|
projection: shared.ProjectPublicError(shared.InvalidRequest("email must be a single valid email address")),
|
|
},
|
|
{
|
|
name: "confirm invalid request",
|
|
responseName: "ConfirmEmailCodeBadRequestError",
|
|
exampleName: "invalidRequest",
|
|
projection: shared.ProjectPublicError(shared.InvalidRequest("challenge_id must not be empty")),
|
|
},
|
|
{
|
|
name: "confirm invalid code",
|
|
responseName: "ConfirmEmailCodeBadRequestError",
|
|
exampleName: "invalidCode",
|
|
projection: shared.ProjectPublicError(shared.InvalidCode()),
|
|
},
|
|
{
|
|
name: "confirm invalid client public key",
|
|
responseName: "ConfirmEmailCodeBadRequestError",
|
|
exampleName: "invalidClientPublicKey",
|
|
projection: shared.ProjectPublicError(shared.InvalidClientPublicKey()),
|
|
},
|
|
{
|
|
name: "confirm invalid time zone",
|
|
responseName: "ConfirmEmailCodeBadRequestError",
|
|
exampleName: "invalidTimeZone",
|
|
projection: shared.ProjectPublicError(shared.InvalidRequest("time_zone must be a valid IANA time zone name")),
|
|
},
|
|
{
|
|
name: "challenge not found",
|
|
responseName: "ChallengeNotFoundError",
|
|
exampleName: "notFound",
|
|
projection: shared.ProjectPublicError(shared.ChallengeNotFound()),
|
|
},
|
|
{
|
|
name: "challenge expired",
|
|
responseName: "ChallengeExpiredError",
|
|
exampleName: "expired",
|
|
projection: shared.ProjectPublicError(shared.ChallengeExpired()),
|
|
},
|
|
{
|
|
name: "blocked by policy",
|
|
responseName: "BlockedByPolicyError",
|
|
exampleName: "blocked",
|
|
projection: shared.ProjectPublicError(shared.BlockedByPolicy()),
|
|
},
|
|
{
|
|
name: "session limit exceeded",
|
|
responseName: "SessionLimitExceededError",
|
|
exampleName: "limitExceeded",
|
|
projection: shared.ProjectPublicError(shared.SessionLimitExceeded()),
|
|
},
|
|
{
|
|
name: "service unavailable",
|
|
responseName: "ServiceUnavailableError",
|
|
exampleName: "unavailable",
|
|
projection: shared.ProjectPublicError(shared.ServiceUnavailable(nil)),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := responseExampleValue(t, doc, tt.responseName, tt.exampleName)
|
|
want := map[string]any{
|
|
"error": map[string]any{
|
|
"code": tt.projection.Code,
|
|
"message": tt.projection.Message,
|
|
},
|
|
}
|
|
|
|
require.JSONEq(t, string(mustJSON(t, want)), string(mustJSON(t, got)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInternalOpenAPISpecFreezesMutationContracts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
doc := loadSpec(t, "api", "internal-openapi.yaml")
|
|
|
|
blockUser := componentSchemaRef(t, doc, "BlockUserRequest")
|
|
if got := len(blockUser.Value.OneOf); got != 2 {
|
|
require.Failf(t, "test failed", "BlockUserRequest oneOf length = %d, want 2", got)
|
|
}
|
|
|
|
refs := []string{
|
|
blockUser.Value.OneOf[0].Ref,
|
|
blockUser.Value.OneOf[1].Ref,
|
|
}
|
|
slices.Sort(refs)
|
|
wantRefs := []string{
|
|
"#/components/schemas/BlockUserByEmailRequest",
|
|
"#/components/schemas/BlockUserByUserIDRequest",
|
|
}
|
|
if !slices.Equal(refs, wantRefs) {
|
|
require.Failf(t, "test failed", "BlockUserRequest oneOf refs = %v, want %v", refs, wantRefs)
|
|
}
|
|
|
|
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByUserIDRequest"), "reason_code", "actor", "user_id")
|
|
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByEmailRequest"), "reason_code", "actor", "email")
|
|
assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "device_session_id", "affected_session_count")
|
|
assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "user_id", "affected_session_count", "affected_device_session_ids")
|
|
assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "subject_kind", "subject_value", "affected_session_count", "affected_device_session_ids")
|
|
|
|
assertStringEnum(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "revoked", "already_revoked")
|
|
assertStringEnum(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "revoked", "no_active_sessions")
|
|
assertStringEnum(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "blocked", "already_blocked")
|
|
}
|
|
|
|
func loadSpec(t *testing.T, pathElems ...string) *openapi3.T {
|
|
t.Helper()
|
|
|
|
_, thisFile, _, ok := runtime.Caller(0)
|
|
if !ok {
|
|
require.FailNow(t, "runtime.Caller failed")
|
|
}
|
|
|
|
specPath := filepath.Join(append([]string{filepath.Dir(thisFile)}, pathElems...)...)
|
|
loader := openapi3.NewLoader()
|
|
doc, err := loader.LoadFromFile(specPath)
|
|
if err != nil {
|
|
require.Failf(t, "test failed", "load spec %s: %v", specPath, err)
|
|
}
|
|
if doc == nil {
|
|
require.Failf(t, "test failed", "load spec %s: returned nil document", specPath)
|
|
}
|
|
if doc.Info == nil {
|
|
require.Failf(t, "test failed", "load spec %s: missing info section", specPath)
|
|
}
|
|
if doc.Info.Version != "v1" {
|
|
require.Failf(t, "test failed", "spec %s version = %q, want v1", specPath, doc.Info.Version)
|
|
}
|
|
if err := doc.Validate(context.Background()); err != nil {
|
|
require.Failf(t, "test failed", "validate spec %s: %v", specPath, err)
|
|
}
|
|
|
|
return doc
|
|
}
|
|
|
|
func getOperation(t *testing.T, doc *openapi3.T, path string, method string) *openapi3.Operation {
|
|
t.Helper()
|
|
|
|
if doc.Paths == nil {
|
|
require.Failf(t, "test failed", "spec is missing paths while looking up %s %s", method, path)
|
|
}
|
|
pathItem := doc.Paths.Value(path)
|
|
if pathItem == nil {
|
|
require.Failf(t, "test failed", "spec is missing path %s", path)
|
|
}
|
|
operation := pathItem.GetOperation(method)
|
|
if operation == nil {
|
|
require.Failf(t, "test failed", "spec is missing %s operation for path %s", method, path)
|
|
}
|
|
|
|
return operation
|
|
}
|
|
|
|
func requestSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef {
|
|
t.Helper()
|
|
|
|
if operation.RequestBody == nil || operation.RequestBody.Value == nil {
|
|
require.FailNow(t, "operation is missing request body")
|
|
}
|
|
mediaType := operation.RequestBody.Value.Content.Get("application/json")
|
|
if mediaType == nil || mediaType.Schema == nil {
|
|
require.FailNow(t, "operation is missing application/json request schema")
|
|
}
|
|
|
|
return mediaType.Schema
|
|
}
|
|
|
|
func responseSchemaRef(t *testing.T, operation *openapi3.Operation, status int) *openapi3.SchemaRef {
|
|
t.Helper()
|
|
|
|
if operation.Responses == nil {
|
|
require.Failf(t, "test failed", "operation is missing responses for status %d", status)
|
|
}
|
|
response := operation.Responses.Status(status)
|
|
if response == nil || response.Value == nil {
|
|
require.Failf(t, "test failed", "operation is missing response for status %d", status)
|
|
}
|
|
mediaType := response.Value.Content.Get("application/json")
|
|
if mediaType == nil || mediaType.Schema == nil {
|
|
require.Failf(t, "test failed", "operation response %d is missing application/json schema", status)
|
|
}
|
|
|
|
return mediaType.Schema
|
|
}
|
|
|
|
func defaultResponseSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef {
|
|
t.Helper()
|
|
|
|
if operation.Responses == nil {
|
|
require.FailNow(t, "operation is missing default responses")
|
|
}
|
|
response := operation.Responses.Default()
|
|
if response == nil || response.Value == nil {
|
|
require.FailNow(t, "operation is missing default response")
|
|
}
|
|
mediaType := response.Value.Content.Get("application/json")
|
|
if mediaType == nil || mediaType.Schema == nil {
|
|
require.FailNow(t, "operation default response is missing application/json schema")
|
|
}
|
|
|
|
return mediaType.Schema
|
|
}
|
|
|
|
func componentSchemaRef(t *testing.T, doc *openapi3.T, name string) *openapi3.SchemaRef {
|
|
t.Helper()
|
|
|
|
if doc.Components == nil {
|
|
require.Failf(t, "test failed", "spec is missing components while looking up schema %s", name)
|
|
}
|
|
schema := doc.Components.Schemas[name]
|
|
if schema == nil || schema.Value == nil {
|
|
require.Failf(t, "test failed", "spec is missing schema %s", name)
|
|
}
|
|
|
|
return schema
|
|
}
|
|
|
|
func responseExampleValue(t *testing.T, doc *openapi3.T, responseName string, exampleName string) any {
|
|
t.Helper()
|
|
|
|
if doc.Components == nil {
|
|
require.Failf(t, "test failed", "spec is missing components while looking up response %s", responseName)
|
|
}
|
|
response := doc.Components.Responses[responseName]
|
|
if response == nil || response.Value == nil {
|
|
require.Failf(t, "test failed", "spec is missing response %s", responseName)
|
|
}
|
|
mediaType := response.Value.Content.Get("application/json")
|
|
if mediaType == nil {
|
|
require.Failf(t, "test failed", "response %s is missing application/json content", responseName)
|
|
}
|
|
example := mediaType.Examples[exampleName]
|
|
if example == nil || example.Value == nil {
|
|
require.Failf(t, "test failed", "response %s is missing example %s", responseName, exampleName)
|
|
}
|
|
|
|
return example.Value.Value
|
|
}
|
|
|
|
func compareSchemaRefs(t *testing.T, got *openapi3.SchemaRef, want *openapi3.SchemaRef, name string) {
|
|
t.Helper()
|
|
|
|
gotJSON := mustJSON(t, got)
|
|
wantJSON := mustJSON(t, want)
|
|
if !bytes.Equal(gotJSON, wantJSON) {
|
|
require.Failf(t, "test failed", "%s mismatch:\n got: %s\nwant: %s", name, gotJSON, wantJSON)
|
|
}
|
|
}
|
|
|
|
func compareParameterRefs(t *testing.T, got openapi3.Parameters, want openapi3.Parameters, name string) {
|
|
t.Helper()
|
|
|
|
gotJSON := mustJSON(t, got)
|
|
wantJSON := mustJSON(t, want)
|
|
if !bytes.Equal(gotJSON, wantJSON) {
|
|
require.Failf(t, "test failed", "%s mismatch:\n got: %s\nwant: %s", name, gotJSON, wantJSON)
|
|
}
|
|
}
|
|
|
|
func assertSchemaRef(t *testing.T, schemaRef *openapi3.SchemaRef, want string, name string) {
|
|
t.Helper()
|
|
|
|
if schemaRef.Ref != want {
|
|
require.Failf(t, "test failed", "%s ref = %q, want %q", name, schemaRef.Ref, want)
|
|
}
|
|
}
|
|
|
|
func assertOperationParameterRefs(t *testing.T, operation *openapi3.Operation, refs ...string) {
|
|
t.Helper()
|
|
|
|
if len(operation.Parameters) != len(refs) {
|
|
require.Failf(t, "test failed", "operation parameter count = %d, want %d", len(operation.Parameters), len(refs))
|
|
}
|
|
|
|
for index, want := range refs {
|
|
if operation.Parameters[index] == nil {
|
|
require.Failf(t, "test failed", "operation parameter %d is nil", index)
|
|
}
|
|
if operation.Parameters[index].Ref != want {
|
|
require.Failf(t, "test failed", "operation parameter %d ref = %q, want %q", index, operation.Parameters[index].Ref, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func assertRequiredFields(t *testing.T, schemaRef *openapi3.SchemaRef, fields ...string) {
|
|
t.Helper()
|
|
|
|
required := append([]string(nil), schemaRef.Value.Required...)
|
|
slices.Sort(required)
|
|
want := append([]string(nil), fields...)
|
|
slices.Sort(want)
|
|
if !slices.Equal(required, want) {
|
|
require.Failf(t, "test failed", "schema required fields = %v, want %v", required, want)
|
|
}
|
|
}
|
|
|
|
func assertStringEnum(t *testing.T, schemaRef *openapi3.SchemaRef, property string, values ...string) {
|
|
t.Helper()
|
|
|
|
prop := schemaRef.Value.Properties[property]
|
|
if prop == nil || prop.Value == nil {
|
|
require.Failf(t, "test failed", "schema is missing property %s", property)
|
|
}
|
|
|
|
got := make([]string, 0, len(prop.Value.Enum))
|
|
for _, raw := range prop.Value.Enum {
|
|
value, ok := raw.(string)
|
|
if !ok {
|
|
require.Failf(t, "test failed", "property %s enum contains non-string value %T", property, raw)
|
|
}
|
|
got = append(got, value)
|
|
}
|
|
|
|
if !slices.Equal(got, values) {
|
|
require.Failf(t, "test failed", "property %s enum = %v, want %v", property, got, values)
|
|
}
|
|
}
|
|
|
|
func mustJSON(t *testing.T, value any) []byte {
|
|
t.Helper()
|
|
|
|
data, err := json.Marshal(value)
|
|
if err != nil {
|
|
require.Failf(t, "test failed", "marshal JSON: %v", err)
|
|
}
|
|
|
|
return data
|
|
}
|
|
|
|
func publicErrorStatuses(path string) []int {
|
|
switch path {
|
|
case "/api/v1/public/auth/send-email-code":
|
|
return []int{http.StatusBadRequest, http.StatusServiceUnavailable}
|
|
case "/api/v1/public/auth/confirm-email-code":
|
|
return []int{
|
|
http.StatusBadRequest,
|
|
http.StatusForbidden,
|
|
http.StatusNotFound,
|
|
http.StatusConflict,
|
|
http.StatusGone,
|
|
http.StatusServiceUnavailable,
|
|
}
|
|
default:
|
|
panic("unexpected public auth path: " + path)
|
|
}
|
|
}
|