Files
galaxy-game/pkg/geoip/geoip.go
T
Ilia Denisov 84eeaf5184 feat: geoip
2026-04-09 15:16:36 +03:00

165 lines
4.2 KiB
Go

package geoip
import (
"errors"
"fmt"
"net/netip"
"strings"
"sync"
"github.com/oschwald/geoip2-golang/v2"
)
var (
// ErrInvalidAddress reports that a caller supplied an invalid IP address.
ErrInvalidAddress = errors.New("invalid IP address")
// ErrCountryNotFound reports that the GeoIP database contains no country
// record for a valid IP address.
ErrCountryNotFound = errors.New("country not found")
// ErrClosed reports that the resolver has already released its database
// resources and can no longer serve lookups.
ErrClosed = errors.New("geoip resolver is closed")
)
type countryReader interface {
Country(netip.Addr) (*geoip2.Country, error)
Close() error
}
// Resolver resolves ISO 3166-1 alpha-2 country codes from a local MaxMind
// country-capable database.
//
// Resolver is safe for concurrent lookups. Call Close when the process no
// longer needs the memory-mapped database file.
type Resolver struct {
mu sync.RWMutex
reader countryReader
}
// Open constructs a Resolver backed by the MaxMind database file at
// databasePath.
//
// databasePath must point to a local country-capable .mmdb file such as
// GeoLite2 Country. Open trims surrounding whitespace from databasePath before
// opening the file.
func Open(databasePath string) (*Resolver, error) {
path := strings.TrimSpace(databasePath)
if path == "" {
return nil, errors.New("open geoip database: database path must not be empty")
}
reader, err := geoip2.Open(path)
if err != nil {
return nil, fmt.Errorf("open geoip database %q: %w", path, err)
}
return newResolver(reader), nil
}
// NewGeoIP is a legacy alias for Open.
//
// Deprecated: use Open for new code.
func NewGeoIP(databasePath string) (*Resolver, error) {
return Open(databasePath)
}
func newResolver(reader countryReader) *Resolver {
return &Resolver{reader: reader}
}
// Country resolves addr to an uppercase ISO 3166-1 alpha-2 country code.
//
// Country returns ErrInvalidAddress when addr is not valid, ErrCountryNotFound
// when the database contains no country for addr, and ErrClosed after Close
// has been called successfully.
func (r *Resolver) Country(addr netip.Addr) (string, error) {
if !addr.IsValid() {
return "", fmt.Errorf("lookup country: %w", ErrInvalidAddress)
}
if r == nil {
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrClosed)
}
r.mu.RLock()
defer r.mu.RUnlock()
if r.reader == nil {
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrClosed)
}
record, err := r.reader.Country(addr)
if err != nil {
return "", fmt.Errorf("lookup country for %s: %w", addr, err)
}
if record == nil {
return "", fmt.Errorf("lookup country for %s: nil country record", addr)
}
if !record.HasData() || strings.TrimSpace(record.Country.ISOCode) == "" {
return "", fmt.Errorf("lookup country for %s: %w", addr, ErrCountryNotFound)
}
code, err := normalizeCountryCode(record.Country.ISOCode)
if err != nil {
return "", fmt.Errorf("lookup country for %s: %w", addr, err)
}
return code, nil
}
// CountryString resolves raw to an uppercase ISO 3166-1 alpha-2 country code.
//
// CountryString trims surrounding whitespace from raw before parsing it as an
// IP address and then delegates to Country.
func (r *Resolver) CountryString(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
addr, err := netip.ParseAddr(trimmed)
if err != nil {
return "", fmt.Errorf("parse IP address %q: %w", raw, errors.Join(ErrInvalidAddress, err))
}
return r.Country(addr)
}
// Close releases the underlying database resources.
//
// Close is idempotent and nil-safe.
func (r *Resolver) Close() error {
if r == nil {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
if r.reader == nil {
return nil
}
reader := r.reader
r.reader = nil
if err := reader.Close(); err != nil {
return fmt.Errorf("close geoip resolver: %w", err)
}
return nil
}
func normalizeCountryCode(raw string) (string, error) {
code := strings.ToUpper(strings.TrimSpace(raw))
if len(code) != 2 {
return "", fmt.Errorf("invalid ISO 3166-1 alpha-2 code %q", raw)
}
for idx := 0; idx < len(code); idx++ {
if code[idx] < 'A' || code[idx] > 'Z' {
return "", fmt.Errorf("invalid ISO 3166-1 alpha-2 code %q", raw)
}
}
return code, nil
}