From 66066d3bbf3a84e006a5bda43f28b21b0d2bdaec Mon Sep 17 00:00:00 2001 From: wucm667 Date: Sat, 25 Apr 2026 14:17:54 +0800 Subject: [PATCH] feat(tool): add parallel web search --- tool/builtin_tools/web_search/web_search.go | 140 +++++++++++++++--- .../web_search/web_search_test.go | 114 ++++++++++++++ 2 files changed, 232 insertions(+), 22 deletions(-) diff --git a/tool/builtin_tools/web_search/web_search.go b/tool/builtin_tools/web_search/web_search.go index 089f18f..9a3aa9c 100644 --- a/tool/builtin_tools/web_search/web_search.go +++ b/tool/builtin_tools/web_search/web_search.go @@ -20,6 +20,7 @@ import ( "fmt" "net/http" "strings" + "sync" "github.com/volcengine/veadk-go/auth/veauth" "github.com/volcengine/veadk-go/common" @@ -41,6 +42,10 @@ const ( var ErrWebSearchConfig = errors.New("web search config error") +var doWebSearchRequest = func(client *ve_sign.VeRequest) ([]byte, error) { + return client.DoRequest() +} + func NewClient() *ve_sign.VeRequest { return &ve_sign.VeRequest{ Method: http.MethodPost, @@ -62,56 +67,81 @@ type WebSearchResult struct { Result []string `json:"result,omitempty"` } +type ParallelWebSearchArgs struct { + Queries []string `json:"queries" jsonschema:"The queries to search in parallel"` +} + +type ParallelWebSearchResult struct { + Result map[string][]string `json:"result,omitempty"` + Errors map[string]string `json:"errors,omitempty"` +} + type Config struct { TopK int } -func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearchResult, error) { - var ak string - var sk string - var header map[string]string - var result *WebSearchResponse - var out = WebSearchResult{Result: make([]string, 0)} +type webSearchCredential struct { + AK string + SK string + Header map[string]string +} - client := NewClient() +func resolveWebSearchCredential(ctx tool.Context) webSearchCredential { + var header map[string]string + credential := webSearchCredential{} if ctx != nil { - client.AK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_ACCESS_KEY) - client.SK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_SECRET_KEY) + credential.AK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_ACCESS_KEY) + credential.SK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_SECRET_KEY) } - if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" { - client.AK = utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, configs.GetGlobalConfig().Volcengine.AK) - client.SK = utils.GetEnvWithDefault(common.VOLCENGINE_SECRET_KEY, configs.GetGlobalConfig().Volcengine.SK) + if strings.TrimSpace(credential.AK) == "" { + credential.AK = utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, configs.GetGlobalConfig().Volcengine.AK) + } + if strings.TrimSpace(credential.SK) == "" { + credential.SK = utils.GetEnvWithDefault(common.VOLCENGINE_SECRET_KEY, configs.GetGlobalConfig().Volcengine.SK) } - if strings.TrimSpace(client.AK) == "" || strings.TrimSpace(client.SK) == "" { + if strings.TrimSpace(credential.AK) == "" || strings.TrimSpace(credential.SK) == "" { iam, err := veauth.GetCredentialFromVeFaaSIAM() if err != nil { log.Warn(fmt.Sprintf("%s : GetCredential error: %s", ErrWebSearchConfig.Error(), err.Error())) } else { - client.AK = iam.AccessKeyID - client.SK = iam.SecretAccessKey + credential.AK = iam.AccessKeyID + credential.SK = iam.SecretAccessKey if iam.SessionToken != "" { header = map[string]string{"X-Security-Token": iam.SessionToken} } } } + credential.Header = header + return credential +} - client.Header = header - +func (c Config) topK() int { if c.TopK <= 0 { - c.TopK = DefaultTopK + return DefaultTopK } + return c.TopK +} + +func (c Config) search(query string, credential webSearchCredential) ([]string, error) { + var result *WebSearchResponse + out := make([]string, 0) + + client := NewClient() + client.AK = credential.AK + client.SK = credential.SK + client.Header = credential.Header body := map[string]any{ - "Query": args.Query, + "Query": query, "SearchType": "web", - "Count": c.TopK, + "Count": c.topK(), "NeedSummary": true, } client.Body = body - resp, err := client.DoRequest() + resp, err := doWebSearchRequest(client) if err != nil { return out, err } @@ -124,13 +154,63 @@ func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearc return out, fmt.Errorf("web search result is empty") } for _, item := range result.Result.WebResults { - out.Result = append(out.Result, item.Summary) + out = append(out, item.Summary) + } + + return out, nil +} + +func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearchResult, error) { + result, err := c.search(args.Query, resolveWebSearchCredential(ctx)) + if err != nil { + return WebSearchResult{Result: make([]string, 0)}, err + } + return WebSearchResult{Result: result}, nil +} + +func (c Config) parallelWebSearchHandler(ctx tool.Context, args ParallelWebSearchArgs) (ParallelWebSearchResult, error) { + out := ParallelWebSearchResult{ + Result: make(map[string][]string), + Errors: make(map[string]string), + } + if len(args.Queries) == 0 { + return out, nil } + credential := resolveWebSearchCredential(ctx) + var wg sync.WaitGroup + var mu sync.Mutex + for _, query := range args.Queries { + query = strings.TrimSpace(query) + if query == "" { + continue + } + + wg.Add(1) + go func(q string) { + defer wg.Done() + result, err := c.search(q, credential) + mu.Lock() + defer mu.Unlock() + if err != nil { + out.Errors[q] = err.Error() + return + } + out.Result[q] = result + }(query) + } + wg.Wait() + + if len(out.Errors) == 0 { + out.Errors = nil + } return out, nil } func NewWebSearchTool(cfg *Config) (tool.Tool, error) { + if cfg == nil { + cfg = &Config{} + } return functiontool.New( functiontool.Config{ Name: "web_search", @@ -142,3 +222,19 @@ Returns: }, cfg.webSearchHandler) } + +func NewParallelWebSearchTool(cfg *Config) (tool.Tool, error) { + if cfg == nil { + cfg = &Config{} + } + return functiontool.New( + functiontool.Config{ + Name: "parallel_web_search", + Description: `Search multiple queries from websites in parallel. +Args: +queries: The queries to search. Each query will be searched in parallel. +Returns: +A map of query to result documents, plus per-query errors if any.`, + }, + cfg.parallelWebSearchHandler) +} diff --git a/tool/builtin_tools/web_search/web_search_test.go b/tool/builtin_tools/web_search/web_search_test.go index 0f29855..905f0c1 100644 --- a/tool/builtin_tools/web_search/web_search_test.go +++ b/tool/builtin_tools/web_search/web_search_test.go @@ -15,14 +15,128 @@ package web_search import ( + "errors" "github.com/volcengine/veadk-go/log" + "sync" "testing" + "github.com/stretchr/testify/assert" "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/integrations/ve_sign" "github.com/volcengine/veadk-go/utils" ) +func mockWebSearchRequest(t *testing.T, fn func(client *ve_sign.VeRequest) ([]byte, error)) { + t.Helper() + old := doWebSearchRequest + doWebSearchRequest = fn + t.Cleanup(func() { + doWebSearchRequest = old + }) +} + +func setWebSearchCredentialEnv(t *testing.T) { + t.Helper() + t.Setenv(common.VOLCENGINE_ACCESS_KEY, "test-ak") + t.Setenv(common.VOLCENGINE_SECRET_KEY, "test-sk") +} + +func TestWebSearchHandler(t *testing.T) { + setWebSearchCredentialEnv(t) + mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) { + body := client.Body.(map[string]any) + assert.Equal(t, "golang", body["Query"]) + assert.Equal(t, "web", body["SearchType"]) + assert.Equal(t, DefaultTopK, body["Count"]) + assert.Equal(t, true, body["NeedSummary"]) + assert.Equal(t, "test-ak", client.AK) + assert.Equal(t, "test-sk", client.SK) + return []byte(`{"Result":{"WebResults":[{"Summary":"summary one"},{"Summary":"summary two"}]}}`), nil + }) + + result, err := Config{}.webSearchHandler(nil, WebSearchArgs{Query: "golang"}) + + assert.NoError(t, err) + assert.Equal(t, []string{"summary one", "summary two"}, result.Result) +} + +func TestWebSearchHandlerCustomTopK(t *testing.T) { + setWebSearchCredentialEnv(t) + mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) { + body := client.Body.(map[string]any) + assert.Equal(t, 3, body["Count"]) + return []byte(`{"Result":{"WebResults":[{"Summary":"summary"}]}}`), nil + }) + + result, err := Config{TopK: 3}.webSearchHandler(nil, WebSearchArgs{Query: "golang"}) + + assert.NoError(t, err) + assert.Equal(t, []string{"summary"}, result.Result) +} + +func TestWebSearchHandlerEmptyResult(t *testing.T) { + setWebSearchCredentialEnv(t) + mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) { + return []byte(`{"Result":{"WebResults":[]}}`), nil + }) + + result, err := Config{}.webSearchHandler(nil, WebSearchArgs{Query: "golang"}) + + assert.Error(t, err) + assert.Empty(t, result.Result) +} + +func TestParallelWebSearchHandler(t *testing.T) { + setWebSearchCredentialEnv(t) + var mu sync.Mutex + queries := make(map[string]bool) + mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) { + body := client.Body.(map[string]any) + query := body["Query"].(string) + mu.Lock() + queries[query] = true + mu.Unlock() + + if query == "bad" { + return nil, errors.New("search failed") + } + return []byte(`{"Result":{"WebResults":[{"Summary":"summary for ` + query + `"}]}}`), nil + }) + + result, err := Config{TopK: 2}.parallelWebSearchHandler(nil, ParallelWebSearchArgs{ + Queries: []string{"alpha", "bad", " ", "beta"}, + }) + + assert.NoError(t, err) + assert.Equal(t, []string{"summary for alpha"}, result.Result["alpha"]) + assert.Equal(t, []string{"summary for beta"}, result.Result["beta"]) + assert.Contains(t, result.Errors["bad"], "search failed") + assert.NotContains(t, result.Result, "") + assert.True(t, queries["alpha"]) + assert.True(t, queries["beta"]) + assert.True(t, queries["bad"]) + assert.False(t, queries[""]) +} + +func TestParallelWebSearchHandlerEmptyQueries(t *testing.T) { + result, err := Config{}.parallelWebSearchHandler(nil, ParallelWebSearchArgs{}) + + assert.NoError(t, err) + assert.Empty(t, result.Result) + assert.Empty(t, result.Errors) +} + +func TestNewWebSearchTools(t *testing.T) { + webSearchTool, err := NewWebSearchTool(nil) + assert.NoError(t, err) + assert.NotNil(t, webSearchTool) + + parallelTool, err := NewParallelWebSearchTool(nil) + assert.NoError(t, err) + assert.NotNil(t, parallelTool) +} + func TestClient_DoRequest(t *testing.T) { ak := utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, "")