From bbb08c788f70d96cd84da5336dfca15fd9a2759f Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 22 Feb 2023 14:20:22 +0300 Subject: [PATCH 1/2] refactor jwt provider for better testing --- services/jwt-provider/handler_create_jwt.go | 34 ++++++++++++--------- services/jwt-provider/provider.go | 13 -------- services/jwt-provider/provider_test.go | 10 ++++++ 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/services/jwt-provider/handler_create_jwt.go b/services/jwt-provider/handler_create_jwt.go index 7a80fd27..7238a41d 100644 --- a/services/jwt-provider/handler_create_jwt.go +++ b/services/jwt-provider/handler_create_jwt.go @@ -1,6 +1,7 @@ package jwt_provider import ( + "context" "encoding/json" "fmt" "net" @@ -31,29 +32,34 @@ func (j *JWTProvider) createJWTHandler(w http.ResponseWriter, req *http.Request) } } - ipAddr, _, err := net.SplitHostPort(req.RemoteAddr) + jwt, err := j.doCreateJWT(req.Context(), req.RemoteAddr, msg.Claims) if err != nil { w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprintf(w, "can't extract ip from request %s", req.RemoteAddr) + _, _ = fmt.Fprintf(w, "can not create jwt: %v", err) return } - agentID, err := j.agentIDReverseLookup(req.Context(), ipAddr) + resp, err := json.Marshal(CreateJWTResponse{Token: jwt}) + + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "%s", resp) +} + +func (j *JWTProvider) doCreateJWT(ctx context.Context, remoteAddr string, claims map[string]interface{}) (string, error) { + ipAddr, _, err := net.SplitHostPort(remoteAddr) if err != nil { - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprintf(w, "can't find bot id from request source %s, err: %v", ipAddr, err) - return + return "", fmt.Errorf("can't extract ip from request %s", remoteAddr) } - jwt, err := CreateBotJWT(j.cfg.Key, agentID, msg.Claims) + agentID, err := j.agentIDReverseLookup(ctx, ipAddr) if err != nil { - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprint(w, errFailedToCreateJWT) - return + return "", fmt.Errorf("can't find bot id from request source %s, err: %v", ipAddr, err) } - resp, err := json.Marshal(CreateJWTResponse{Token: jwt}) + jwt, err := CreateBotJWT(j.cfg.Key, agentID, claims) + if err != nil { + return "", fmt.Errorf(errFailedToCreateJWT) + } - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintf(w, "%s", resp) -} + return jwt, nil +} \ No newline at end of file diff --git a/services/jwt-provider/provider.go b/services/jwt-provider/provider.go index 37cbd5bd..0c4cd1a8 100644 --- a/services/jwt-provider/provider.go +++ b/services/jwt-provider/provider.go @@ -6,13 +6,10 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/docker/docker/api/types" "github.com/ethereum/go-ethereum/accounts/keystore" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "github.com/forta-network/forta-core-go/clients/health" "github.com/forta-network/forta-core-go/security" "github.com/forta-network/forta-node/clients" @@ -23,9 +20,6 @@ import ( // JWTProvider provides jwt tokens to bots, signed with node's private key.. type JWTProvider struct { - botConfigs []config.AgentConfig - botConfigsMutex sync.RWMutex - // to match request ip <-> bot id dockerClient clients.DockerClient @@ -192,13 +186,6 @@ func (j *JWTProvider) Health() health.Reports { } } -// requestHash used for "hash" claim in JWT token -func requestHash(uri string, payload []byte) common.Hash { - requestStr := fmt.Sprintf("%s%s", uri, payload) - - return crypto.Keccak256Hash([]byte(requestStr)) -} - // CreateBotJWT returns a bot JWT token. Basically security.ScannerJWT with bot&request info. func CreateBotJWT(key *keystore.Key, agentID string, claims map[string]interface{}) (string, error) { if claims == nil { diff --git a/services/jwt-provider/provider_test.go b/services/jwt-provider/provider_test.go index 8f179032..b5134be1 100644 --- a/services/jwt-provider/provider_test.go +++ b/services/jwt-provider/provider_test.go @@ -2,14 +2,24 @@ package jwt_provider import ( "context" + "fmt" "testing" "time" "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/forta-network/forta-core-go/security" "github.com/golang-jwt/jwt/v4" ) +// requestHash used for "hash" claim in JWT token +func requestHash(uri string, payload []byte) common.Hash { + requestStr := fmt.Sprintf("%s%s", uri, payload) + + return crypto.Keccak256Hash([]byte(requestStr)) +} + func Test_createBotJWT(t *testing.T) { dir := t.TempDir() ks := keystore.NewKeyStore(dir, keystore.StandardScryptN, keystore.StandardScryptP) From 4b6de9f81a88311517d4666cd2bc661737d1abf2 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 22 Feb 2023 15:04:17 +0300 Subject: [PATCH 2/2] add tests to jwt provider --- services/jwt-provider/handler_create_jwt.go | 57 ++++- .../jwt-provider/handler_create_jwt_test.go | 216 ++++++++++++++++++ services/jwt-provider/provider.go | 80 ++----- 3 files changed, 287 insertions(+), 66 deletions(-) create mode 100644 services/jwt-provider/handler_create_jwt_test.go diff --git a/services/jwt-provider/handler_create_jwt.go b/services/jwt-provider/handler_create_jwt.go index 7238a41d..14fd1af0 100644 --- a/services/jwt-provider/handler_create_jwt.go +++ b/services/jwt-provider/handler_create_jwt.go @@ -6,6 +6,10 @@ import ( "fmt" "net" "net/http" + "strings" + + "github.com/docker/docker/api/types" + "github.com/forta-network/forta-node/config" ) type CreateJWTMessage struct { @@ -15,11 +19,62 @@ type CreateJWTResponse struct { Token string `json:"token"` } +const envPrefix = config.EnvFortaBotID + "=" + const ( errBadCreateMessage = "bad create jwt message body" errFailedToCreateJWT = "can't find bot id from request source" ) +// agentIDReverseLookup reverse lookup from ip to agent id. +func (j *JWTProvider) agentIDReverseLookup(ctx context.Context, ipAddr string) (string, error) { + container, err := j.findContainerByIP(ctx, ipAddr) + if err != nil { + return "", err + } + + botID, err := j.extractBotIDFromContainer(ctx, container) + if err != nil { + return "", err + } + + return botID, nil +} + +func (j *JWTProvider) extractBotIDFromContainer(ctx context.Context, container types.Container) (string, error) { + // container struct doesn't have the "env" information, inspection required. + c, err := j.dockerClient.InspectContainer(ctx, container.ID) + if err != nil { + return "", err + } + + // find the env variable with bot id + for _, s := range c.Config.Env { + if env := strings.SplitAfter(s, envPrefix); len(env) == 2 { + return env[1], nil + } + } + + return "", fmt.Errorf("can't extract bot id from container") +} + +func (j *JWTProvider) findContainerByIP(ctx context.Context, ipAddr string) (types.Container, error) { + containers, err := j.dockerClient.GetContainers(ctx) + if err != nil { + return types.Container{}, err + } + + // find the container that has the same ip + for _, container := range containers { + for _, network := range container.NetworkSettings.Networks { + if network.IPAddress == ipAddr { + return container, nil + } + } + } + return types.Container{}, fmt.Errorf("can't find container %s", ipAddr) +} + // createJWTHandler returns a scanner jwt token with claims [hash] = hash(uri,payload) and [bot] = "bot id" func (j *JWTProvider) createJWTHandler(w http.ResponseWriter, req *http.Request) { var msg CreateJWTMessage @@ -58,7 +113,7 @@ func (j *JWTProvider) doCreateJWT(ctx context.Context, remoteAddr string, claims jwt, err := CreateBotJWT(j.cfg.Key, agentID, claims) if err != nil { - return "", fmt.Errorf(errFailedToCreateJWT) + return "", fmt.Errorf("%s: %v", errFailedToCreateJWT, err) } return jwt, nil diff --git a/services/jwt-provider/handler_create_jwt_test.go b/services/jwt-provider/handler_create_jwt_test.go new file mode 100644 index 00000000..417de468 --- /dev/null +++ b/services/jwt-provider/handler_create_jwt_test.go @@ -0,0 +1,216 @@ +package jwt_provider + +import ( + "bytes" + "context" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/network" + "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/forta-network/forta-core-go/security" + mock_clients "github.com/forta-network/forta-node/clients/mocks" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func mockJWTProvider(t *testing.T) (*JWTProvider, *mock_clients.MockDockerClient) { + ctrl := gomock.NewController(t) + + mockDockerClient := mock_clients.NewMockDockerClient(ctrl) + + dir := t.TempDir() + ks := keystore.NewKeyStore(dir, keystore.StandardScryptN, keystore.StandardScryptP) + + _, err := ks.NewAccount("Forta123") + if err != nil { + t.Fatal(err) + } + + mockKey, err := security.LoadKeyWithPassphrase(dir, "Forta123") + if err != nil { + t.Fatal(err) + } + + return &JWTProvider{ + cfg: &JWTProviderConfig{Key: mockKey}, + dockerClient: mockDockerClient, + }, mockDockerClient + +} + +func TestJWTProvider_jwtHandler(t *testing.T) { + j, mockDockerClient := mockJWTProvider(t) + + // test context + ctx := context.Background() + + // + // Test Case 1: Bot can retrieve JWT + // + mockBotIP := "172.0.1.2" + mockBotRemoteAddr := mockBotIP + ":8423" + mockBotID := "1" + mockBotContainerInfo := []types.Container{ + { + NetworkSettings: &types.SummaryNetworkSettings{ + Networks: map[string]*network.EndpointSettings{ + "local": {IPAddress: mockBotIP}, + }, + }, + }, + } + mockBotContainerJSON := &types.ContainerJSON{ + Config: &container.Config{ + Env: []string{ + envPrefix + mockBotID, + }, + }, + } + + mockDockerClient.EXPECT().GetContainers(gomock.Any()).Return(mockBotContainerInfo, nil) + mockDockerClient.EXPECT().InspectContainer(gomock.Any(), gomock.Any()).Return(mockBotContainerJSON, nil) + + _, err := j.doCreateJWT(ctx, mockBotRemoteAddr, nil) + assert.NoError(t, err) + + // + // Test Case 2: Unknown sources can't get JWT + // + mockBotIP = "172.0.1.2" + mockBotRemoteAddr = mockBotIP + ":8423" + mockBotID = "1" + mockBotContainerInfo = []types.Container{ + { + NetworkSettings: &types.SummaryNetworkSettings{ + Networks: map[string]*network.EndpointSettings{}, + }, + }, + } + mockBotContainerJSON = &types.ContainerJSON{ + Config: &container.Config{ + Env: []string{ + envPrefix + mockBotID, + }, + }, + } + + mockDockerClient.EXPECT().GetContainers(gomock.Any()).Return(mockBotContainerInfo, nil) + + _, err = j.doCreateJWT(ctx, mockBotRemoteAddr, nil) + assert.Error(t, err) + + // + // Test Case 3: Source with bad remote address + // + mockBotIP = "www.X.Y.Z" + mockBotRemoteAddr = mockBotIP + mockBotID = "1" + mockBotContainerInfo = []types.Container{ + { + NetworkSettings: &types.SummaryNetworkSettings{ + Networks: map[string]*network.EndpointSettings{}, + }, + }, + } + mockBotContainerJSON = &types.ContainerJSON{ + Config: &container.Config{ + Env: []string{ + envPrefix + mockBotID, + }, + }, + } + + _, err = j.doCreateJWT(ctx, mockBotRemoteAddr, nil) + assert.Error(t, err) + + // + // Test Case 4: Source with bad remote address + // + mockBotIP = "172.0.1.2" + mockBotRemoteAddr = mockBotIP + ":8423" + mockBotID = "1" + mockBotContainerInfo = []types.Container{ + { + NetworkSettings: &types.SummaryNetworkSettings{ + Networks: map[string]*network.EndpointSettings{ + "local": {IPAddress: mockBotIP}, + }, + }, + }, + } + mockBotContainerJSON = &types.ContainerJSON{ + Config: &container.Config{ + Env: []string{ + envPrefix + mockBotID, + }, + }, + } + + j2, mockDockerClient2 := mockJWTProvider(t) + j2.cfg.Key = nil + mockDockerClient2.EXPECT().GetContainers(gomock.Any()).Return(mockBotContainerInfo, nil) + mockDockerClient2.EXPECT().InspectContainer(gomock.Any(), gomock.Any()).Return(mockBotContainerJSON, nil) + + _, err = j2.doCreateJWT(ctx, mockBotRemoteAddr, nil) + assert.Error(t, err) + +} + +func TestJWTProvider_createJWTHandler(t *testing.T) { + j, mockDockerClient := mockJWTProvider(t) + + body := CreateJWTMessage{map[string]interface{}{"test-claim": "success"}} + b := bytes.NewBuffer([]byte{}) + _ = json.NewEncoder(b).Encode(body) + req := httptest.NewRequest(http.MethodPost, "/create", b) + w := httptest.NewRecorder() + + mockBotID := "1" + mockBotIP, _, _ := net.SplitHostPort(req.RemoteAddr) + mockBotContainerInfo := []types.Container{ + { + NetworkSettings: &types.SummaryNetworkSettings{ + Networks: map[string]*network.EndpointSettings{ + "local": { + IPAddress: mockBotIP, + }, + }, + }, + }, + } + mockBotContainerJSON := &types.ContainerJSON{ + Config: &container.Config{ + Env: []string{ + envPrefix + mockBotID, + }, + }, + } + + // Test case 1: can retrieve token + mockDockerClient.EXPECT().GetContainers(gomock.Any()).Return(mockBotContainerInfo, nil) + mockDockerClient.EXPECT().InspectContainer(gomock.Any(), gomock.Any()).Return(mockBotContainerJSON, nil) + j.createJWTHandler(w, req) + + // Test case 2: bad request body + b2 := bytes.NewBuffer([]byte("xxxxxx")) + req2 := httptest.NewRequest(http.MethodPost, "/create", b2) + j.createJWTHandler(w, req2) + + // Test case 3: can not retrieve token, bad source + body3 := CreateJWTMessage{map[string]interface{}{"test-claim": "success"}} + b3 := bytes.NewBuffer([]byte{}) + _ = json.NewEncoder(b3).Encode(body3) + req3 := httptest.NewRequest(http.MethodPost, "/create", b3) + req3.RemoteAddr = "1.1.1.1:4444" + w3 := httptest.NewRecorder() + + mockDockerClient.EXPECT().GetContainers(gomock.Any()).Return(mockBotContainerInfo, nil) + + j.createJWTHandler(w3, req3) +} diff --git a/services/jwt-provider/provider.go b/services/jwt-provider/provider.go index 0c4cd1a8..8527e0c6 100644 --- a/services/jwt-provider/provider.go +++ b/services/jwt-provider/provider.go @@ -5,10 +5,8 @@ import ( "errors" "fmt" "net/http" - "strings" "time" - "github.com/docker/docker/api/types" "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/forta-network/forta-core-go/clients/health" "github.com/forta-network/forta-core-go/security" @@ -35,6 +33,20 @@ type JWTProviderConfig struct { Config config.Config } +// CreateBotJWT returns a bot JWT token. Basically security.ScannerJWT with bot&request info. +func CreateBotJWT(key *keystore.Key, agentID string, claims map[string]interface{}) (string, error) { + if key == nil { + return "", fmt.Errorf("provider has no private key") + } + if claims == nil { + claims = make(map[string]interface{}) + } + + claims["bot-id"] = agentID + + return security.CreateScannerJWT(key, claims) +} + func NewJWTProvider( cfg config.Config, ) (*JWTProvider, error) { @@ -113,57 +125,6 @@ func (j *JWTProvider) listenAndServeWithContext(ctx context.Context) error { } } -// agentIDReverseLookup reverse lookup from ip to agent id. -func (j *JWTProvider) agentIDReverseLookup(ctx context.Context, ipAddr string) (string, error) { - container, err := j.findContainerByIP(ctx, ipAddr) - if err != nil { - return "", err - } - - botID, err := j.extractBotIDFromContainer(ctx, container) - if err != nil { - return "", err - } - - return botID, nil -} - -const envPrefix = config.EnvFortaBotID + "=" - -func (j *JWTProvider) extractBotIDFromContainer(ctx context.Context, container types.Container) (string, error) { - // container struct doesn't have the "env" information, inspection required. - c, err := j.dockerClient.InspectContainer(ctx, container.ID) - if err != nil { - return "", err - } - - // find the env variable with bot id - for _, s := range c.Config.Env { - if env := strings.SplitAfter(s, envPrefix); len(env) == 2 { - return env[1], nil - } - } - - return "", fmt.Errorf("can't extract bot id from container") -} - -func (j *JWTProvider) findContainerByIP(ctx context.Context, ipAddr string) (types.Container, error) { - containers, err := j.dockerClient.GetContainers(ctx) - if err != nil { - return types.Container{}, err - } - - // find the container that has the same ip - for _, container := range containers { - for _, network := range container.NetworkSettings.Networks { - if network.IPAddress == ipAddr { - return container, nil - } - } - } - return types.Container{}, fmt.Errorf("can't find container %s", ipAddr) -} - func (j *JWTProvider) testAPI(_ context.Context) { j.lastErr.Set(nil) } @@ -184,15 +145,4 @@ func (j *JWTProvider) Health() health.Reports { return health.Reports{ j.lastErr.GetReport("api"), } -} - -// CreateBotJWT returns a bot JWT token. Basically security.ScannerJWT with bot&request info. -func CreateBotJWT(key *keystore.Key, agentID string, claims map[string]interface{}) (string, error) { - if claims == nil { - claims = make(map[string]interface{}) - } - - claims["bot-id"] = agentID - - return security.CreateScannerJWT(key, claims) -} +} \ No newline at end of file