package ratewatch import ( "context" "fmt" "testing" "time" "github.com/google/uuid" ) // fakeFlagger records flag calls and reports them as newly set. type fakeFlagger struct { calls []uuid.UUID } func (f *fakeFlagger) FlagHighRate(_ context.Context, id uuid.UUID, _ time.Time) (bool, error) { f.calls = append(f.calls, id) return true, nil } // watchAt returns a Watch with a controllable clock. func watchAt(cfg Config, flagger Flagger, at *time.Time) *Watch { w := New(cfg, flagger, nil) w.now = func() time.Time { return *at } return w } // TestIngestAggregatesAndRecent verifies episodes accumulate per (class, key), // invalid entries are skipped, and Recent orders by last rejection. func TestIngestAggregatesAndRecent(t *testing.T) { now := time.Date(2026, 6, 10, 12, 0, 0, 0, time.UTC) w := watchAt(DefaultConfig(), nil, &now) ctx := context.Background() w.Ingest(ctx, []Entry{ {Class: "public", Key: "10.0.0.1", Rejected: 3}, {Class: "user", Key: "u-1", Rejected: 5}, {Class: "", Key: "x", Rejected: 1}, {Class: "user", Key: "", Rejected: 1}, {Class: "user", Key: "u-1", Rejected: 0}, }) now = now.Add(30 * time.Second) w.Ingest(ctx, []Entry{{Class: "public", Key: "10.0.0.1", Rejected: 4}}) got := w.Recent() if len(got) != 2 { t.Fatalf("Recent returned %d episodes, want 2", len(got)) } if got[0].Class != "public" || got[0].Key != "10.0.0.1" || got[0].Rejected != 7 { t.Errorf("first episode = %+v, want public/10.0.0.1 rejected=7", got[0]) } if !got[0].LastSeen.After(got[0].FirstSeen) { t.Errorf("episode span = [%v, %v], want a positive span", got[0].FirstSeen, got[0].LastSeen) } if got[1].Class != "user" || got[1].Rejected != 5 { t.Errorf("second episode = %+v, want user rejected=5", got[1]) } } // TestAutoFlagThreshold verifies the flag fires only for a user-class series // crossing the threshold within the window, with a parseable account id. func TestAutoFlagThreshold(t *testing.T) { now := time.Date(2026, 6, 10, 12, 0, 0, 0, time.UTC) flagged := &fakeFlagger{} id := uuid.New() w := watchAt(Config{FlagThreshold: 100, FlagWindow: 10 * time.Minute}, flagged, &now) ctx := context.Background() w.Ingest(ctx, []Entry{ {Class: "user", Key: id.String(), Rejected: 99}, {Class: "public", Key: "10.0.0.1", Rejected: 1000}, {Class: "user", Key: "not-a-uuid", Rejected: 1000}, }) if len(flagged.calls) != 0 { t.Fatalf("flagged %v below the threshold", flagged.calls) } now = now.Add(30 * time.Second) w.Ingest(ctx, []Entry{{Class: "user", Key: id.String(), Rejected: 1}}) if len(flagged.calls) != 1 || flagged.calls[0] != id { t.Fatalf("flag calls = %v, want exactly [%s]", flagged.calls, id) } } // TestAutoFlagWindowExpiry verifies rejections age out of the rolling window. func TestAutoFlagWindowExpiry(t *testing.T) { now := time.Date(2026, 6, 10, 12, 0, 0, 0, time.UTC) flagged := &fakeFlagger{} id := uuid.New() w := watchAt(Config{FlagThreshold: 100, FlagWindow: 10 * time.Minute}, flagged, &now) ctx := context.Background() w.Ingest(ctx, []Entry{{Class: "user", Key: id.String(), Rejected: 60}}) now = now.Add(11 * time.Minute) w.Ingest(ctx, []Entry{{Class: "user", Key: id.String(), Rejected: 60}}) if len(flagged.calls) != 0 { t.Fatalf("flagged %v across an expired window", flagged.calls) } now = now.Add(time.Minute) w.Ingest(ctx, []Entry{{Class: "user", Key: id.String(), Rejected: 50}}) if len(flagged.calls) != 1 { t.Fatalf("flag calls = %v, want one in-window crossing", flagged.calls) } } // TestSeriesBound verifies the episode map stays bounded by evicting the // least-recently-throttled series. func TestSeriesBound(t *testing.T) { now := time.Date(2026, 6, 10, 12, 0, 0, 0, time.UTC) w := watchAt(DefaultConfig(), nil, &now) ctx := context.Background() for i := range maxSeries + 10 { now = now.Add(time.Second) w.Ingest(ctx, []Entry{{Class: "public", Key: fmt.Sprintf("10.0.%d.%d", i/256, i%256), Rejected: 1}}) } got := w.Recent() if len(got) != maxSeries { t.Fatalf("retained %d series, want %d", len(got), maxSeries) } for _, ep := range got { if ep.Key == "10.0.0.0" { t.Fatal("the least-recently-throttled series survived the bound") } } } // TestConfigValidate covers the tuning guards. func TestConfigValidate(t *testing.T) { if err := DefaultConfig().Validate(); err != nil { t.Errorf("default config invalid: %v", err) } if err := (Config{FlagThreshold: 0, FlagWindow: time.Minute}).Validate(); err == nil { t.Error("zero threshold passed validation") } if err := (Config{FlagThreshold: 1, FlagWindow: 0}).Validate(); err == nil { t.Error("zero window passed validation") } }