package postgres import ( "context" "database/sql" "errors" "fmt" "github.com/XSAM/otelsql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" ) // dbSystemAttribute identifies the wrapped backend in OpenTelemetry spans // without locking the package to a specific semconv release. var dbSystemAttribute = attribute.String("db.system", "postgresql") // Option configures the OpenTelemetry providers attached to a connection by // OpenPrimary, OpenReplicas, and InstrumentDBStats. Unset providers fall back // to the OpenTelemetry global providers. type Option func(*options) type options struct { tracerProvider trace.TracerProvider meterProvider metric.MeterProvider } // WithTracerProvider sets the tracer provider used for SQL statement spans. func WithTracerProvider(tp trace.TracerProvider) Option { return func(o *options) { o.tracerProvider = tp } } // WithMeterProvider sets the meter provider used for connection-pool stats. func WithMeterProvider(mp metric.MeterProvider) Option { return func(o *options) { o.meterProvider = mp } } func evalOptions(opts []Option) options { var resolved options for _, opt := range opts { if opt == nil { continue } opt(&resolved) } return resolved } func (o options) otelsqlOpenOptions() []otelsql.Option { out := []otelsql.Option{otelsql.WithAttributes(dbSystemAttribute)} if o.tracerProvider != nil { out = append(out, otelsql.WithTracerProvider(o.tracerProvider)) } if o.meterProvider != nil { out = append(out, otelsql.WithMeterProvider(o.meterProvider)) } return out } // OpenPrimary opens the primary `*sql.DB` from cfg. ctx bounds individual // pgx connect attempts via the parsed pgx config's ConnectTimeout (set to // cfg.OperationTimeout). The returned pool has SetMaxOpenConns, // SetMaxIdleConns and SetConnMaxLifetime applied. func OpenPrimary(ctx context.Context, cfg Config, opts ...Option) (*sql.DB, error) { if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("open postgres primary: %w", err) } db, err := openDB(ctx, cfg, cfg.PrimaryDSN, evalOptions(opts)) if err != nil { return nil, fmt.Errorf("open postgres primary: %w", err) } return db, nil } // OpenReplicas opens one `*sql.DB` per replica DSN. It returns nil when no // replicas are configured. When opening a replica fails mid-way, every // already-opened replica is closed before returning the error. func OpenReplicas(ctx context.Context, cfg Config, opts ...Option) ([]*sql.DB, error) { if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("open postgres replicas: %w", err) } if len(cfg.ReplicaDSNs) == 0 { return nil, nil } resolved := evalOptions(opts) pools := make([]*sql.DB, 0, len(cfg.ReplicaDSNs)) for index, dsn := range cfg.ReplicaDSNs { db, err := openDB(ctx, cfg, dsn, resolved) if err != nil { for _, opened := range pools { _ = opened.Close() } return nil, fmt.Errorf("open postgres replica at index %d: %w", index, err) } pools = append(pools, db) } return pools, nil } func openDB(ctx context.Context, cfg Config, dsn string, opts options) (*sql.DB, error) { if ctx.Err() != nil { return nil, ctx.Err() } pgxCfg, err := pgx.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse dsn: %w", err) } pgxCfg.ConnectTimeout = cfg.OperationTimeout registeredName := stdlib.RegisterConnConfig(pgxCfg) db, err := otelsql.Open("pgx", registeredName, opts.otelsqlOpenOptions()...) if err != nil { stdlib.UnregisterConnConfig(registeredName) return nil, fmt.Errorf("otelsql open: %w", err) } if db == nil { stdlib.UnregisterConnConfig(registeredName) return nil, errors.New("otelsql open returned nil db") } db.SetMaxOpenConns(cfg.MaxOpenConns) db.SetMaxIdleConns(cfg.MaxIdleConns) db.SetConnMaxLifetime(cfg.ConnMaxLifetime) return db, nil }