using subtests for TestValidateSecureEndpoints()

dependabot/go_modules/go.uber.org/atomic-1.10.0
eval-exec 2022-03-21 22:36:39 +08:00
parent 4786a72cfc
commit 88e1d6b126
1 changed files with 59 additions and 13 deletions

View File

@ -17,7 +17,7 @@ package transport
import (
"net/http"
"net/http/httptest"
"strings"
"reflect"
"testing"
)
@ -33,18 +33,64 @@ func TestValidateSecureEndpoints(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
defer srv.Close()
insecureEps := []string{
"http://" + srv.Listener.Addr().String(),
"invalid remote address",
tests := map[string]struct {
endPoints []string
expectedEndpoints []string
expectedErr bool
}{
"invalidEndPoints": {
endPoints: []string{
"invalid endpoint",
},
expectedEndpoints: nil,
expectedErr: true,
},
"insecureEndpoints": {
endPoints: []string{
"http://127.0.0.1:8000",
"http://" + srv.Listener.Addr().String(),
},
expectedEndpoints: nil,
expectedErr: true,
},
"secureEndPoints": {
endPoints: []string{
"https://" + srv.Listener.Addr().String(),
},
expectedEndpoints: []string{
"https://" + srv.Listener.Addr().String(),
},
expectedErr: false,
},
"mixEndPoints": {
endPoints: []string{
"https://" + srv.Listener.Addr().String(),
"http://" + srv.Listener.Addr().String(),
"invalid end points",
},
expectedEndpoints: []string{
"https://" + srv.Listener.Addr().String(),
},
expectedErr: true,
},
}
if _, err := ValidateSecureEndpoints(*tlsInfo, insecureEps); err == nil || !strings.Contains(err.Error(), "is insecure") {
t.Error("validate secure endpoints should fail")
}
secureEps := []string{
"https://" + srv.Listener.Addr().String(),
}
if _, err := ValidateSecureEndpoints(*tlsInfo, secureEps); err != nil {
t.Error("validate secure endpoints should succeed")
for name, test := range tests {
t.Run(name, func(t *testing.T) {
secureEps, err := ValidateSecureEndpoints(*tlsInfo, test.endPoints)
if test.expectedErr && err == nil {
t.Errorf("expected error")
}
if !test.expectedErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
if err == nil && !test.expectedErr {
if len(secureEps) != len(test.expectedEndpoints) {
t.Errorf("expected %v endpoints, got %v", len(test.expectedEndpoints), len(secureEps))
}
if !reflect.DeepEqual(test.expectedEndpoints, secureEps) {
t.Errorf("expected endpoints %v, got %v", test.expectedEndpoints, secureEps)
}
}
})
}
}