301 lines
6.2 KiB
Go
301 lines
6.2 KiB
Go
package geoip
|
|
|
|
import (
|
|
"errors"
|
|
"net/netip"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/oschwald/geoip2-golang/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
const countryFixturePath = "test-data/test-data/GeoIP2-Country-Test.mmdb"
|
|
|
|
type fakeReader struct {
|
|
countryFunc func(netip.Addr) (*geoip2.Country, error)
|
|
closeFunc func() error
|
|
}
|
|
|
|
func (f fakeReader) Country(addr netip.Addr) (*geoip2.Country, error) {
|
|
if f.countryFunc == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
return f.countryFunc(addr)
|
|
}
|
|
|
|
func (f fakeReader) Close() error {
|
|
if f.closeFunc == nil {
|
|
return nil
|
|
}
|
|
|
|
return f.closeFunc()
|
|
}
|
|
|
|
func TestOpenRejectsEmptyPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
}{
|
|
{name: "empty", path: ""},
|
|
{name: "whitespace", path: " \t\n"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver, err := Open(tt.path)
|
|
require.Error(t, err)
|
|
assert.Nil(t, resolver)
|
|
assert.Contains(t, err.Error(), "database path must not be empty")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenMissingFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver, err := Open("test-data/test-data/does-not-exist.mmdb")
|
|
require.Error(t, err)
|
|
assert.Nil(t, resolver)
|
|
assert.Contains(t, err.Error(), `open geoip database "test-data/test-data/does-not-exist.mmdb"`)
|
|
}
|
|
|
|
func TestResolverFixtureLookups(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
lookup func(*Resolver) (string, error)
|
|
want string
|
|
}{
|
|
{
|
|
name: "addr lookup",
|
|
lookup: func(resolver *Resolver) (string, error) {
|
|
return resolver.Country(netip.MustParseAddr("81.2.69.160"))
|
|
},
|
|
want: "GB",
|
|
},
|
|
{
|
|
name: "string lookup",
|
|
lookup: func(resolver *Resolver) (string, error) {
|
|
return resolver.CountryString(" 81.2.69.160 ")
|
|
},
|
|
want: "GB",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := openFixtureResolver(t)
|
|
|
|
got, err := tt.lookup(resolver)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewGeoIPLegacyAlias(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
requireFixtureDatabase(t)
|
|
|
|
resolver, err := NewGeoIP(countryFixturePath)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
require.NoError(t, resolver.Close())
|
|
})
|
|
|
|
got, err := resolver.CountryString("81.2.69.160")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "GB", got)
|
|
}
|
|
|
|
func TestResolverFixtureReturnsNotFoundForPrivateAddress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := openFixtureResolver(t)
|
|
|
|
_, err := resolver.Country(netip.MustParseAddr("192.168.1.1"))
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ErrCountryNotFound)
|
|
}
|
|
|
|
func TestResolverCloseIsIdempotent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := openFixtureResolver(t)
|
|
|
|
require.NoError(t, resolver.Close())
|
|
require.NoError(t, resolver.Close())
|
|
}
|
|
|
|
func TestResolverCloseIsNilSafe(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var resolver *Resolver
|
|
require.NoError(t, resolver.Close())
|
|
}
|
|
|
|
func TestResolverCountryAfterClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := openFixtureResolver(t)
|
|
require.NoError(t, resolver.Close())
|
|
|
|
_, err := resolver.Country(netip.MustParseAddr("81.2.69.160"))
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ErrClosed)
|
|
}
|
|
|
|
func TestResolverCountryRejectsZeroAddress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := newResolver(fakeReader{})
|
|
|
|
_, err := resolver.Country(netip.Addr{})
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ErrInvalidAddress)
|
|
}
|
|
|
|
func TestResolverCountryStringRejectsInvalidAddress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []string{
|
|
"",
|
|
"not-an-ip",
|
|
"999.0.0.1",
|
|
}
|
|
|
|
for _, raw := range tests {
|
|
raw := raw
|
|
t.Run(raw, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := newResolver(fakeReader{})
|
|
|
|
_, err := resolver.CountryString(raw)
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ErrInvalidAddress)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolverCountryWrapsReaderError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
lookupErr := errors.New("lookup failed")
|
|
resolver := newResolver(fakeReader{
|
|
countryFunc: func(netip.Addr) (*geoip2.Country, error) {
|
|
return nil, lookupErr
|
|
},
|
|
})
|
|
|
|
_, err := resolver.Country(netip.MustParseAddr("203.0.113.10"))
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, lookupErr)
|
|
assert.Contains(t, err.Error(), "lookup country for 203.0.113.10")
|
|
}
|
|
|
|
func TestResolverCountryReturnsNotFoundWhenISOCodeIsEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := newResolver(fakeReader{
|
|
countryFunc: func(netip.Addr) (*geoip2.Country, error) {
|
|
return &geoip2.Country{
|
|
Continent: geoip2.Continent{Code: "EU"},
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
_, err := resolver.Country(netip.MustParseAddr("203.0.113.10"))
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, ErrCountryNotFound)
|
|
}
|
|
|
|
func TestResolverCountryNormalizesCountryCode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := newResolver(fakeReader{
|
|
countryFunc: func(netip.Addr) (*geoip2.Country, error) {
|
|
return &geoip2.Country{
|
|
Country: geoip2.CountryRecord{ISOCode: "gb"},
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
got, err := resolver.Country(netip.MustParseAddr("203.0.113.10"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "GB", got)
|
|
}
|
|
|
|
func TestResolverCountryRejectsInvalidISOCode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
}{
|
|
{name: "digit", code: "G1"},
|
|
{name: "too long", code: "USA"},
|
|
{name: "non ascii", code: "éé"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolver := newResolver(fakeReader{
|
|
countryFunc: func(netip.Addr) (*geoip2.Country, error) {
|
|
return &geoip2.Country{
|
|
Country: geoip2.CountryRecord{ISOCode: tt.code},
|
|
}, nil
|
|
},
|
|
})
|
|
|
|
_, err := resolver.Country(netip.MustParseAddr("203.0.113.10"))
|
|
require.Error(t, err)
|
|
assert.NotErrorIs(t, err, ErrCountryNotFound)
|
|
assert.Contains(t, err.Error(), "invalid ISO 3166-1 alpha-2 code")
|
|
})
|
|
}
|
|
}
|
|
|
|
func openFixtureResolver(t *testing.T) *Resolver {
|
|
t.Helper()
|
|
|
|
requireFixtureDatabase(t)
|
|
|
|
resolver, err := Open(countryFixturePath)
|
|
require.NoError(t, err)
|
|
|
|
t.Cleanup(func() {
|
|
require.NoError(t, resolver.Close())
|
|
})
|
|
|
|
return resolver
|
|
}
|
|
|
|
func requireFixtureDatabase(t *testing.T) {
|
|
t.Helper()
|
|
|
|
_, err := os.Stat(countryFixturePath)
|
|
require.NoErrorf(
|
|
t,
|
|
err,
|
|
"fixture database %q is unavailable; run `git submodule update --init --recursive` before running tests",
|
|
countryFixturePath,
|
|
)
|
|
}
|