package postgres import ( "context" "database/sql/driver" "errors" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIsTransientConnError(t *testing.T) { t.Parallel() tests := []struct { name string err error want bool }{ {"nil", nil, false}, {"driver.ErrBadConn", driver.ErrBadConn, true}, {"wrapped ErrBadConn", fmt.Errorf("run migrations: %w", driver.ErrBadConn), true}, // The exact shape observed flaking CI: goose surfaces the driver // error as a plain string, so errors.Is can't see ErrBadConn. {"bad connection string", errors.New(`apply backend migrations: run migrations: ERROR 00001_init.sql: CREATE TABLE race_names: driver: bad connection`), true}, {"connection refused", errors.New("dial tcp 127.0.0.1:5432: connect: connection refused"), true}, {"connection reset", errors.New("read tcp: connection reset by peer"), true}, {"broken pipe", errors.New("write tcp: broken pipe"), true}, {"server closed", errors.New("pq: server closed the connection unexpectedly"), true}, {"syntax error is not transient", errors.New(`pq: syntax error at or near "TABL"`), false}, {"constraint violation is not transient", errors.New("pq: duplicate key value violates unique constraint"), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() assert.Equal(t, tt.want, isTransientConnError(tt.err)) }) } } func TestRetryOnTransientSucceedsAfterTransientFailures(t *testing.T) { t.Parallel() calls := 0 err := retryOnTransient(context.Background(), 5, time.Millisecond, func() error { calls++ if calls < 3 { return fmt.Errorf("attempt %d: %w", calls, driver.ErrBadConn) } return nil }) require.NoError(t, err) assert.Equal(t, 3, calls, "should retry until the transient error clears") } func TestRetryOnTransientStopsOnNonTransient(t *testing.T) { t.Parallel() sentinel := errors.New(`pq: syntax error at or near "TABL"`) calls := 0 err := retryOnTransient(context.Background(), 5, time.Millisecond, func() error { calls++ return sentinel }) require.ErrorIs(t, err, sentinel) assert.Equal(t, 1, calls, "a deterministic SQL error must not be retried") } func TestRetryOnTransientExhaustsAttempts(t *testing.T) { t.Parallel() calls := 0 err := retryOnTransient(context.Background(), 3, time.Millisecond, func() error { calls++ return driver.ErrBadConn }) require.ErrorIs(t, err, driver.ErrBadConn) assert.Equal(t, 3, calls, "must stop after the attempt budget is spent") } func TestRetryOnTransientRespectsContextCancellation(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) cancel() calls := 0 err := retryOnTransient(ctx, 5, time.Hour, func() error { calls++ return driver.ErrBadConn }) require.ErrorIs(t, err, context.Canceled) require.ErrorIs(t, err, driver.ErrBadConn, "the underlying transient error is preserved") assert.Equal(t, 1, calls, "cancellation during backoff stops further attempts") }