Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 118 additions & 22 deletions tool/builtin_tools/web_search/web_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"

"github.com/volcengine/veadk-go/auth/veauth"
"github.com/volcengine/veadk-go/common"
Expand All @@ -41,6 +42,10 @@ const (

var ErrWebSearchConfig = errors.New("web search config error")

var doWebSearchRequest = func(client *ve_sign.VeRequest) ([]byte, error) {
return client.DoRequest()
}

func NewClient() *ve_sign.VeRequest {
return &ve_sign.VeRequest{
Method: http.MethodPost,
Expand All @@ -62,56 +67,81 @@ type WebSearchResult struct {
Result []string `json:"result,omitempty"`
}

type ParallelWebSearchArgs struct {
Queries []string `json:"queries" jsonschema:"The queries to search in parallel"`
}

type ParallelWebSearchResult struct {
Result map[string][]string `json:"result,omitempty"`
Errors map[string]string `json:"errors,omitempty"`
}

type Config struct {
TopK int
}

func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearchResult, error) {
var ak string
var sk string
var header map[string]string
var result *WebSearchResponse
var out = WebSearchResult{Result: make([]string, 0)}
type webSearchCredential struct {
AK string
SK string
Header map[string]string
}

client := NewClient()
func resolveWebSearchCredential(ctx tool.Context) webSearchCredential {
var header map[string]string
credential := webSearchCredential{}
if ctx != nil {
client.AK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_ACCESS_KEY)
client.SK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_SECRET_KEY)
credential.AK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_ACCESS_KEY)
credential.SK = utils.GetStringFromToolContext(ctx, common.VOLCENGINE_SECRET_KEY)
}

if strings.TrimSpace(ak) == "" || strings.TrimSpace(sk) == "" {
client.AK = utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, configs.GetGlobalConfig().Volcengine.AK)
client.SK = utils.GetEnvWithDefault(common.VOLCENGINE_SECRET_KEY, configs.GetGlobalConfig().Volcengine.SK)
if strings.TrimSpace(credential.AK) == "" {
credential.AK = utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, configs.GetGlobalConfig().Volcengine.AK)
}
if strings.TrimSpace(credential.SK) == "" {
credential.SK = utils.GetEnvWithDefault(common.VOLCENGINE_SECRET_KEY, configs.GetGlobalConfig().Volcengine.SK)
}

if strings.TrimSpace(client.AK) == "" || strings.TrimSpace(client.SK) == "" {
if strings.TrimSpace(credential.AK) == "" || strings.TrimSpace(credential.SK) == "" {
iam, err := veauth.GetCredentialFromVeFaaSIAM()
if err != nil {
log.Warn(fmt.Sprintf("%s : GetCredential error: %s", ErrWebSearchConfig.Error(), err.Error()))
} else {
client.AK = iam.AccessKeyID
client.SK = iam.SecretAccessKey
credential.AK = iam.AccessKeyID
credential.SK = iam.SecretAccessKey
if iam.SessionToken != "" {
header = map[string]string{"X-Security-Token": iam.SessionToken}
}
}
}
credential.Header = header
return credential
}

client.Header = header

func (c Config) topK() int {
if c.TopK <= 0 {
c.TopK = DefaultTopK
return DefaultTopK
}
return c.TopK
}

func (c Config) search(query string, credential webSearchCredential) ([]string, error) {
var result *WebSearchResponse
out := make([]string, 0)

client := NewClient()
client.AK = credential.AK
client.SK = credential.SK
client.Header = credential.Header

body := map[string]any{
"Query": args.Query,
"Query": query,
"SearchType": "web",
"Count": c.TopK,
"Count": c.topK(),
"NeedSummary": true,
}
client.Body = body

resp, err := client.DoRequest()
resp, err := doWebSearchRequest(client)
if err != nil {
return out, err
}
Expand All @@ -124,13 +154,63 @@ func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearc
return out, fmt.Errorf("web search result is empty")
}
for _, item := range result.Result.WebResults {
out.Result = append(out.Result, item.Summary)
out = append(out, item.Summary)
}

return out, nil
}

func (c Config) webSearchHandler(ctx tool.Context, args WebSearchArgs) (WebSearchResult, error) {
result, err := c.search(args.Query, resolveWebSearchCredential(ctx))
if err != nil {
return WebSearchResult{Result: make([]string, 0)}, err
}
return WebSearchResult{Result: result}, nil
}

func (c Config) parallelWebSearchHandler(ctx tool.Context, args ParallelWebSearchArgs) (ParallelWebSearchResult, error) {
out := ParallelWebSearchResult{
Result: make(map[string][]string),
Errors: make(map[string]string),
}
if len(args.Queries) == 0 {
return out, nil
}

credential := resolveWebSearchCredential(ctx)
var wg sync.WaitGroup
var mu sync.Mutex
for _, query := range args.Queries {
query = strings.TrimSpace(query)
if query == "" {
continue
}

wg.Add(1)
go func(q string) {
defer wg.Done()
result, err := c.search(q, credential)
mu.Lock()
defer mu.Unlock()
if err != nil {
out.Errors[q] = err.Error()
return
}
out.Result[q] = result
}(query)
}
wg.Wait()

if len(out.Errors) == 0 {
out.Errors = nil
}
return out, nil
}

