165 lines
4.2 KiB
Go
165 lines
4.2 KiB
Go
package geoip
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/oschwald/geoip2-golang/v2"
|
|
)
|
|
|
|
var (
|
|
// ErrInvalidAddress reports that a caller supplied an invalid IP address.
|
|
ErrInvalidAddress = errors.New("invalid IP address")
|
|
|
|
// ErrCountryNotFound reports that the GeoIP database contains no country
|
|
// record for a valid IP address.
|
|
ErrCountryNotFound = errors.New("country not found")
|
|
|
|
// ErrClosed reports that the resolver has already released its database
|
|
// resources and can no longer serve lookups.
|
|
ErrClosed = errors.New("geoip resolver is closed")
|
|
)
|
|
|
|
type countryReader interface {
|
|
Country(netip.Addr) (*geoip2.Country, error)
|
|
Close() error
|
|
}
|
|
|
|
// Resolver resolves ISO 3166-1 alpha-2 country codes from a local MaxMind
|
|
// country-capable database.
|
|
//
|
|
// Resolver is safe for concurrent lookups. Call Close when the process no
|
|
// longer needs the memory-mapped database file.
|
|
type Resolver struct {
|
|
mu sync.RWMutex
|
|
reader countryReader
|
|
}
|
|
|
|
// Open constructs a Resolver backed by the MaxMind database file at
|
|
// databasePath.
|
|
//
|
|
// databasePath must point to a local country-capable .mmdb file such as
|
|
// GeoLite2 Country. Open trims surrounding whitespace from databasePath before
|
|
// opening the file.
|
|
func Open(databasePath string) (*Resolver, error) {
|
|
path := strings.TrimSpace(databasePath)
|
|
if path == "" {
|
|
return nil, errors.New("open geoip database: database path must not be empty")
|
|
}
|
|
|
|
reader, err := geoip2.Open(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open geoip database %q: %w", path, err)
|
|
}
|
|
|
|
return newResolver(reader), nil
|
|
}
|
|
|
|
// NewGeoIP is a legacy alias for Open.
|
|
//
|
|
// Deprecated: use Open for new code.
|
|
func NewGeoIP(databasePath string) (*Resolver, error) {
|
|
return Open(databasePath)
|
|
}
|
|
|
|
func newResolver(reader countryReader) *Resolver {
|
|
return &Resolver{reader: reader}
|
|
}
|
|
|
|
// Country resolves addr to an uppercase ISO 3166-1 alpha-2 country code.
|
|
//
|
|
// Country returns ErrInvalidAddress when addr is not valid, ErrCountryNotFound
|
|
// when the database contains no country for addr, and ErrClosed after Close
|
|
// has been called successfully.
|
|
func (r *Resolver) Country(addr netip.Addr) (string, error) {
|
|
if !addr.IsValid() {
|
|
return "", fmt.Errorf("lookup country: %w", ErrInvalidAddress)
|
|
}
|
|
if r == nil {
|
|
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrClosed)
|
|
}
|
|
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
if r.reader == nil {
|
|
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrClosed)
|
|
}
|
|
|
|
record, err := r.reader.Country(addr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("lookup country for %s: %w", addr, err)
|
|
}
|
|
if record == nil {
|
|
return "", fmt.Errorf("lookup country for %s: nil country record", addr)
|
|
}
|
|
if !record.HasData() || strings.TrimSpace(record.Country.ISOCode) == "" {
|
|
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrCountryNotFound)
|
|
}
|
|
|
|
code, err := normalizeCountryCode(record.Country.ISOCode)
|
|
if err != nil {
|
|
return "", fmt.Errorf("lookup country for %s: %w", addr, err)
|
|
}
|
|
|
|
return code, nil
|
|
}
|
|
|
|
// CountryString resolves raw to an uppercase ISO 3166-1 alpha-2 country code.
|
|
//
|
|
// CountryString trims surrounding whitespace from raw before parsing it as an
|
|
// IP address and then delegates to Country.
|
|
func (r *Resolver) CountryString(raw string) (string, error) {
|
|
trimmed := strings.TrimSpace(raw)
|
|
|
|
addr, err := netip.ParseAddr(trimmed)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse IP address %q: %w", raw, errors.Join(ErrInvalidAddress, err))
|
|
}
|
|
|
|
return r.Country(addr)
|
|
}
|
|
|
|
// Close releases the underlying database resources.
|
|
//
|
|
// Close is idempotent and nil-safe.
|
|
func (r *Resolver) Close() error {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
if r.reader == nil {
|
|
return nil
|
|
}
|
|
|
|
reader := r.reader
|
|
r.reader = nil
|
|
|
|
if err := reader.Close(); err != nil {
|
|
return fmt.Errorf("close geoip resolver: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func normalizeCountryCode(raw string) (string, error) {
|
|
code := strings.ToUpper(strings.TrimSpace(raw))
|
|
if len(code) != 2 {
|
|
return "", fmt.Errorf("invalid ISO 3166-1 alpha-2 code %q", raw)
|
|
}
|
|
|
|
for idx := 0; idx < len(code); idx++ {
|
|
if code[idx] < 'A' || code[idx] > 'Z' {
|
|
return "", fmt.Errorf("invalid ISO 3166-1 alpha-2 code %q", raw)
|
|
}
|
|
}
|
|
|
|
return code, nil
|
|
}
|