package router_test import ( "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "galaxy/model/rest" "github.com/gin-gonic/gin" "github.com/iliadenisov/galaxy/server/internal/router" "github.com/stretchr/testify/assert" ) func TestLimitConnections(t *testing.T) { r := limitTestingRouter() wg := sync.WaitGroup{} lock := sync.WaitGroup{} lock.Add(1) for range 1000 { wg.Go(func() { w := httptest.NewRecorder() lock.Wait() req, _ := http.NewRequest("GET", "/limited", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code, w.Body) }) } lock.Done() wg.Wait() } func asBody(body any) *strings.Reader { commandJson, _ := json.Marshal(body) return strings.NewReader(string(commandJson)) } func limitTestingRouter() *gin.Engine { gin.SetMode(gin.ReleaseMode) r := gin.New() r.Use(gin.Recovery()) counter := atomic.Int32{} r.GET("/limited", // limiting all ingoing connections router.LimitMiddleware(1), // storing counter value and testing increment after executing Next handlers func(c *gin.Context) { expected := counter.Load() + 1 c.Next() current := counter.Load() if current != expected { c.String(http.StatusConflict, "expected: %d, got: %d", expected, current) } }, // increment counter func(c *gin.Context) { counter.Add(1) c.Status(http.StatusOK) }) return r } func generateInitRequest(races int) rest.Init { request := rest.Init{ Races: make([]rest.Race, races), } for i := range request.Races { request.Races[i] = rest.Race{Name: raceName(i)} } return request } func raceName(i int) string { return fmt.Sprintf("Race_%02d", i) }