func NewWebSearchTool(cfg *Config) (tool.Tool, error) {
if cfg == nil {
cfg = &Config{}
}
return functiontool.New(
functiontool.Config{
Name: "web_search",
Expand All @@ -142,3 +222,19 @@ Returns:
},
cfg.webSearchHandler)
}

func NewParallelWebSearchTool(cfg *Config) (tool.Tool, error) {
if cfg == nil {
cfg = &Config{}
}
return functiontool.New(
functiontool.Config{
Name: "parallel_web_search",
Description: `Search multiple queries from websites in parallel.
Args:
queries: The queries to search. Each query will be searched in parallel.
Returns:
A map of query to result documents, plus per-query errors if any.`,
},
cfg.parallelWebSearchHandler)
}
114 changes: 114 additions & 0 deletions tool/builtin_tools/web_search/web_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,128 @@
package web_search

import (
"errors"
"github.com/volcengine/veadk-go/log"
"sync"

"testing"

"github.com/stretchr/testify/assert"
"github.com/volcengine/veadk-go/common"
"github.com/volcengine/veadk-go/integrations/ve_sign"
"github.com/volcengine/veadk-go/utils"
)

func mockWebSearchRequest(t *testing.T, fn func(client *ve_sign.VeRequest) ([]byte, error)) {
t.Helper()
old := doWebSearchRequest
doWebSearchRequest = fn
t.Cleanup(func() {
doWebSearchRequest = old
})
}

func setWebSearchCredentialEnv(t *testing.T) {
t.Helper()
t.Setenv(common.VOLCENGINE_ACCESS_KEY, "test-ak")
t.Setenv(common.VOLCENGINE_SECRET_KEY, "test-sk")
}

func TestWebSearchHandler(t *testing.T) {
setWebSearchCredentialEnv(t)
mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) {
body := client.Body.(map[string]any)
assert.Equal(t, "golang", body["Query"])
assert.Equal(t, "web", body["SearchType"])
assert.Equal(t, DefaultTopK, body["Count"])
assert.Equal(t, true, body["NeedSummary"])
assert.Equal(t, "test-ak", client.AK)
assert.Equal(t, "test-sk", client.SK)
return []byte(`{"Result":{"WebResults":[{"Summary":"summary one"},{"Summary":"summary two"}]}}`), nil
})

result, err := Config{}.webSearchHandler(nil, WebSearchArgs{Query: "golang"})

assert.NoError(t, err)
assert.Equal(t, []string{"summary one", "summary two"}, result.Result)
}

func TestWebSearchHandlerCustomTopK(t *testing.T) {
setWebSearchCredentialEnv(t)
mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) {
body := client.Body.(map[string]any)
assert.Equal(t, 3, body["Count"])
return []byte(`{"Result":{"WebResults":[{"Summary":"summary"}]}}`), nil
})

result, err := Config{TopK: 3}.webSearchHandler(nil, WebSearchArgs{Query: "golang"})

assert.NoError(t, err)
assert.Equal(t, []string{"summary"}, result.Result)
}

func TestWebSearchHandlerEmptyResult(t *testing.T) {
setWebSearchCredentialEnv(t)
mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) {
return []byte(`{"Result":{"WebResults":[]}}`), nil
})

result, err := Config{}.webSearchHandler(nil, WebSearchArgs{Query: "golang"})

assert.Error(t, err)
assert.Empty(t, result.Result)
}

func TestParallelWebSearchHandler(t *testing.T) {
setWebSearchCredentialEnv(t)
var mu sync.Mutex
queries := make(map[string]bool)
mockWebSearchRequest(t, func(client *ve_sign.VeRequest) ([]byte, error) {
body := client.Body.(map[string]any)
query := body["Query"].(string)
mu.Lock()
queries[query] = true
mu.Unlock()

if query == "bad" {
return nil, errors.New("search failed")
}
return []byte(`{"Result":{"WebResults":[{"Summary":"summary for ` + query + `"}]}}`), nil
})

result, err := Config{TopK: 2}.parallelWebSearchHandler(nil, ParallelWebSearchArgs{
Queries: []string{"alpha", "bad", " ", "beta"},
})

assert.NoError(t, err)
assert.Equal(t, []string{"summary for alpha"}, result.Result["alpha"])
assert.Equal(t, []string{"summary for beta"}, result.Result["beta"])
assert.Contains(t, result.Errors["bad"], "search failed")
assert.NotContains(t, result.Result, "")
assert.True(t, queries["alpha"])
assert.True(t, queries["beta"])
assert.True(t, queries["bad"])
assert.False(t, queries[""])
}

func TestParallelWebSearchHandlerEmptyQueries(t *testing.T) {
result, err := Config{}.parallelWebSearchHandler(nil, ParallelWebSearchArgs{})

assert.NoError(t, err)
assert.Empty(t, result.Result)
assert.Empty(t, result.Errors)
}

func TestNewWebSearchTools(t *testing.T) {
webSearchTool, err := NewWebSearchTool(nil)
assert.NoError(t, err)
assert.NotNil(t, webSearchTool)

parallelTool, err := NewParallelWebSearchTool(nil)
assert.NoError(t, err)
assert.NotNil(t, parallelTool)
}

func TestClient_DoRequest(t *testing.T) {

ak := utils.GetEnvWithDefault(common.VOLCENGINE_ACCESS_KEY, "")
Expand Down
Loading