package authn import ( "bytes" "crypto/sha256" "encoding/hex" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestVerifyPayloadHash(t *testing.T) { t.Parallel() payloadSum := sha256.Sum256([]byte("payload")) emptySum := sha256.Sum256(nil) otherSum := sha256.Sum256([]byte("other")) tests := []struct { name string payload []byte payloadHash []byte wantErr error }{ { name: "matches non-empty payload", payload: []byte("payload"), payloadHash: payloadSum[:], }, { name: "matches empty payload", payload: nil, payloadHash: emptySum[:], }, { name: "rejects digest with invalid length", payload: []byte("payload"), payloadHash: []byte("short"), wantErr: ErrInvalidPayloadHash, }, { name: "rejects digest mismatch", payload: []byte("payload"), payloadHash: otherSum[:], wantErr: ErrPayloadHashMismatch, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() err := VerifyPayloadHash(tt.payload, tt.payloadHash) if tt.wantErr == nil { require.NoError(t, err) return } require.ErrorIs(t, err, tt.wantErr) }) } } func TestBuildRequestSigningInput(t *testing.T) { t.Parallel() fields := RequestSigningFields{ ProtocolVersion: "v1", DeviceSessionID: "device-session-123", MessageType: "fleet.move", TimestampMS: 123456789, RequestID: "request-123", PayloadHash: mustSHA256([]byte("payload")), } got := BuildRequestSigningInput(fields) want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5") require.NoError(t, err) assert.Equal(t, want, got) } func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) { t.Parallel() base := RequestSigningFields{ ProtocolVersion: "v1", DeviceSessionID: "device-session-123", MessageType: "fleet.move", TimestampMS: 123456789, RequestID: "request-123", PayloadHash: mustSHA256([]byte("payload")), } baseInput := BuildRequestSigningInput(base) tests := []struct { name string mutate func(RequestSigningFields) RequestSigningFields }{ { name: "protocol version", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.ProtocolVersion = "v2" return fields }, }, { name: "device session id", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.DeviceSessionID = "device-session-456" return fields }, }, { name: "message type", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.MessageType = "fleet.attack" return fields }, }, { name: "timestamp", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.TimestampMS++ return fields }, }, { name: "request id", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.RequestID = "request-456" return fields }, }, { name: "payload hash", mutate: func(fields RequestSigningFields) RequestSigningFields { fields.PayloadHash = mustSHA256([]byte("other")) return fields }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() mutated := BuildRequestSigningInput(tt.mutate(base)) assert.False(t, bytes.Equal(baseInput, mutated)) }) } } func mustSHA256(payload []byte) []byte { sum := sha256.Sum256(payload) return sum[:] }