Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ const (
MODEL_IMAGE_API_BASE = "MODEL_IMAGE_API_BASE"
MODEL_IMAGE_API_KEY = "MODEL_IMAGE_API_KEY"

// Edit
MODEL_EDIT_NAME = "MODEL_EDIT_NAME"
MODEL_EDIT_API_BASE = "MODEL_EDIT_API_BASE"
MODEL_EDIT_API_KEY = "MODEL_EDIT_API_KEY"

// Video
MODEL_VIDEO_NAME = "MODEL_VIDEO_NAME"
MODEL_VIDEO_API_BASE = "MODEL_VIDEO_API_BASE"
Expand Down
4 changes: 4 additions & 0 deletions common/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ const (
DEFAULT_MODEL_IMAGE_NAME = "doubao-seedream-4-5-251128"
DEFAULT_MODEL_IMAGE_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"

// Edit
DEFAULT_MODEL_EDIT_NAME = "doubao-seededit-3-0-i2i-250628"
DEFAULT_MODEL_EDIT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"

// Video
DEFAULT_MODEL_VIDEO_NAME = "doubao-seedance-1-0-pro-250528"
DEFAULT_MODEL_VIDEO_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
Expand Down
14 changes: 14 additions & 0 deletions configs/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ func TestSetupVeADKConfig(t *testing.T) {
assert.Equal(t, "doubao-seed-1-6-250615", os.Getenv(common.MODEL_AGENT_NAME))
}

func TestModelConfig_EditEnvMapping(t *testing.T) {
t.Setenv(common.MODEL_EDIT_NAME, "edit-model")
t.Setenv(common.MODEL_EDIT_API_BASE, "https://edit.example.com")
t.Setenv(common.MODEL_EDIT_API_KEY, "edit-key")

config := &ModelConfig{}
config.MapEnvToConfig()

assert.NotNil(t, config.Edit)
assert.Equal(t, "edit-model", config.Edit.Name)
assert.Equal(t, "https://edit.example.com", config.Edit.ApiBase)
assert.Equal(t, "edit-key", config.Edit.ApiKey)
}

func TestObservabilityConfig_YamlMapping(t *testing.T) {
yamlData := `
opentelemetry:
Expand Down
1 change: 1 addition & 0 deletions configs/configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func SetupVeADKConfig() error {
Model: &ModelConfig{
Agent: &AgentConfig{},
Image: &CommonModelConfig{},
Edit: &CommonModelConfig{},
Video: &CommonModelConfig{},
Embedding: &EmbeddingModelConfig{},
},
Expand Down
22 changes: 22 additions & 0 deletions configs/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,28 @@ type EmbeddingModelConfig struct {
type ModelConfig struct {
Agent *AgentConfig
Image *CommonModelConfig
Edit *CommonModelConfig
Video *CommonModelConfig
Embedding *EmbeddingModelConfig
}

func (c *ModelConfig) MapEnvToConfig() {
if c.Agent == nil {
c.Agent = &AgentConfig{}
}
if c.Image == nil {
c.Image = &CommonModelConfig{}
}
if c.Edit == nil {
c.Edit = &CommonModelConfig{}
}
if c.Video == nil {
c.Video = &CommonModelConfig{}
}
if c.Embedding == nil {
c.Embedding = &EmbeddingModelConfig{}
}

// Agent
c.Agent.Name = utils.GetEnvWithDefault(common.MODEL_AGENT_NAME, common.DEFAULT_MODEL_AGENT_NAME)
c.Agent.Provider = utils.GetEnvWithDefault(common.MODEL_AGENT_PROVIDER, common.DEFAULT_MODEL_AGENT_PROVIDER)
Expand All @@ -56,6 +73,11 @@ func (c *ModelConfig) MapEnvToConfig() {
c.Image.ApiBase = utils.GetEnvWithDefault(common.MODEL_IMAGE_API_BASE, common.DEFAULT_MODEL_IMAGE_API_BASE)
c.Image.ApiKey = utils.GetEnvWithDefault(common.MODEL_IMAGE_API_KEY)

// Edit
c.Edit.Name = utils.GetEnvWithDefault(common.MODEL_EDIT_NAME, common.DEFAULT_MODEL_EDIT_NAME)
c.Edit.ApiBase = utils.GetEnvWithDefault(common.MODEL_EDIT_API_BASE, common.DEFAULT_MODEL_EDIT_API_BASE)
c.Edit.ApiKey = utils.GetEnvWithDefault(common.MODEL_EDIT_API_KEY)

// Video
c.Video.Name = utils.GetEnvWithDefault(common.MODEL_VIDEO_NAME, common.DEFAULT_MODEL_VIDEO_NAME)
c.Video.ApiBase = utils.GetEnvWithDefault(common.MODEL_VIDEO_API_BASE, common.DEFAULT_MODEL_VIDEO_API_BASE)
Expand Down
298 changes: 298 additions & 0 deletions tool/builtin_tools/image_edit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package builtin_tools

import (
"fmt"
"sync"

"github.com/volcengine/veadk-go/auth/veauth"
"github.com/volcengine/veadk-go/common"
"github.com/volcengine/veadk-go/configs"
"github.com/volcengine/veadk-go/log"
"github.com/volcengine/veadk-go/utils"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/functiontool"
)

const (
defaultImageEditResponseFormat = "url"
defaultImageEditGuidanceScale = 2.5
defaultImageEditSeed = -1

ImageEditSuccessStatus = "success"
ImageEditErrorStatus = "error"
)

var imageEditToolDescription = `
Edit images in batch according to prompts and optional generation settings.

Args:
params (list of EditImagesRequest)

Per-item schema (EditImagesRequest)
Required:
- origin_image (str): source image URL or Base64 data URL.
- prompt (str): image editing instruction.

Optional:
- image_name (str): output image name. Defaults to generated_image_<index>.
- response_format (str): "url" (default) or "b64_json".
- guidance_scale (float): prompt adherence strength. Defaults to 2.5.
- watermark (bool): whether to add watermark. Defaults to true.
- seed (int): random seed. Defaults to -1.

Returns:
{
"status": "success",
"success_list": [{"image_name": "edited", "url": "..."}],
"error_list": []
}
`

type ImageEditConfig struct {
ModelName string
APIKey string
BaseURL string
}

type ImageEditToolRequest struct {
Params []EditImagesRequest `json:"params"`
}

type EditImagesRequest struct {
ImageName string `json:"image_name,omitempty"`
OriginImage string `json:"origin_image"`
Prompt string `json:"prompt"`
ResponseFormat string `json:"response_format,omitempty"`
GuidanceScale *float64 `json:"guidance_scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}

type ImageEditToolResult struct {
SuccessList []*ImageEditResult `json:"success_list,omitempty"`
ErrorList []*ImageEditResult `json:"error_list,omitempty"`
Status string `json:"status"`
}

type ImageEditResult struct {
ImageName string `json:"image_name"`
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
Error string `json:"error,omitempty"`
}

type ImageEditToolChannelMessage struct {
Status string
Result *ImageEditResult
}

func NewImageEditTool(config *ImageEditConfig) (tool.Tool, error) {
if config == nil {
config = &ImageEditConfig{}
}
if config.ModelName == "" {
config.ModelName = utils.GetEnvWithDefault(common.MODEL_EDIT_NAME, configs.GetGlobalConfig().Model.Edit.Name, common.DEFAULT_MODEL_EDIT_NAME)
}
if config.APIKey == "" {
config.APIKey = resolveImageEditAPIKey()
}
if config.BaseURL == "" {
config.BaseURL = utils.GetEnvWithDefault(common.MODEL_EDIT_API_BASE, configs.GetGlobalConfig().Model.Edit.ApiBase, common.DEFAULT_MODEL_EDIT_API_BASE)
}

log.Debug("Initializing image edit tool", "model", config.ModelName, "base_url", config.BaseURL)

handler := func(ctx tool.Context, toolRequest ImageEditToolRequest) (*ImageEditToolResult, error) {
client := arkruntime.NewClientWithApiKey(
config.APIKey,
arkruntime.WithBaseUrl(config.BaseURL),
)

result := &ImageEditToolResult{}
var wg sync.WaitGroup
ch := make(chan *ImageEditToolChannelMessage)
for i, task := range toolRequest.Params {
wg.Add(1)
go func(index int, req EditImagesRequest) {
defer func() {
wg.Done()
if r := recover(); r != nil {
log.Error("Image edit task panic", "recover", r, "prompt", req.Prompt)
ch <- &ImageEditToolChannelMessage{
Status: ImageEditErrorStatus,
Result: newImageEditErrorResult(defaultImageEditName(index), fmt.Sprintf("task panic: %v", r)),
}
}
}()

imageName := imageEditName(req, index)
if req.Prompt == "" {
ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, "prompt is required")}
return
}
if req.OriginImage == "" {
ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, "origin_image is required")}
return
}

resp, err := client.GenerateImages(ctx, buildImageEditModelRequest(config.ModelName, req))
if err != nil {
log.Error("Failed to edit image", "error", err)
ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, err.Error())}
return
}
if resp.Error != nil {
ch <- &ImageEditToolChannelMessage{Status: ImageEditErrorStatus, Result: newImageEditErrorResult(imageName, resp.Error.Message)}
return
}

messages := imageEditResultsFromResponse(imageName, resp.Data)
for _, message := range messages {
ch <- message
}
}(i, task)
}

