package geoip import ( "errors" "net/netip" "os" "testing" "github.com/oschwald/geoip2-golang/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const countryFixturePath = "test-data/test-data/GeoIP2-Country-Test.mmdb" type fakeReader struct { countryFunc func(netip.Addr) (*geoip2.Country, error) closeFunc func() error } func (f fakeReader) Country(addr netip.Addr) (*geoip2.Country, error) { if f.countryFunc == nil { return nil, nil } return f.countryFunc(addr) } func (f fakeReader) Close() error { if f.closeFunc == nil { return nil } return f.closeFunc() } func TestOpenRejectsEmptyPath(t *testing.T) { t.Parallel() tests := []struct { name string path string }{ {name: "empty", path: ""}, {name: "whitespace", path: " \t\n"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() resolver, err := Open(tt.path) require.Error(t, err) assert.Nil(t, resolver) assert.Contains(t, err.Error(), "database path must not be empty") }) } } func TestOpenMissingFile(t *testing.T) { t.Parallel() resolver, err := Open("test-data/test-data/does-not-exist.mmdb") require.Error(t, err) assert.Nil(t, resolver) assert.Contains(t, err.Error(), `open geoip database "test-data/test-data/does-not-exist.mmdb"`) } func TestResolverFixtureLookups(t *testing.T) { t.Parallel() tests := []struct { name string lookup func(*Resolver) (string, error) want string }{ { name: "addr lookup", lookup: func(resolver *Resolver) (string, error) { return resolver.Country(netip.MustParseAddr("81.2.69.160")) }, want: "GB", }, { name: "string lookup", lookup: func(resolver *Resolver) (string, error) { return resolver.CountryString(" 81.2.69.160 ") }, want: "GB", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() resolver := openFixtureResolver(t) got, err := tt.lookup(resolver) require.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestNewGeoIPLegacyAlias(t *testing.T) { t.Parallel() requireFixtureDatabase(t) resolver, err := NewGeoIP(countryFixturePath) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, resolver.Close()) }) got, err := resolver.CountryString("81.2.69.160") require.NoError(t, err) assert.Equal(t, "GB", got) } func TestResolverFixtureReturnsNotFoundForPrivateAddress(t *testing.T) { t.Parallel() resolver := openFixtureResolver(t) _, err := resolver.Country(netip.MustParseAddr("192.168.1.1")) require.Error(t, err) assert.ErrorIs(t, err, ErrCountryNotFound) } func TestResolverCloseIsIdempotent(t *testing.T) { t.Parallel() resolver := openFixtureResolver(t) require.NoError(t, resolver.Close()) require.NoError(t, resolver.Close()) } func TestResolverCloseIsNilSafe(t *testing.T) { t.Parallel() var resolver *Resolver require.NoError(t, resolver.Close()) } func TestResolverCountryAfterClose(t *testing.T) { t.Parallel() resolver := openFixtureResolver(t) require.NoError(t, resolver.Close()) _, err := resolver.Country(netip.MustParseAddr("81.2.69.160")) require.Error(t, err) assert.ErrorIs(t, err, ErrClosed) } func TestResolverCountryRejectsZeroAddress(t *testing.T) { t.Parallel() resolver := newResolver(fakeReader{}) _, err := resolver.Country(netip.Addr{}) require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidAddress) } func TestResolverCountryStringRejectsInvalidAddress(t *testing.T) { t.Parallel() tests := []string{ "", "not-an-ip", "999.0.0.1", } for _, raw := range tests { raw := raw t.Run(raw, func(t *testing.T) { t.Parallel() resolver := newResolver(fakeReader{}) _, err := resolver.CountryString(raw) require.Error(t, err) assert.ErrorIs(t, err, ErrInvalidAddress) }) } } func TestResolverCountryWrapsReaderError(t *testing.T) { t.Parallel() lookupErr := errors.New("lookup failed") resolver := newResolver(fakeReader{ countryFunc: func(netip.Addr) (*geoip2.Country, error) { return nil, lookupErr }, }) _, err := resolver.Country(netip.MustParseAddr("203.0.113.10")) require.Error(t, err) assert.ErrorIs(t, err, lookupErr) assert.Contains(t, err.Error(), "lookup country for 203.0.113.10") } func TestResolverCountryReturnsNotFoundWhenISOCodeIsEmpty(t *testing.T) { t.Parallel() resolver := newResolver(fakeReader{ countryFunc: func(netip.Addr) (*geoip2.Country, error) { return &geoip2.Country{ Continent: geoip2.Continent{Code: "EU"}, }, nil }, }) _, err := resolver.Country(netip.MustParseAddr("203.0.113.10")) require.Error(t, err) assert.ErrorIs(t, err, ErrCountryNotFound) } func TestResolverCountryNormalizesCountryCode(t *testing.T) { t.Parallel() resolver := newResolver(fakeReader{ countryFunc: func(netip.Addr) (*geoip2.Country, error) { return &geoip2.Country{ Country: geoip2.CountryRecord{ISOCode: "gb"}, }, nil }, }) got, err := resolver.Country(netip.MustParseAddr("203.0.113.10")) require.NoError(t, err) assert.Equal(t, "GB", got) } func TestResolverCountryRejectsInvalidISOCode(t *testing.T) { t.Parallel() tests := []struct { name string code string }{ {name: "digit", code: "G1"}, {name: "too long", code: "USA"}, {name: "non ascii", code: "éé"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() resolver := newResolver(fakeReader{ countryFunc: func(netip.Addr) (*geoip2.Country, error) { return &geoip2.Country{ Country: geoip2.CountryRecord{ISOCode: tt.code}, }, nil }, }) _, err := resolver.Country(netip.MustParseAddr("203.0.113.10")) require.Error(t, err) assert.NotErrorIs(t, err, ErrCountryNotFound) assert.Contains(t, err.Error(), "invalid ISO 3166-1 alpha-2 code") }) } } func openFixtureResolver(t *testing.T) *Resolver { t.Helper() requireFixtureDatabase(t) resolver, err := Open(countryFixturePath) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, resolver.Close()) }) return resolver } func requireFixtureDatabase(t *testing.T) { t.Helper() _, err := os.Stat(countryFixturePath) require.NoErrorf( t, err, "fixture database %q is unavailable; run `git submodule update --init --recursive` before running tests", countryFixturePath, ) }