diff --git a/tw/twdataloader/helpers.go b/tw/twdataloader/helpers.go index adf1971..bcfb10f 100644 --- a/tw/twdataloader/helpers.go +++ b/tw/twdataloader/helpers.go @@ -25,7 +25,14 @@ func uncompressAndReadCsvLines(r io.Reader) ([][]string, error) { } type handlers struct { - getServers http.HandlerFunc + getServers http.HandlerFunc + killAll http.HandlerFunc + killAtt http.HandlerFunc + killDef http.HandlerFunc + killSup http.HandlerFunc + killAllTribe http.HandlerFunc + killAttTribe http.HandlerFunc + killDefTribe http.HandlerFunc } func (h *handlers) init() { @@ -35,6 +42,27 @@ func (h *handlers) init() { if h.getServers == nil { h.getServers = noop } + if h.killAll == nil { + h.killAll = noop + } + if h.killAtt == nil { + h.killAtt = noop + } + if h.killDef == nil { + h.killDef = noop + } + if h.killSup == nil { + h.killSup = noop + } + if h.killAllTribe == nil { + h.killAllTribe = noop + } + if h.killAttTribe == nil { + h.killAttTribe = noop + } + if h.killDefTribe == nil { + h.killDefTribe = noop + } } func prepareTestServer(h *handlers) *httptest.Server { @@ -42,24 +70,32 @@ func prepareTestServer(h *handlers) *httptest.Server { h = &handlers{} } h.init() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case EndpointGetServers: h.getServers(w, r) return case EndpointKillAll: + h.killAll(w, r) return case EndpointKillAtt: + h.killAtt(w, r) return case EndpointKillDef: + h.killDef(w, r) return case EndpointKillSup: + h.killSup(w, r) return case EndpointKillAllTribe: + h.killAllTribe(w, r) return case EndpointKillAttTribe: + h.killAttTribe(w, r) return case EndpointKillDefTribe: + h.killDefTribe(w, r) return default: w.WriteHeader(http.StatusNotFound) @@ -75,3 +111,17 @@ func createWriteStringHandler(resp string) http.HandlerFunc { } }) } + +func createWriteCompressedStringHandler(resp string) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gzipWriter := gzip.NewWriter(w) + defer gzipWriter.Close() + _, err := gzipWriter.Write([]byte(resp)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + if err := gzipWriter.Flush(); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + }) +} diff --git a/tw/twdataloader/server_data_loader_test.go b/tw/twdataloader/server_data_loader_test.go index 6c61716..7e9227a 100644 --- a/tw/twdataloader/server_data_loader_test.go +++ b/tw/twdataloader/server_data_loader_test.go @@ -1,19 +1,156 @@ package twdataloader import ( + "github.com/stretchr/testify/assert" "testing" + + "github.com/tribalwarshelp/shared/tw/twmodel" ) func TestLoadOD(t *testing.T) { - t.Run("fallback to the not gzipped endpoint", func(t *testing.T) { + type scenario struct { + respKillAll string + respKillAtt string + respKillDef string + respKillSup string + respKillAllTribe string + respKillAttTribe string + respKillDefTribe string + tribe bool + expectedResult map[int]*twmodel.OpponentsDefeated + expectedErrMsg string + } - }) + scenarios := []scenario{ + { + respKillAll: "1,1", + expectedErrMsg: "invalid line format (should be rank,id,score)", + }, + { + respKillAllTribe: "1,1", + expectedErrMsg: "invalid line format (should be rank,id,score)", + tribe: true, + }, + { + respKillAll: "1,1,1", + respKillAtt: "1,1,1", + respKillDef: "1,1,1", + respKillSup: "1,1", + expectedErrMsg: "invalid line format (should be rank,id,score)", + }, + { + respKillAllTribe: "1,1,1", + respKillAttTribe: "1,1,1", + respKillDefTribe: "1,1", + expectedErrMsg: "invalid line format (should be rank,id,score)", + tribe: true, + }, + { + respKillAll: "1,1,1\n2,2,2\n3,3,3", + respKillAtt: "1,1,1\n2,2,2\n3,3,3", + respKillDef: "1,1,1\n2,2,2\n3,3,3", + respKillSup: "1,1,1\n2,2,2\n3,3,3", + expectedResult: map[int]*twmodel.OpponentsDefeated{ + 1: { + RankAtt: 1, + ScoreAtt: 1, + RankDef: 1, + ScoreDef: 1, + RankSup: 1, + ScoreSup: 1, + RankTotal: 1, + ScoreTotal: 1, + }, + 2: { + RankAtt: 2, + ScoreAtt: 2, + RankDef: 2, + ScoreDef: 2, + RankSup: 2, + ScoreSup: 2, + ScoreTotal: 2, + RankTotal: 2, + }, + 3: { + RankAtt: 3, + ScoreAtt: 3, + RankDef: 3, + ScoreDef: 3, + RankSup: 3, + ScoreSup: 3, + ScoreTotal: 3, + RankTotal: 3, + }, + }, + }, + { + respKillAllTribe: "1,1,1\n2,2,2\n3,3,3", + respKillAttTribe: "1,1,1\n2,2,2\n3,3,3", + respKillDefTribe: "1,1,1\n2,2,2\n3,3,3", + expectedResult: map[int]*twmodel.OpponentsDefeated{ + 1: { + RankAtt: 1, + ScoreAtt: 1, + RankDef: 1, + ScoreDef: 1, + ScoreTotal: 1, + RankTotal: 1, + }, + 2: { + RankAtt: 2, + ScoreAtt: 2, + RankDef: 2, + ScoreDef: 2, + ScoreTotal: 2, + RankTotal: 2, + }, + 3: { + RankAtt: 3, + ScoreAtt: 3, + RankDef: 3, + ScoreDef: 3, + ScoreTotal: 3, + RankTotal: 3, + }, + }, + tribe: true, + }, + } - t.Run("invalid line format", func(t *testing.T) { + for _, scenario := range scenarios { + ts := prepareTestServer(&handlers{ + killAll: createWriteCompressedStringHandler(scenario.respKillAll), + killAtt: createWriteCompressedStringHandler(scenario.respKillAtt), + killDef: createWriteCompressedStringHandler(scenario.respKillDef), + killSup: createWriteCompressedStringHandler(scenario.respKillSup), + killAllTribe: createWriteCompressedStringHandler(scenario.respKillAllTribe), + killAttTribe: createWriteCompressedStringHandler(scenario.respKillAttTribe), + killDefTribe: createWriteCompressedStringHandler(scenario.respKillDefTribe), + }) - }) + dl := NewServerDataLoader(&ServerDataLoaderConfig{ + BaseURL: ts.URL, + Client: ts.Client(), + }) - t.Run("success", func(t *testing.T) { + res, err := dl.LoadOD(scenario.tribe) + if scenario.expectedErrMsg != "" { + assert.NotNil(t, err) + assert.Contains(t, err.Error(), scenario.expectedErrMsg) + } else { + assert.Nil(t, err) + } - }) + if scenario.expectedResult != nil { + assert.Len(t, res, len(scenario.expectedResult)) + for id, singleResult := range res { + expected, ok := scenario.expectedResult[id] + assert.True(t, ok) + assert.NotNil(t, expected) + assert.EqualValues(t, expected, singleResult) + } + } + + ts.Close() + } } diff --git a/tw/twdataloader/version_data_loader_test.go b/tw/twdataloader/version_data_loader_test.go index aa69808..1b3f633 100644 --- a/tw/twdataloader/version_data_loader_test.go +++ b/tw/twdataloader/version_data_loader_test.go @@ -43,7 +43,7 @@ func TestLoadServers(t *testing.T) { }) dl := NewVersionDataLoader(&VersionDataLoaderConfig{ - Host: strings.ReplaceAll(ts.URL, "https://", ""), + Host: ts.URL, Client: ts.Client(), }) @@ -66,7 +66,7 @@ func TestLoadServers(t *testing.T) { assert.Nil(t, err) dl := NewVersionDataLoader(&VersionDataLoaderConfig{ - Host: strings.ReplaceAll(ts.URL, "https://", ""), + Host: ts.URL, Client: ts.Client(), })