130 lines
3.9 KiB
Go
130 lines
3.9 KiB
Go
package grpcapi
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"galaxy/gateway/internal/session"
|
|
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// resolvedSessionFromContext returns the session record previously attached to
|
|
// ctx by the session-lookup gateway wrapper.
|
|
func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) {
|
|
if ctx == nil {
|
|
return session.Record{}, false
|
|
}
|
|
|
|
record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record)
|
|
if !ok {
|
|
return session.Record{}, false
|
|
}
|
|
|
|
return cloneSessionRecord(record), true
|
|
}
|
|
|
|
// sessionLookupService resolves the authenticated session from SessionCache
|
|
// after envelope parsing succeeds and before later auth steps run.
|
|
type sessionLookupService struct {
|
|
gatewayv1.UnimplementedEdgeGatewayServer
|
|
|
|
delegate gatewayv1.EdgeGatewayServer
|
|
cache session.Cache
|
|
}
|
|
|
|
// ExecuteCommand resolves the cached session for req and only then forwards it
|
|
// to the configured delegate with the resolved session attached to ctx.
|
|
func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
|
record, err := s.lookupSession(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req)
|
|
}
|
|
|
|
// SubscribeEvents resolves the cached session for req and only then forwards it
|
|
// to the configured delegate with the resolved session attached to the stream
|
|
// context.
|
|
func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
|
record, err := s.lookupSession(stream.Context())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{
|
|
ServerStreamingServer: stream,
|
|
ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)),
|
|
})
|
|
}
|
|
|
|
// newSessionLookupService wraps delegate with the session-cache lookup gate.
|
|
func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer {
|
|
return sessionLookupService{
|
|
delegate: delegate,
|
|
cache: cache,
|
|
}
|
|
}
|
|
|
|
func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) {
|
|
envelope, ok := parsedEnvelopeFromContext(ctx)
|
|
if !ok {
|
|
return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
|
}
|
|
|
|
record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID)
|
|
switch {
|
|
case err == nil:
|
|
case errors.Is(err, session.ErrNotFound):
|
|
return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session")
|
|
default:
|
|
return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable")
|
|
}
|
|
|
|
if record.Status == session.StatusRevoked {
|
|
return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked")
|
|
}
|
|
|
|
return cloneSessionRecord(record), nil
|
|
}
|
|
|
|
func cloneSessionRecord(record session.Record) session.Record {
|
|
cloned := record
|
|
if record.RevokedAtMS != nil {
|
|
value := *record.RevokedAtMS
|
|
cloned.RevokedAtMS = &value
|
|
}
|
|
|
|
return cloned
|
|
}
|
|
|
|
type resolvedSessionContextKey struct{}
|
|
|
|
type resolvedSessionContextStream struct {
|
|
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
|
ctx context.Context
|
|
}
|
|
|
|
func (s resolvedSessionContextStream) Context() context.Context {
|
|
if s.ctx == nil {
|
|
return context.Background()
|
|
}
|
|
|
|
return s.ctx
|
|
}
|
|
|
|
type unavailableSessionCache struct{}
|
|
|
|
func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
|
return session.Record{}, errors.New("session cache is unavailable")
|
|
}
|
|
|
|
func (unavailableSessionCache) MarkRevoked(string) {}
|
|
func (unavailableSessionCache) MarkAllRevokedForUser(string) {}
|
|
|
|
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
|