package router_test import ( "encoding/json" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "github.com/gin-gonic/gin" "github.com/iliadenisov/galaxy/internal/model/rest" "github.com/iliadenisov/galaxy/internal/router" "github.com/stretchr/testify/assert" ) func TestRouter(t *testing.T) { r := router.SetupRouter() exampleCommand := rest.Command{ Race: "SomeRace", Vote: &rest.CommandVote{ Recipient: "AnotherRace", }, } w := httptest.NewRecorder() req, _ := http.NewRequest("PUT", "/api/v1/command", cmdBody(exampleCommand)) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code, w.Body) // error: notblank validator exampleCommand.Race = "" w = httptest.NewRecorder() req, _ = http.NewRequest("PUT", "/api/v1/command", cmdBody(exampleCommand)) r.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) exampleCommand.Race = " " w = httptest.NewRecorder() req, _ = http.NewRequest("PUT", "/api/v1/command", cmdBody(exampleCommand)) r.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) // error: no commands exampleCommand = rest.Command{ Race: "SomeRace", } w = httptest.NewRecorder() req, _ = http.NewRequest("PUT", "/api/v1/command", cmdBody(exampleCommand)) r.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) // error: more than one command exampleCommand = rest.Command{ Race: "SomeRace", Vote: &rest.CommandVote{ Recipient: "AnotherRace", }, DeclarePeace: &rest.CommandDeclarePeace{ Opponent: "OpponentRace", }, } w = httptest.NewRecorder() req, _ = http.NewRequest("PUT", "/api/v1/command", cmdBody(exampleCommand)) r.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code, w.Body) } func TestLimitConnections(t *testing.T) { r := limitTestingRouter() wg := sync.WaitGroup{} lock := sync.WaitGroup{} lock.Add(1) for range 1000 { wg.Go(func() { w := httptest.NewRecorder() lock.Wait() req, _ := http.NewRequest("GET", "/limited", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code, w.Body) }) } lock.Done() wg.Wait() } func cmdBody(cmd rest.Command) *strings.Reader { commandJson, _ := json.Marshal(cmd) return strings.NewReader(string(commandJson)) } func limitTestingRouter() *gin.Engine { gin.SetMode(gin.ReleaseMode) r := gin.New() r.Use(gin.Recovery()) counter := atomic.Int32{} r.GET("/limited", // limiting all ingoing connections router.LimitMiddleware(1), // storing counter value and testing increment after executing Next handlers func(c *gin.Context) { expected := counter.Load() + 1 c.Next() current := counter.Load() if current != expected { c.String(http.StatusConflict, "expected: %d, got: %d", expected, current) } }, // increment counter func(c *gin.Context) { counter.Add(1) c.Status(http.StatusOK) }) return r }