go func() {
wg.Wait()
close(ch)
}()

for res := range ch {
switch res.Status {
case ImageEditSuccessStatus:
result.SuccessList = append(result.SuccessList, res.Result)
case ImageEditErrorStatus:
result.ErrorList = append(result.ErrorList, res.Result)
}
}

if len(result.SuccessList) == 0 {
result.Status = ImageEditErrorStatus
} else {
result.Status = ImageEditSuccessStatus
}
return result, nil
}

return functiontool.New(
functiontool.Config{
Name: "image_edit",
Description: imageEditToolDescription,
},
handler)
}

func buildImageEditModelRequest(modelName string, req EditImagesRequest) model.GenerateImagesRequest {
responseFormat := req.ResponseFormat
if responseFormat == "" {
responseFormat = defaultImageEditResponseFormat
}

guidanceScale := defaultImageEditGuidanceScale
if req.GuidanceScale != nil {
guidanceScale = *req.GuidanceScale
}

watermark := true
if req.Watermark != nil {
watermark = *req.Watermark
}

seed := int64(defaultImageEditSeed)
if req.Seed != nil {
seed = *req.Seed
}

return model.GenerateImagesRequest{
Model: modelName,
Prompt: req.Prompt,
Image: req.OriginImage,
ResponseFormat: &responseFormat,
GuidanceScale: &guidanceScale,
Watermark: &watermark,
Seed: &seed,
}
}

