Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions tool/builtin_tools/image_generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
package builtin_tools

import (
"encoding/base64"
"fmt"
"strings"
"sync"

"github.com/volcengine/veadk-go/auth/veauth"
"github.com/volcengine/veadk-go/common"
"github.com/volcengine/veadk-go/configs"
"github.com/volcengine/veadk-go/integrations/ve_tos"
"github.com/volcengine/veadk-go/log"
"github.com/volcengine/veadk-go/utils"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
Expand Down Expand Up @@ -118,6 +120,7 @@ type ImageGenerateConfig struct {
ModelName string
APIKey string
BaseURL string
TosConfig *ve_tos.Config
}

type ImageGenerateToolRequest struct {
Expand Down Expand Up @@ -155,8 +158,23 @@ type ImageGenerateToolChanelMessage struct {
const (
ImageGenerateSuccessStatus = "success"
ImageGenerateErrorStatus = "error"
imageGenerateTOSBucketPath = "image_generate"
)

type imageGenerateUploader interface {
UploadBytes(data []byte, objectKey string, metadata map[string]string) error
BuildTOSURL(objectKey string) string
}

type imageGenerateUploaderGetter func() (imageGenerateUploader, error)

var newImageGenerateTOSUploader = func(config *ve_tos.Config) (imageGenerateUploader, error) {
if config == nil {
config = &ve_tos.Config{}
}
return ve_tos.New(config)
}

func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) {
if config == nil {
config = &ImageGenerateConfig{}
Expand Down Expand Up @@ -188,9 +206,18 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) {
result := &ImageGenerateToolResult{}
var wg sync.WaitGroup
ch := make(chan *ImageGenerateToolChanelMessage)
var uploader imageGenerateUploader
var uploaderErr error
var uploaderOnce sync.Once
getUploader := func() (imageGenerateUploader, error) {
uploaderOnce.Do(func() {
uploader, uploaderErr = newImageGenerateTOSUploader(config.TosConfig)
})
return uploader, uploaderErr
}
for i, task := range toolRequest.Tasks {
wg.Add(1)
go func(req GenerateImagesRequest) {
go func(taskIndex int, req GenerateImagesRequest) {
defer func() {
wg.Done()
if r := recover(); r != nil {
Expand Down Expand Up @@ -241,20 +268,29 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) {
}

for index, imageData := range resp.Data {
imageName := fmt.Sprintf("task_%d_image_%d", i, index)
imageName := fmt.Sprintf("task_%d_image_%d", taskIndex, index)
imageUrl := ""
if imageData == nil {
ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: "image result is empty", Result: &ImageResult{ImageName: imageName}}
continue
}
if imageData.Url != nil {
imageUrl = *imageData.Url
} else if imageData.B64Json != nil {
// 上传到tos
imageUrl, err = uploadImageGenerateB64Result(getUploader, *imageData.B64Json, imageName)
if err != nil {
log.Error("Failed to upload generated image", "error", err)
ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: err.Error(), Result: &ImageResult{ImageName: imageName}}
continue
}
} else {
ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateSuccessStatus, ErrorMessage: "image url or b64_json is empty", Result: &ImageResult{ImageName: imageName}}
ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateErrorStatus, ErrorMessage: "image url or b64_json is empty", Result: &ImageResult{ImageName: imageName}}
continue
}

ch <- &ImageGenerateToolChanelMessage{Status: ImageGenerateSuccessStatus, Result: &ImageResult{ImageName: imageName, Url: imageUrl}}
}
}(task)
}(i, task)
}

go func() {
Expand Down Expand Up @@ -288,3 +324,51 @@ func NewImageGenerateTool(config *ImageGenerateConfig) (tool.Tool, error) {
handler)

}

func uploadImageGenerateB64Result(getUploader imageGenerateUploaderGetter, b64Image string, imageName string) (string, error) {
imageBytes, err := decodeImageGenerateB64(b64Image)
if err != nil {
return "", err
}

uploader, err := getUploader()
if err != nil {
return "", fmt.Errorf("new TOS client error: %w", err)
}

objectKey := buildImageGenerateObjectKey(imageName)
if err = uploader.UploadBytes(imageBytes, objectKey, nil); err != nil {
return "", fmt.Errorf("upload image to TOS error: %w", err)
}
return uploader.BuildTOSURL(objectKey), nil
}

func decodeImageGenerateB64(b64Image string) ([]byte, error) {
encoded := strings.TrimSpace(b64Image)
if encoded == "" {
return nil, fmt.Errorf("image b64_json is empty")
}

if commaIndex := strings.Index(encoded, ","); commaIndex >= 0 && strings.HasPrefix(strings.ToLower(encoded[:commaIndex]), "data:") {
encoded = encoded[commaIndex+1:]
}

imageBytes, err := base64.StdEncoding.DecodeString(encoded)
if err == nil {
return imageBytes, nil
}
imageBytes, rawErr := base64.RawStdEncoding.DecodeString(encoded)
if rawErr == nil {
return imageBytes, nil
}
return nil, fmt.Errorf("decode image b64_json error: %w", err)
}

func buildImageGenerateObjectKey(imageName string) string {
cleanName := strings.TrimSpace(imageName)
if cleanName == "" {
cleanName = "image"
}
cleanName = strings.NewReplacer("/", "_", "\\", "_", ":", "_").Replace(cleanName)
return fmt.Sprintf("%s/%s.png", imageGenerateTOSBucketPath, cleanName)
}
82 changes: 82 additions & 0 deletions tool/builtin_tools/image_generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package builtin_tools

import (
"encoding/base64"
"errors"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -163,3 +165,83 @@ func TestImageGenerateToolHandler(t *testing.T) {
})
}
}

