diff --git a/common/consts.go b/common/consts.go index 21c9ef4..c2a9569 100644 --- a/common/consts.go +++ b/common/consts.go @@ -25,6 +25,11 @@ const ( MODEL_IMAGE_API_BASE = "MODEL_IMAGE_API_BASE" MODEL_IMAGE_API_KEY = "MODEL_IMAGE_API_KEY" + // Edit + MODEL_EDIT_NAME = "MODEL_EDIT_NAME" + MODEL_EDIT_API_BASE = "MODEL_EDIT_API_BASE" + MODEL_EDIT_API_KEY = "MODEL_EDIT_API_KEY" + // Video MODEL_VIDEO_NAME = "MODEL_VIDEO_NAME" MODEL_VIDEO_API_BASE = "MODEL_VIDEO_API_BASE" diff --git a/common/defaults.go b/common/defaults.go index 724f218..e2ee171 100644 --- a/common/defaults.go +++ b/common/defaults.go @@ -25,6 +25,10 @@ const ( DEFAULT_MODEL_IMAGE_NAME = "doubao-seedream-4-5-251128" DEFAULT_MODEL_IMAGE_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" + // Edit + DEFAULT_MODEL_EDIT_NAME = "doubao-seededit-3-0-i2i-250628" + DEFAULT_MODEL_EDIT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" + // Video DEFAULT_MODEL_VIDEO_NAME = "doubao-seedance-1-0-pro-250528" DEFAULT_MODEL_VIDEO_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" diff --git a/configs/config_test.go b/configs/config_test.go index b95b576..cb3b2a5 100644 --- a/configs/config_test.go +++ b/configs/config_test.go @@ -93,6 +93,20 @@ func TestSetupVeADKConfig(t *testing.T) { assert.Equal(t, "doubao-seed-1-6-250615", os.Getenv(common.MODEL_AGENT_NAME)) } +func TestModelConfig_EditEnvMapping(t *testing.T) { + t.Setenv(common.MODEL_EDIT_NAME, "edit-model") + t.Setenv(common.MODEL_EDIT_API_BASE, "https://edit.example.com") + t.Setenv(common.MODEL_EDIT_API_KEY, "edit-key") + + config := &ModelConfig{} + config.MapEnvToConfig() + + assert.NotNil(t, config.Edit) + assert.Equal(t, "edit-model", config.Edit.Name) + assert.Equal(t, "https://edit.example.com", config.Edit.ApiBase) + assert.Equal(t, "edit-key", config.Edit.ApiKey) +} + func TestObservabilityConfig_YamlMapping(t *testing.T) { yamlData := ` opentelemetry: diff --git a/configs/configs.go b/configs/configs.go index 095624f..111e285 100644 --- a/configs/configs.go +++ b/configs/configs.go @@ -71,6 +71,7 @@ func SetupVeADKConfig() error { Model: &ModelConfig{ Agent: &AgentConfig{}, Image: &CommonModelConfig{}, + Edit: &CommonModelConfig{}, Video: &CommonModelConfig{}, Embedding: &EmbeddingModelConfig{}, }, diff --git a/configs/model.go b/configs/model.go index daf4193..0626465 100644 --- a/configs/model.go +++ b/configs/model.go @@ -40,11 +40,28 @@ type EmbeddingModelConfig struct { type ModelConfig struct { Agent *AgentConfig Image *CommonModelConfig + Edit *CommonModelConfig Video *CommonModelConfig Embedding *EmbeddingModelConfig } func (c *ModelConfig) MapEnvToConfig() { + if c.Agent == nil { + c.Agent = &AgentConfig{} + } + if c.Image == nil { + c.Image = &CommonModelConfig{} + } + if c.Edit == nil { + c.Edit = &CommonModelConfig{} + } + if c.Video == nil { + c.Video = &CommonModelConfig{} + } + if c.Embedding == nil { + c.Embedding = &EmbeddingModelConfig{} + } + // Agent c.Agent.Name = utils.GetEnvWithDefault(common.MODEL_AGENT_NAME, common.DEFAULT_MODEL_AGENT_NAME) c.Agent.Provider = utils.GetEnvWithDefault(common.MODEL_AGENT_PROVIDER, common.DEFAULT_MODEL_AGENT_PROVIDER) @@ -56,6 +73,11 @@ func (c *ModelConfig) MapEnvToConfig() { c.Image.ApiBase = utils.GetEnvWithDefault(common.MODEL_IMAGE_API_BASE, common.DEFAULT_MODEL_IMAGE_API_BASE) c.Image.ApiKey = utils.GetEnvWithDefault(common.MODEL_IMAGE_API_KEY) + // Edit + c.Edit.Name = utils.GetEnvWithDefault(common.MODEL_EDIT_NAME, common.DEFAULT_MODEL_EDIT_NAME) + c.Edit.ApiBase = utils.GetEnvWithDefault(common.MODEL_EDIT_API_BASE, common.DEFAULT_MODEL_EDIT_API_BASE) + c.Edit.ApiKey = utils.GetEnvWithDefault(common.MODEL_EDIT_API_KEY) + // Video c.Video.Name = utils.GetEnvWithDefault(common.MODEL_VIDEO_NAME, common.DEFAULT_MODEL_VIDEO_NAME) c.Video.ApiBase = utils.GetEnvWithDefault(common.MODEL_VIDEO_API_BASE, common.DEFAULT_MODEL_VIDEO_API_BASE) diff --git a/tool/builtin_tools/image_edit.go b/tool/builtin_tools/image_edit.go new file mode 100644 index 0000000..e12b5ca --- /dev/null +++ b/tool/builtin_tools/image_edit.go @@ -0,0 +1,298 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin_tools + +import ( + "fmt" + "sync" + + "github.com/volcengine/veadk-go/auth/veauth" + "github.com/volcengine/veadk-go/common" + "github.com/volcengine/veadk-go/configs" + "github.com/volcengine/veadk-go/log" + "github.com/volcengine/veadk-go/utils" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" +) + +const ( + defaultImageEditResponseFormat = "url" + defaultImageEditGuidanceScale = 2.5 + defaultImageEditSeed = -1 + + ImageEditSuccessStatus = "success" + ImageEditErrorStatus = "error" +) + +var imageEditToolDescription = ` + Edit images in batch according to prompts and optional generation settings. + + Args: + params (list of EditImagesRequest) + + Per-item schema (EditImagesRequest) + Required: + - origin_image (str): source image URL or Base64 data URL. + - prompt (str): image editing instruction. + + Optional: + - image_name (str): output image name. Defaults to generated_image_. + - response_format (str): "url" (default) or "b64_json". + - guidance_scale (float): prompt adherence strength. Defaults to 2.5. + - watermark (bool): whether to add watermark. Defaults to true. + - seed (int): random seed. Defaults to -1. + + Returns: + { + "status": "success", + "success_list": [{"image_name": "edited", "url": "..."}], + "error_list": [] + } +` + +type ImageEditConfig struct { + ModelName string + APIKey string + BaseURL string +} + +type ImageEditToolRequest struct { + Params []EditImagesRequest `json:"params"` +} + +type EditImagesRequest struct { + ImageName string `json:"image_name,omitempty"` + OriginImage string `json:"origin_image"` + Prompt string `json:"prompt"` + ResponseFormat string `json:"response_format,omitempty"` + GuidanceScale *float64 `json:"guidance_scale,omitempty"` + Watermark *bool `json:"watermark,omitempty"` + Seed *int64 `json:"seed,omitempty"` +} + +type ImageEditToolResult struct { + SuccessList []*ImageEditResult `json:"success_list,omitempty"` + ErrorList []*ImageEditResult `json:"error_list,omitempty"` + Status string `json:"status"` +} + +type ImageEditResult struct { + ImageName string `json:"image_name"` + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + Error string `json:"error,omitempty"` +} + +type ImageEditToolChannelMessage struct { + Status string + Result *ImageEditResult +} + +func NewImageEditTool(config *ImageEditConfig) (tool.Tool, error) { + if config == nil { + config = &ImageEditConfig{} + } + if config.ModelName == "" { + config.ModelName = utils.GetEnvWithDefault(common.MODEL_EDIT_NAME, configs.GetGlobalConfig().Model.Edit.Name, common.DEFAULT_MODEL_EDIT_NAME) + } + if config.APIKey == "" { + config.APIKey = resolveImageEditAPIKey() + } + if config.BaseURL == "" { + config.BaseURL = utils.GetEnvWithDefault(common.MODEL_EDIT_API_BASE, configs.GetGlobalConfig().Model.Edit.ApiBase, common.DEFAULT_MODEL_EDIT_API_BASE) + } + + log.Debug("Initializing image edit tool", "model", config.ModelName, "base_url", config.BaseURL) + + handler := func(ctx tool.Context, toolRequest ImageEditToolRequest) (*ImageEditToolResult, error) { + client := arkruntime.NewClientWithApiKey( + config.APIKey, + arkruntime.WithBaseUrl(config.BaseURL), + ) + + result := &ImageEditToolResult{} + var wg sync.WaitGroup + ch := make(chan *ImageEditToolChannelMessage) + for i, task := range toolRequest.Params { + wg.Add(1) + go func(index int, req EditImagesRequest) { + defer func() { + wg.Done() + if r := recover(); r != nil { + log.Error("Image edit task panic", "recover", r, "prompt", req.Prompt) + ch <- &ImageEditToolChannelMessage{ + Status: ImageEditErrorStatus, + Result: newImageEditErrorResult(defaultImageEditName(index), fmt.Sprintf("task panic: %v", r)), + } + } + }() + + imageName := imageEditName(req, index) + if req.Prompt == "" { + ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, "prompt is required")} + return + } + if req.OriginImage == "" { + ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, "origin_image is required")} + return + } + + resp, err := client.GenerateImages(ctx, buildImageEditModelRequest(config.ModelName, req)) + if err != nil { + log.Error("Failed to edit image", "error", err) + ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, err.Error())} + return + } + if resp.Error != nil { + ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, resp.Error.Message)} + return + } + + messages := imageEditResultsFromResponse(imageName, resp.Data) + for _, message := range messages { + ch <- message + } + }(i, task) + } + + go func() { + wg.Wait() + close(ch) + }() + + for res := range ch { + switch res.Status { + case ImageEditSuccessStatus: + result.SuccessList = append(result.SuccessList, res.Result) + case ImageEditErrorStatus: + result.ErrorList = append(result.ErrorList, res.Result) + } + } + + if len(result.SuccessList) == 0 { + result.Status = ImageEditErrorStatus + } else { + result.Status = ImageEditSuccessStatus + } + return result, nil + } + + return functiontool.New( + functiontool.Config{ + Name: "image_edit", + Description: imageEditToolDescription, + }, + handler) +} + +func buildImageEditModelRequest(modelName string, req EditImagesRequest) model.GenerateImagesRequest { + responseFormat := req.ResponseFormat + if responseFormat == "" { + responseFormat = defaultImageEditResponseFormat + } + + guidanceScale := defaultImageEditGuidanceScale + if req.GuidanceScale != nil { + guidanceScale = *req.GuidanceScale + } + + watermark := true + if req.Watermark != nil { + watermark = *req.Watermark + } + + seed := int64(defaultImageEditSeed) + if req.Seed != nil { + seed = *req.Seed + } + + return model.GenerateImagesRequest{ + Model: modelName, + Prompt: req.Prompt, + Image: req.OriginImage, + ResponseFormat: &responseFormat, + GuidanceScale: &guidanceScale, + Watermark: &watermark, + Seed: &seed, + } +} + +func resolveImageEditAPIKey() string { + if key := utils.GetEnvWithDefault(common.MODEL_EDIT_API_KEY, configs.GetGlobalConfig().Model.Edit.ApiKey); key != "" { + return key + } + if key := utils.GetEnvWithDefault(common.MODEL_AGENT_API_KEY, configs.GetGlobalConfig().Model.Agent.ApiKey); key != "" { + return key + } + return utils.Must(veauth.GetArkToken(common.DEFAULT_MODEL_REGION)) +} + +func imageEditResultsFromResponse(imageName string, images []*model.Image) []*ImageEditToolChannelMessage { + if len(images) == 0 { + return []*ImageEditToolChannelMessage{{ + Status: ImageEditErrorStatus, + Result: newImageEditErrorResult(imageName, "no images returned"), + }} + } + + messages := make([]*ImageEditToolChannelMessage, 0, len(images)) + for i, image := range images { + resultName := imageName + if len(images) > 1 { + resultName = fmt.Sprintf("%s_%d", imageName, i) + } + + switch { + case image == nil: + messages = append(messages, &ImageEditToolChannelMessage{ + Status: ImageEditErrorStatus, + Result: newImageEditErrorResult(resultName, "empty image result"), + }) + case image.Url != nil: + messages = append(messages, &ImageEditToolChannelMessage{ + Status: ImageEditSuccessStatus, + Result: &ImageEditResult{ImageName: resultName, Url: *image.Url}, + }) + case image.B64Json != nil: + messages = append(messages, &ImageEditToolChannelMessage{ + Status: ImageEditSuccessStatus, + Result: &ImageEditResult{ImageName: resultName, B64Json: *image.B64Json}, + }) + default: + messages = append(messages, &ImageEditToolChannelMessage{ + Status: ImageEditErrorStatus, + Result: newImageEditErrorResult(resultName, "image url or b64_json is empty"), + }) + } + } + return messages +} + +func imageEditName(req EditImagesRequest, index int) string { + if req.ImageName != "" { + return req.ImageName + } + return defaultImageEditName(index) +} + +func defaultImageEditName(index int) string { + return fmt.Sprintf("generated_image_%d", index) +} + +func newImageEditErrorResult(imageName, errMsg string) *ImageEditResult { + return &ImageEditResult{ImageName: imageName, Error: errMsg} +} diff --git a/tool/builtin_tools/image_edit_test.go b/tool/builtin_tools/image_edit_test.go new file mode 100644 index 0000000..711aba9 --- /dev/null +++ b/tool/builtin_tools/image_edit_test.go @@ -0,0 +1,114 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin_tools + +import ( + "testing" + + "github.com/stretchr/testify/assert" + arkmodel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/volcengine" +) + +func TestNewImageEditTool(t *testing.T) { + tool, err := NewImageEditTool(&ImageEditConfig{ + ModelName: "doubao-seededit-3-0-i2i-250628", + APIKey: "test-api-key", + BaseURL: "https://test-api.com", + }) + + assert.NoError(t, err) + assert.NotNil(t, tool) +} + +func TestBuildImageEditModelRequest_Defaults(t *testing.T) { + req := buildImageEditModelRequest("edit-model", EditImagesRequest{ + OriginImage: "https://example.com/input.png", + Prompt: "make the sky blue", + }) + + assert.Equal(t, "edit-model", req.Model) + assert.Equal(t, "make the sky blue", req.Prompt) + assert.Equal(t, "https://example.com/input.png", req.Image) + assert.NotNil(t, req.ResponseFormat) + assert.Equal(t, defaultImageEditResponseFormat, *req.ResponseFormat) + assert.NotNil(t, req.GuidanceScale) + assert.Equal(t, defaultImageEditGuidanceScale, *req.GuidanceScale) + assert.NotNil(t, req.Watermark) + assert.True(t, *req.Watermark) + assert.NotNil(t, req.Seed) + assert.Equal(t, int64(defaultImageEditSeed), *req.Seed) +} + +func TestBuildImageEditModelRequest_Overrides(t *testing.T) { + guidanceScale := 7.5 + watermark := false + seed := int64(42) + + req := buildImageEditModelRequest("edit-model", EditImagesRequest{ + OriginImage: "data:image/png;base64,abc", + Prompt: "remove the logo", + ResponseFormat: "b64_json", + GuidanceScale: &guidanceScale, + Watermark: &watermark, + Seed: &seed, + }) + + assert.Equal(t, "data:image/png;base64,abc", req.Image) + assert.Equal(t, "b64_json", *req.ResponseFormat) + assert.Equal(t, guidanceScale, *req.GuidanceScale) + assert.Equal(t, watermark, *req.Watermark) + assert.Equal(t, seed, *req.Seed) +} + +func TestImageEditResultsFromResponse(t *testing.T) { + url := "https://example.com/output.png" + b64 := "aW1hZ2U=" + + messages := imageEditResultsFromResponse("edited", []*arkmodel.Image{ + {Url: volcengine.String(url)}, + {B64Json: volcengine.String(b64)}, + nil, + {}, + }) + + assert.Len(t, messages, 4) + assert.Equal(t, ImageEditSuccessStatus, messages[0].Status) + assert.Equal(t, "edited_0", messages[0].Result.ImageName) + assert.Equal(t, url, messages[0].Result.Url) + assert.Equal(t, ImageEditSuccessStatus, messages[1].Status) + assert.Equal(t, "edited_1", messages[1].Result.ImageName) + assert.Equal(t, b64, messages[1].Result.B64Json) + assert.Equal(t, ImageEditErrorStatus, messages[2].Status) + assert.Equal(t, "edited_2", messages[2].Result.ImageName) + assert.Contains(t, messages[2].Result.Error, "empty image result") + assert.Equal(t, ImageEditErrorStatus, messages[3].Status) + assert.Equal(t, "edited_3", messages[3].Result.ImageName) + assert.Contains(t, messages[3].Result.Error, "image url or b64_json is empty") +} + +func TestImageEditResultsFromEmptyResponse(t *testing.T) { + messages := imageEditResultsFromResponse("edited", nil) + + assert.Len(t, messages, 1) + assert.Equal(t, ImageEditErrorStatus, messages[0].Status) + assert.Equal(t, "edited", messages[0].Result.ImageName) + assert.Contains(t, messages[0].Result.Error, "no images returned") +} + +func TestImageEditName(t *testing.T) { + assert.Equal(t, "custom", imageEditName(EditImagesRequest{ImageName: "custom"}, 3)) + assert.Equal(t, "generated_image_3", imageEditName(EditImagesRequest{}, 3)) +}