feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,163 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyPayloadHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payloadSum := sha256.Sum256([]byte("payload"))
|
||||
emptySum := sha256.Sum256(nil)
|
||||
otherSum := sha256.Sum256([]byte("other"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
payloadHash []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "matches non-empty payload",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: payloadSum[:],
|
||||
},
|
||||
{
|
||||
name: "matches empty payload",
|
||||
payload: nil,
|
||||
payloadHash: emptySum[:],
|
||||
},
|
||||
{
|
||||
name: "rejects digest with invalid length",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: []byte("short"),
|
||||
wantErr: ErrInvalidPayloadHash,
|
||||
},
|
||||
{
|
||||
name: "rejects digest mismatch",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: otherSum[:],
|
||||
wantErr: ErrPayloadHashMismatch,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := VerifyPayloadHash(tt.payload, tt.payloadHash)
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fields := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
got := BuildRequestSigningInput(fields)
|
||||
|
||||
want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildRequestSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(RequestSigningFields) RequestSigningFields
|
||||
}{
|
||||
{
|
||||
name: "protocol version",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.ProtocolVersion = "v2"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device session id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.DeviceSessionID = "device-session-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message type",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.MessageType = "fleet.attack"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildRequestSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustSHA256(payload []byte) []byte {
|
||||
sum := sha256.Sum256(payload)
|
||||
return sum[:]
|
||||
}
|
||||
Reference in New Issue
Block a user