From 82ab1ec60fda1e06081359537ce2b8f07c445517 Mon Sep 17 00:00:00 2001 From: wucm667 Date: Sat, 25 Apr 2026 15:21:36 +0800 Subject: [PATCH] fix(tool): support b64_json image generation results --- tool/builtin_tools/image_generate.go | 94 +++++++++++++++++++++-- tool/builtin_tools/image_generate_test.go | 82 ++++++++++++++++++++ 2 files changed, 171 insertions(+), 5 deletions(-) diff --git a/tool/builtin_tools/image_generate.go b/tool/builtin_tools/image_generate.go index 1f65750..8c682dd 100644 --- a/tool/builtin_tools/image_generate.go +++ b/tool/builtin_tools/image_generate.go @@ -15,6 +15,7 @@ package builtin_tools import ( + "encoding/base64" "fmt" "strings" "sync" @@ -22,6 +23,7 @@ import ( "github.com/volcengine/veadk-go/auth/veauth" "github.com/volcengine/veadk-go/common" "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/integrations/ve_tos" "github.com/volcengine/veadk-go/log" "github.com/volcengine/veadk-go/utils" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" @@ -118,6 +120,7 @@ type ImageGenerateConfig struct { ModelName string APIKey string BaseURL string + TosConfig *ve_tos.Config } type ImageGenerateToolRequest struct { @@ -155,8 +158,23 @@ type ImageGenerateToolChanelMessage struct { const ( ImageGenerateSuccessStatus = "success" ImageGenerateErrorStatus = "error" + imageGenerateTOSBucketPath = "image_generate" ) +type imageGenerateUploader interface { + UploadBytes(data []byte, objectKey string, metadata map[string]string) error + BuildTOSURL(objectKey string) string +} + +type imageGenerateUploaderGetter func() (imageGenerateUploader, error) + +var newImageGenerateTOSUploader = func(config *ve_tos.Config) (imageGenerateUploader, error) { + if config == nil { + config = &ve_tos.Config{} + } + return ve_tos.New(config) +} + func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) { if config == nil { config = &ImageGenerateConfig{} @@ -188,9 +206,18 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) { result := &ImageGenerateToolResult{} var wg sync.WaitGroup ch := make(chan *ImageGenerateToolChanelMessage) + var uploader imageGenerateUploader + var uploaderErr error + var uploaderOnce sync.Once + getUploader := func() (imageGenerateUploader, error) { + uploaderOnce.Do(func() { + uploader, uploaderErr = newImageGenerateTOSUploader(config.TosConfig) + }) + return uploader, uploaderErr + } for i, task := range toolRequest.Tasks { wg.Add(1) - go func(req GenerateImagesRequest) { + go func(taskIndex int, req GenerateImagesRequest) { defer func() { wg.Done() if r := recover(); r != nil { @@ -241,20 +268,29 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) { } for index, imageData := range resp.Data { - imageName := fmt.Sprintf("task_%d_image_%d", i, index) + imageName := fmt.Sprintf("task_%d_image_%d", taskIndex, index) imageUrl := "" + if imageData == nil { + ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: "image result is empty", Result: &ImageResult{ImageName: imageName}} + continue + } if imageData.Url != nil { imageUrl = *imageData.Url } else if imageData.B64Json != nil { - // 上传到tos + imageUrl, err = uploadImageGenerateB64Result(getUploader, *imageData.B64Json, imageName) + if err != nil { + log.Error("Failed to upload generated image", "error", err) + ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: err.Error(), Result: &ImageResult{ImageName: imageName}} + continue + } } else { - ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateSuccessStatus, ErrorMessage: "image url or b64_json is empty", Result: &ImageResult{ImageName: imageName}} + ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: "image url or b64_json is empty", Result: &ImageResult{ImageName: imageName}} continue } ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateSuccessStatus, Result: &ImageResult{ImageName: imageName, Url: imageUrl}} } - }(task) + }(i, task) } go func() { @@ -288,3 +324,51 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) { handler) } + +func uploadImageGenerateB64Result(getUploader imageGenerateUploaderGetter, b64Image string, imageName string) (string, error) { + imageBytes, err := decodeImageGenerateB64(b64Image) + if err != nil { + return "", err + } + + uploader, err := getUploader() + if err != nil { + return "", fmt.Errorf("new TOS client error: %w", err) + } + + objectKey := buildImageGenerateObjectKey(imageName) + if err = uploader.UploadBytes(imageBytes, objectKey, nil); err != nil { + return "", fmt.Errorf("upload image to TOS error: %w", err) + } + return uploader.BuildTOSURL(objectKey), nil +} + +func decodeImageGenerateB64(b64Image string) ([]byte, error) { + encoded := strings.TrimSpace(b64Image) + if encoded == "" { + return nil, fmt.Errorf("image b64_json is empty") + } + + if commaIndex := strings.Index(encoded, ","); commaIndex >= 0 && strings.HasPrefix(strings.ToLower(encoded[:commaIndex]), "data:") { + encoded = encoded[commaIndex+1:] + } + + imageBytes, err := base64.StdEncoding.DecodeString(encoded) + if err == nil { + return imageBytes, nil + } + imageBytes, rawErr := base64.RawStdEncoding.DecodeString(encoded) + if rawErr == nil { + return imageBytes, nil + } + return nil, fmt.Errorf("decode image b64_json error: %w", err) +} + +func buildImageGenerateObjectKey(imageName string) string { + cleanName := strings.TrimSpace(imageName) + if cleanName == "" { + cleanName = "image" + } + cleanName = strings.NewReplacer("/", "_", "\\", "_", ":", "_").Replace(cleanName) + return fmt.Sprintf("%s/%s.png", imageGenerateTOSBucketPath, cleanName) +} diff --git a/tool/builtin_tools/image_generate_test.go b/tool/builtin_tools/image_generate_test.go index 2934202..9bbc572 100644 --- a/tool/builtin_tools/image_generate_test.go +++ b/tool/builtin_tools/image_generate_test.go @@ -15,6 +15,8 @@ package builtin_tools import ( + "encoding/base64" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -163,3 +165,83 @@ func TestImageGenerateToolHandler(t *testing.T) { }) } } + +func TestDecodeImageGenerateB64(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("image-data")) + + decoded, err := decodeImageGenerateB64(encoded) + assert.NoError(t, err) + assert.Equal(t, []byte("image-data"), decoded) + + decoded, err = decodeImageGenerateB64("data:image/png;base64," + encoded) + assert.NoError(t, err) + assert.Equal(t, []byte("image-data"), decoded) + + _, err = decodeImageGenerateB64("not-valid-base64") + assert.Error(t, err) + + _, err = decodeImageGenerateB64("") + assert.Error(t, err) +} + +func TestBuildImageGenerateObjectKey(t *testing.T) { + assert.Equal(t, "image_generate/task_0_image_0.png", buildImageGenerateObjectKey("task_0_image_0")) + assert.Equal(t, "image_generate/task_0_image_0.png", buildImageGenerateObjectKey("task/0:image\\0")) + assert.Equal(t, "image_generate/image.png", buildImageGenerateObjectKey(" ")) +} + +func TestUploadImageGenerateB64Result(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("image-data")) + uploader := &mockImageGenerateUploader{} + + url, err := uploadImageGenerateB64Result(func() (imageGenerateUploader, error) { + return uploader, nil + }, encoded, "task/0:image\\0") + + assert.NoError(t, err) + assert.Equal(t, "tos://image_generate/task_0_image_0.png", url) + assert.Equal(t, []byte("image-data"), uploader.data) + assert.Equal(t, "image_generate/task_0_image_0.png", uploader.objectKey) +} + +func TestUploadImageGenerateB64ResultErrors(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("image-data")) + + _, err := uploadImageGenerateB64Result(func() (imageGenerateUploader, error) { + return nil, errors.New("new uploader failed") + }, encoded, "image") + assert.Error(t, err) + assert.Contains(t, err.Error(), "new TOS client error") + + _, err = uploadImageGenerateB64Result(func() (imageGenerateUploader, error) { + return &mockImageGenerateUploader{err: errors.New("upload failed")}, nil + }, encoded, "image") + assert.Error(t, err) + assert.Contains(t, err.Error(), "upload image to TOS error") + + _, err = uploadImageGenerateB64Result(func() (imageGenerateUploader, error) { + return &mockImageGenerateUploader{}, nil + }, "not-valid-base64", "image") + assert.Error(t, err) + assert.Contains(t, err.Error(), "decode image b64_json error") +} + +type mockImageGenerateUploader struct { + data []byte + objectKey string + err error +} + +func (m *mockImageGenerateUploader) UploadBytes(data []byte, objectKey string, metadata map[string]string) error { + _ = metadata + if m.err != nil { + return m.err + } + m.data = append([]byte(nil), data...) + m.objectKey = objectKey + return nil +} + +func (m *mockImageGenerateUploader) BuildTOSURL(objectKey string) string { + return "tos://" + objectKey +}