From 232d9caad3ec4a559b45f546d1cfd67e703a716c Mon Sep 17 00:00:00 2001 From: Kichiyaki Date: Tue, 13 Jul 2021 11:09:59 +0200 Subject: [PATCH] add more checks in VersionDataLoader.LoadServers, add more tests --- tw/twdataloader/version_data_loader.go | 16 +++++-- tw/twdataloader/version_data_loader_test.go | 52 ++++++++++++++++----- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/tw/twdataloader/version_data_loader.go b/tw/twdataloader/version_data_loader.go index db26962..4b40017 100644 --- a/tw/twdataloader/version_data_loader.go +++ b/tw/twdataloader/version_data_loader.go @@ -59,11 +59,21 @@ func (dl *VersionDataLoader) LoadServers() ([]*Server, error) { } return nil, fmtedErr } + bodyMap, ok := body.(map[interface{}]interface{}) + if !ok { + return nil, errors.Errorf("expected map, got %T", body) + } var servers []*Server - for serverKey, url := range body.(map[interface{}]interface{}) { - serverKeyStr := serverKey.(string) - urlStr := url.(string) + for serverKey, url := range bodyMap { + serverKeyStr, ok := serverKey.(string) + if !ok { + return nil, errors.Errorf("expected string as the key of the map, got %T", serverKey) + } + urlStr, ok := url.(string) + if !ok { + return nil, errors.Errorf("expected string as the value of the map, got %T", url) + } if serverKeyStr != "" && urlStr != "" { servers = append(servers, &Server{ Key: serverKeyStr, diff --git a/tw/twdataloader/version_data_loader_test.go b/tw/twdataloader/version_data_loader_test.go index f08abed..76ac430 100644 --- a/tw/twdataloader/version_data_loader_test.go +++ b/tw/twdataloader/version_data_loader_test.go @@ -26,19 +26,49 @@ func prepareTestServer(resp string) *httptest.Server { } func TestLoadServers(t *testing.T) { - t.Run("invalid payload", func(t *testing.T) { - resp := `:"https://pl165.plemiona.pl";s:5:"pl166";s:25:"https://pl166.plemiona.pl";s:5:"pl167";s:25:"https://pl167.plemiona.pl";}` - ts := prepareTestServer(resp) - defer ts.Close() + t.Run("invalid response", func(t *testing.T) { + type scenario struct { + resp string + expectedErrMsg string + } - dl := NewVersionDataLoader(&VersionDataLoaderConfig{ - Host: strings.ReplaceAll(ts.URL, "https://", ""), - Client: ts.Client(), - }) + scenarios := []scenario{ + { + resp: `:"https://pl165.plemiona.pl";s:5:"pl166";s:25:"https://pl166.plemiona.pl";s:5:"pl167";s:25:"https://pl167.plemiona.pl";}`, + expectedErrMsg: "couldn't decode the response body into a go value", + }, + { + resp: `a:19:{s:5:"pl150"s:25"https://pl150.plemiona.pl";}`, + expectedErrMsg: "expected string as the value of the map, got ", + }, + { + resp: "a:3:{i:0;i:1;i:1;i:2;i:2;i:3;}", + expectedErrMsg: "expected string as the key of the map, got int64", + }, + { + resp: `O:8:"stdClass":0:{}`, + expectedErrMsg: "expected map, got *phpserialize.PhpObject", + }, + { + resp: `a:2:{s:3:"asd";i:123;s:4:"asd2";i:123;}`, + expectedErrMsg: "expected string as the value of the map, got int64", + }, + } - _, err := dl.LoadServers() - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "couldn't decode the response body into a go value") + for _, scenario := range scenarios { + ts := prepareTestServer(scenario.resp) + + dl := NewVersionDataLoader(&VersionDataLoaderConfig{ + Host: strings.ReplaceAll(ts.URL, "https://", ""), + Client: ts.Client(), + }) + + _, err := dl.LoadServers() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), scenario.expectedErrMsg) + + ts.Close() + } }) t.Run("success", func(t *testing.T) {