package authsession import ( "bytes" "context" "encoding/json" "net/http" "path/filepath" "runtime" "slices" "testing" "galaxy/authsession/internal/service/shared" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/require" ) func TestPublicOpenAPISpecValidates(t *testing.T) { t.Parallel() loadSpec(t, "api", "public-openapi.yaml") } func TestInternalOpenAPISpecValidates(t *testing.T) { t.Parallel() loadSpec(t, "api", "internal-openapi.yaml") } func TestPublicOpenAPISpecMatchesGatewayPublicAuthContract(t *testing.T) { t.Parallel() authDoc := loadSpec(t, "api", "public-openapi.yaml") gatewayDoc := loadSpec(t, "..", "gateway", "openapi.yaml") authErrorEnvelope := componentSchemaRef(t, authDoc, "ErrorResponse") gatewayProjectedEnvelope := defaultResponseSchemaRef(t, getOperation(t, gatewayDoc, "/api/v1/public/auth/send-email-code", http.MethodPost)) const errorResponseRef = "#/components/schemas/ErrorResponse" paths := []string{ "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code", } for _, path := range paths { authOperation := getOperation(t, authDoc, path, http.MethodPost) gatewayOperation := getOperation(t, gatewayDoc, path, http.MethodPost) if authOperation.OperationID != gatewayOperation.OperationID { require.Failf(t, "test failed", "operation %s: got operationId %q, want %q", path, authOperation.OperationID, gatewayOperation.OperationID) } compareSchemaRefs( t, requestSchemaRef(t, authOperation), requestSchemaRef(t, gatewayOperation), "path "+path+" request schema", ) compareSchemaRefs( t, responseSchemaRef(t, authOperation, http.StatusOK), responseSchemaRef(t, gatewayOperation, http.StatusOK), "path "+path+" success response schema", ) for _, status := range publicErrorStatuses(path) { assertSchemaRef(t, responseSchemaRef(t, authOperation, status), errorResponseRef, "path "+path+" error response "+http.StatusText(status)+" envelope") } } compareSchemaRefs( t, authErrorEnvelope, componentSchemaRef(t, gatewayDoc, "ErrorResponse"), "ErrorResponse schema", ) compareSchemaRefs( t, componentSchemaRef(t, authDoc, "ErrorBody"), componentSchemaRef(t, gatewayDoc, "ErrorBody"), "ErrorBody schema", ) assertSchemaRef(t, gatewayProjectedEnvelope, errorResponseRef, "projected gateway auth error envelope") } func TestPublicOpenAPISpecErrorExamplesMatchStablePublicErrors(t *testing.T) { t.Parallel() doc := loadSpec(t, "api", "public-openapi.yaml") tests := []struct { name string responseName string exampleName string projection shared.PublicErrorProjection }{ { name: "send invalid request", responseName: "SendEmailCodeBadRequestError", exampleName: "invalidRequest", projection: shared.ProjectPublicError(shared.InvalidRequest("email must be a single valid email address")), }, { name: "confirm invalid request", responseName: "ConfirmEmailCodeBadRequestError", exampleName: "invalidRequest", projection: shared.ProjectPublicError(shared.InvalidRequest("challenge_id must not be empty")), }, { name: "confirm invalid code", responseName: "ConfirmEmailCodeBadRequestError", exampleName: "invalidCode", projection: shared.ProjectPublicError(shared.InvalidCode()), }, { name: "confirm invalid client public key", responseName: "ConfirmEmailCodeBadRequestError", exampleName: "invalidClientPublicKey", projection: shared.ProjectPublicError(shared.InvalidClientPublicKey()), }, { name: "confirm invalid time zone", responseName: "ConfirmEmailCodeBadRequestError", exampleName: "invalidTimeZone", projection: shared.ProjectPublicError(shared.InvalidRequest("time_zone must be a valid IANA time zone name")), }, { name: "challenge not found", responseName: "ChallengeNotFoundError", exampleName: "notFound", projection: shared.ProjectPublicError(shared.ChallengeNotFound()), }, { name: "challenge expired", responseName: "ChallengeExpiredError", exampleName: "expired", projection: shared.ProjectPublicError(shared.ChallengeExpired()), }, { name: "blocked by policy", responseName: "BlockedByPolicyError", exampleName: "blocked", projection: shared.ProjectPublicError(shared.BlockedByPolicy()), }, { name: "session limit exceeded", responseName: "SessionLimitExceededError", exampleName: "limitExceeded", projection: shared.ProjectPublicError(shared.SessionLimitExceeded()), }, { name: "service unavailable", responseName: "ServiceUnavailableError", exampleName: "unavailable", projection: shared.ProjectPublicError(shared.ServiceUnavailable(nil)), }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() got := responseExampleValue(t, doc, tt.responseName, tt.exampleName) want := map[string]any{ "error": map[string]any{ "code": tt.projection.Code, "message": tt.projection.Message, }, } require.JSONEq(t, string(mustJSON(t, want)), string(mustJSON(t, got))) }) } } func TestInternalOpenAPISpecFreezesMutationContracts(t *testing.T) { t.Parallel() doc := loadSpec(t, "api", "internal-openapi.yaml") blockUser := componentSchemaRef(t, doc, "BlockUserRequest") if got := len(blockUser.Value.OneOf); got != 2 { require.Failf(t, "test failed", "BlockUserRequest oneOf length = %d, want 2", got) } refs := []string{ blockUser.Value.OneOf[0].Ref, blockUser.Value.OneOf[1].Ref, } slices.Sort(refs) wantRefs := []string{ "#/components/schemas/BlockUserByEmailRequest", "#/components/schemas/BlockUserByUserIDRequest", } if !slices.Equal(refs, wantRefs) { require.Failf(t, "test failed", "BlockUserRequest oneOf refs = %v, want %v", refs, wantRefs) } assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByUserIDRequest"), "reason_code", "actor", "user_id") assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserByEmailRequest"), "reason_code", "actor", "email") assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "device_session_id", "affected_session_count") assertRequiredFields(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "user_id", "affected_session_count", "affected_device_session_ids") assertRequiredFields(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "subject_kind", "subject_value", "affected_session_count", "affected_device_session_ids") assertStringEnum(t, componentSchemaRef(t, doc, "RevokeDeviceSessionResponse"), "outcome", "revoked", "already_revoked") assertStringEnum(t, componentSchemaRef(t, doc, "RevokeAllUserSessionsResponse"), "outcome", "revoked", "no_active_sessions") assertStringEnum(t, componentSchemaRef(t, doc, "BlockUserResponse"), "outcome", "blocked", "already_blocked") } func loadSpec(t *testing.T, pathElems ...string) *openapi3.T { t.Helper() _, thisFile, _, ok := runtime.Caller(0) if !ok { require.FailNow(t, "runtime.Caller failed") } specPath := filepath.Join(append([]string{filepath.Dir(thisFile)}, pathElems...)...) loader := openapi3.NewLoader() doc, err := loader.LoadFromFile(specPath) if err != nil { require.Failf(t, "test failed", "load spec %s: %v", specPath, err) } if doc == nil { require.Failf(t, "test failed", "load spec %s: returned nil document", specPath) } if doc.Info == nil { require.Failf(t, "test failed", "load spec %s: missing info section", specPath) } if doc.Info.Version != "v1" { require.Failf(t, "test failed", "spec %s version = %q, want v1", specPath, doc.Info.Version) } if err := doc.Validate(context.Background()); err != nil { require.Failf(t, "test failed", "validate spec %s: %v", specPath, err) } return doc } func getOperation(t *testing.T, doc *openapi3.T, path string, method string) *openapi3.Operation { t.Helper() if doc.Paths == nil { require.Failf(t, "test failed", "spec is missing paths while looking up %s %s", method, path) } pathItem := doc.Paths.Value(path) if pathItem == nil { require.Failf(t, "test failed", "spec is missing path %s", path) } operation := pathItem.GetOperation(method) if operation == nil { require.Failf(t, "test failed", "spec is missing %s operation for path %s", method, path) } return operation } func requestSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef { t.Helper() if operation.RequestBody == nil || operation.RequestBody.Value == nil { require.FailNow(t, "operation is missing request body") } mediaType := operation.RequestBody.Value.Content.Get("application/json") if mediaType == nil || mediaType.Schema == nil { require.FailNow(t, "operation is missing application/json request schema") } return mediaType.Schema } func responseSchemaRef(t *testing.T, operation *openapi3.Operation, status int) *openapi3.SchemaRef { t.Helper() if operation.Responses == nil { require.Failf(t, "test failed", "operation is missing responses for status %d", status) } response := operation.Responses.Status(status) if response == nil || response.Value == nil { require.Failf(t, "test failed", "operation is missing response for status %d", status) } mediaType := response.Value.Content.Get("application/json") if mediaType == nil || mediaType.Schema == nil { require.Failf(t, "test failed", "operation response %d is missing application/json schema", status) } return mediaType.Schema } func defaultResponseSchemaRef(t *testing.T, operation *openapi3.Operation) *openapi3.SchemaRef { t.Helper() if operation.Responses == nil { require.FailNow(t, "operation is missing default responses") } response := operation.Responses.Default() if response == nil || response.Value == nil { require.FailNow(t, "operation is missing default response") } mediaType := response.Value.Content.Get("application/json") if mediaType == nil || mediaType.Schema == nil { require.FailNow(t, "operation default response is missing application/json schema") } return mediaType.Schema } func componentSchemaRef(t *testing.T, doc *openapi3.T, name string) *openapi3.SchemaRef { t.Helper() if doc.Components == nil { require.Failf(t, "test failed", "spec is missing components while looking up schema %s", name) } schema := doc.Components.Schemas[name] if schema == nil || schema.Value == nil { require.Failf(t, "test failed", "spec is missing schema %s", name) } return schema } func responseExampleValue(t *testing.T, doc *openapi3.T, responseName string, exampleName string) any { t.Helper() if doc.Components == nil { require.Failf(t, "test failed", "spec is missing components while looking up response %s", responseName) } response := doc.Components.Responses[responseName] if response == nil || response.Value == nil { require.Failf(t, "test failed", "spec is missing response %s", responseName) } mediaType := response.Value.Content.Get("application/json") if mediaType == nil { require.Failf(t, "test failed", "response %s is missing application/json content", responseName) } example := mediaType.Examples[exampleName] if example == nil || example.Value == nil { require.Failf(t, "test failed", "response %s is missing example %s", responseName, exampleName) } return example.Value.Value } func compareSchemaRefs(t *testing.T, got *openapi3.SchemaRef, want *openapi3.SchemaRef, name string) { t.Helper() gotJSON := mustJSON(t, got) wantJSON := mustJSON(t, want) if !bytes.Equal(gotJSON, wantJSON) { require.Failf(t, "test failed", "%s mismatch:\n got: %s\nwant: %s", name, gotJSON, wantJSON) } } func assertSchemaRef(t *testing.T, schemaRef *openapi3.SchemaRef, want string, name string) { t.Helper() if schemaRef.Ref != want { require.Failf(t, "test failed", "%s ref = %q, want %q", name, schemaRef.Ref, want) } } func assertRequiredFields(t *testing.T, schemaRef *openapi3.SchemaRef, fields ...string) { t.Helper() required := append([]string(nil), schemaRef.Value.Required...) slices.Sort(required) want := append([]string(nil), fields...) slices.Sort(want) if !slices.Equal(required, want) { require.Failf(t, "test failed", "schema required fields = %v, want %v", required, want) } } func assertStringEnum(t *testing.T, schemaRef *openapi3.SchemaRef, property string, values ...string) { t.Helper() prop := schemaRef.Value.Properties[property] if prop == nil || prop.Value == nil { require.Failf(t, "test failed", "schema is missing property %s", property) } got := make([]string, 0, len(prop.Value.Enum)) for _, raw := range prop.Value.Enum { value, ok := raw.(string) if !ok { require.Failf(t, "test failed", "property %s enum contains non-string value %T", property, raw) } got = append(got, value) } if !slices.Equal(got, values) { require.Failf(t, "test failed", "property %s enum = %v, want %v", property, got, values) } } func mustJSON(t *testing.T, value any) []byte { t.Helper() data, err := json.Marshal(value) if err != nil { require.Failf(t, "test failed", "marshal JSON: %v", err) } return data } func publicErrorStatuses(path string) []int { switch path { case "/api/v1/public/auth/send-email-code": return []int{http.StatusBadRequest, http.StatusServiceUnavailable} case "/api/v1/public/auth/confirm-email-code": return []int{ http.StatusBadRequest, http.StatusForbidden, http.StatusNotFound, http.StatusConflict, http.StatusGone, http.StatusServiceUnavailable, } default: panic("unexpected public auth path: " + path) } }