func resolveImageEditAPIKey() string {
if key := utils.GetEnvWithDefault(common.MODEL_EDIT_API_KEY, configs.GetGlobalConfig().Model.Edit.ApiKey); key != "" {
return key
}
if key := utils.GetEnvWithDefault(common.MODEL_AGENT_API_KEY, configs.GetGlobalConfig().Model.Agent.ApiKey); key != "" {
return key
}
return utils.Must(veauth.GetArkToken(common.DEFAULT_MODEL_REGION))
}

func imageEditResultsFromResponse(imageName string, images []*model.Image) []*ImageEditToolChannelMessage {
if len(images) == 0 {
return []*ImageEditToolChannelMessage{{
Status: ImageEditErrorStatus,
Result: newImageEditErrorResult(imageName, "no images returned"),
}}
}

messages := make([]*ImageEditToolChannelMessage, 0, len(images))
for i, image := range images {
resultName := imageName
if len(images) > 1 {
resultName = fmt.Sprintf("%s_%d", imageName, i)
}

switch {
case image == nil:
messages = append(messages, &ImageEditToolChannelMessage{
Status: ImageEditErrorStatus,
Result: newImageEditErrorResult(resultName, "empty image result"),
})
case image.Url != nil:
messages = append(messages, &ImageEditToolChannelMessage{
Status: ImageEditSuccessStatus,
Result: &ImageEditResult{ImageName: resultName, Url: *image.Url},
})
case image.B64Json != nil:
messages = append(messages, &ImageEditToolChannelMessage{
Status: ImageEditSuccessStatus,
Result: &ImageEditResult{ImageName: resultName, B64Json: *image.B64Json},
})
default:
messages = append(messages, &ImageEditToolChannelMessage{
Status: ImageEditErrorStatus,
Result: newImageEditErrorResult(resultName, "image url or b64_json is empty"),
})
}
}
return messages
}

func imageEditName(req EditImagesRequest, index int) string {
if req.ImageName != "" {
return req.ImageName
}
return defaultImageEditName(index)
}

func defaultImageEditName(index int) string {
return fmt.Sprintf("generated_image_%d", index)
}

func newImageEditErrorResult(imageName, errMsg string) *ImageEditResult {
return &ImageEditResult{ImageName: imageName, Error: errMsg}
}
Loading
Loading