func TestDecodeImageGenerateB64(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("image-data"))

decoded, err := decodeImageGenerateB64(encoded)
assert.NoError(t, err)
assert.Equal(t, []byte("image-data"), decoded)

decoded, err = decodeImageGenerateB64("data:image/png;base64," + encoded)
assert.NoError(t, err)
assert.Equal(t, []byte("image-data"), decoded)

_, err = decodeImageGenerateB64("not-valid-base64")
assert.Error(t, err)

_, err = decodeImageGenerateB64("")
assert.Error(t, err)
}

func TestBuildImageGenerateObjectKey(t *testing.T) {
assert.Equal(t, "image_generate/task_0_image_0.png", buildImageGenerateObjectKey("task_0_image_0"))
assert.Equal(t, "image_generate/task_0_image_0.png", buildImageGenerateObjectKey("task/0:image\\0"))
assert.Equal(t, "image_generate/image.png", buildImageGenerateObjectKey(" "))
}

func TestUploadImageGenerateB64Result(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("image-data"))
uploader := &mockImageGenerateUploader{}

url, err := uploadImageGenerateB64Result(func() (imageGenerateUploader, error) {
return uploader, nil
}, encoded, "task/0:image\\0")

assert.NoError(t, err)
assert.Equal(t, "tos://image_generate/task_0_image_0.png", url)
assert.Equal(t, []byte("image-data"), uploader.data)
assert.Equal(t, "image_generate/task_0_image_0.png", uploader.objectKey)
}

func TestUploadImageGenerateB64ResultErrors(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("image-data"))

_, err := uploadImageGenerateB64Result(func() (imageGenerateUploader, error) {
return nil, errors.New("new uploader failed")
}, encoded, "image")
assert.Error(t, err)
assert.Contains(t, err.Error(), "new TOS client error")

_, err = uploadImageGenerateB64Result(func() (imageGenerateUploader, error) {
return &mockImageGenerateUploader{err: errors.New("upload failed")}, nil
}, encoded, "image")
assert.Error(t, err)
assert.Contains(t, err.Error(), "upload image to TOS error")

_, err = uploadImageGenerateB64Result(func() (imageGenerateUploader, error) {
return &mockImageGenerateUploader{}, nil
}, "not-valid-base64", "image")
assert.Error(t, err)
assert.Contains(t, err.Error(), "decode image b64_json error")
}

type mockImageGenerateUploader struct {
data []byte
objectKey string
err error
}

func (m *mockImageGenerateUploader) UploadBytes(data []byte, objectKey string, metadata map[string]string) error {
_ = metadata
if m.err != nil {
return m.err
}
m.data = append([]byte(nil), data...)
m.objectKey = objectKey
return nil
}

func (m *mockImageGenerateUploader) BuildTOSURL(objectKey string) string {
return "tos://" + objectKey
}
Loading