189 lines
5.7 KiB
Go
189 lines
5.7 KiB
Go
package grpcapi
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"galaxy/gateway/internal/session"
|
|
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
req := newValidExecuteCommandRequest()
|
|
req.Signature[0] ^= 0xff
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), req)
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
|
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.executeCalls)
|
|
}
|
|
|
|
func TestExecuteCommandRejectsWrongKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
record := newActiveSessionRecord()
|
|
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
|
return record, nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
|
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.executeCalls)
|
|
}
|
|
|
|
func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
record := newActiveSessionRecord()
|
|
record.ClientPublicKey = "%%%not-base64%%%"
|
|
return record, nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unavailable, status.Code(err))
|
|
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.executeCalls)
|
|
}
|
|
|
|
func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
req := newValidSubscribeEventsRequest()
|
|
req.Signature[0] ^= 0xff
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, req)
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
|
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.subscribeCalls)
|
|
}
|
|
|
|
func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
record := newActiveSessionRecord()
|
|
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
|
return record, nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
|
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.subscribeCalls)
|
|
}
|
|
|
|
func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
delegate := &recordingEdgeGatewayService{}
|
|
server, runGateway := newTestGateway(t, ServerDependencies{
|
|
Service: delegate,
|
|
SessionCache: staticSessionCache{
|
|
lookupFunc: func(context.Context, string) (session.Record, error) {
|
|
record := newActiveSessionRecord()
|
|
record.ClientPublicKey = "%%%not-base64%%%"
|
|
return record, nil
|
|
},
|
|
},
|
|
})
|
|
defer runGateway.stop(t)
|
|
|
|
addr := waitForListenAddr(t, server)
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.Unavailable, status.Code(err))
|
|
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
|
assert.Zero(t, delegate.subscribeCalls)
|
|
}
|