diff --git a/AGENTS.md b/AGENTS.md index 20d9c68d..fc7b21e0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,4 @@ -This repo is a modern Telegram bot for CSUST built with Go 1.26+, featuring AI chat, message search (MeiliSearch), image generation (Stable Diffusion), gacha systems, and comprehensive permission controls. +This repo is a modern Telegram bot for CSUST built with Go 1.26+, featuring AI chat, message search (MeiliSearch), gacha systems, and comprehensive permission controls. ## Architecture Overview @@ -7,8 +7,8 @@ This repo is a modern Telegram bot for CSUST built with Go 1.26+, featuring AI c - **Bot Framework**: `gopkg.in/telebot.v3` - All commands registered via `bot.Handle()` - **Configuration**: `config.yaml` → structs in `config/` → global `config.BotConfig` - **Data Layer**: `orm/` - Redis-based persistence (NOT a SQL ORM); stores chat state, user lists, caches -- **Queue System**: `store/` - Background task processing (message deletion, SD generation) -- **Feature Packages**: `chat/`, `sd/`, `meili/`, `restrict/`, `base/`, `inline/` +- **Queue System**: `store/` - Background task processing (message deletion) +- **Feature Packages**: `chat/`, `meili/`, `restrict/`, `base/`, `inline/` ### Middleware Pipeline All requests flow through this ordered chain (see `main.go:116-119`): @@ -30,7 +30,6 @@ byeWorldMiddleware → mcMiddleware - `util.PrivateCommand(handler)` - Only in private chats - `util.GroupCommand(handler)` - Only in group chats - `util.GroupCommandCtx(handler)` - Group-only with context tracking -- `whiteMiddleware` - Whitelist enforcement (only for sensitive commands like `/sd`) ### Redis Key Patterns (orm/redis.go) - `wrapKey(key)` - Adds global prefix @@ -41,7 +40,7 @@ byeWorldMiddleware → mcMiddleware ### Async Task Queues (store/) - `TaskQueue[T]` interface: `Push()`, `Cancel()`, `fetch()`, `process()` - Example: `ByeWorldQueue` for delayed message deletion -- Background goroutines: `go sd.Process()`, `store.InitQueues(bot)` +- Background goroutines: `store.InitQueues(bot)` ### Chat Config System (config/chat.go) - Multi-model AI support with templates (Go `text/template`) @@ -174,12 +173,6 @@ The chat module is the core AI conversation system with MCP (Model Context Proto 2. Background: Queue processor pushes to MeiliSearch 3. Search: `/search [-id chatID] [-p page] keyword` with pagination -### Stable Diffusion (sd/) -1. Queue: `ch <- context` (buffered, size 10) -2. Worker: `go sd.Process()` consumes queue -3. HTTP/3: Custom `mixRoundTripper` tries QUIC first, falls back to TCP -4. Rate limiting: `busyUser` map tracks per-user concurrency - ## Pull Request Guidelines 1. **Base branch**: Always create PRs against `dev` (not `master`) 2. **Pre-commit**: Run `make build && make fmt && make test` diff --git a/README.md b/README.md index b1b8ad7c..d120aa47 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,6 @@ A modern Telegram bot for CSUST, developed in Go. - 🤖 AI Chat Conversations (supports multiple models) - 🔍 Message Search (powered by MeiliSearch) -- 🎨 Stable Diffusion Image Generation - 🎲 Gacha System - 🎭 Entertainment Features - 🔧 Flexible Configuration System @@ -153,16 +152,6 @@ gacha - Draw cards according to your configuration getvoice - character= gender= theme= type= ``` -### Stable Diffusion - -``` text -sd - Generate images -sdcfg - Configure SD server -sdcfg - set Set configuration -sdcfg - get Get configuration -sdlast - Get last used prompt -``` - ### Utility Functions ``` text @@ -185,7 +174,6 @@ setiwant - f= vf= sf= Set sticker format - **Database**: Redis - **Search**: MeiliSearch - **AI**: OpenAI API Compatible Interface -- **Image Generation**: Stable Diffusion WebUI - **Containerization**: Docker & Docker Compose ## Development diff --git a/README_zh-CN.md b/README_zh-CN.md index 57daee55..c73c5891 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -21,7 +21,6 @@ - 🤖 AI 聊天对话(支持多种模型) - 🔍 消息搜索(基于 MeiliSearch) -- 🎨 Stable Diffusion 图像生成 - 🎲 抽卡系统 - 🎭 各种娱乐功能 - 🔧 灵活的配置系统 @@ -153,16 +152,6 @@ gacha - 抽卡,按照你的配置 getvoice - 角色= 性别= 主题= 类型= ``` -### Stable Diffusion - -``` text -sd - 生成图片 -sdcfg - 配置SD服务器 -sdcfg - set 设置配置 -sdcfg - get 获取配置 -sdlast - 获取上次使用的prompt -``` - ### 工具功能 ``` text @@ -185,7 +174,6 @@ setiwant - f= vf= sf= 设置我要Sticker - **数据库**: Redis - **搜索**: MeiliSearch - **AI**: OpenAI API 兼容接口 -- **图像生成**: Stable Diffusion WebUI - **容器化**: Docker & Docker Compose ## 开发 diff --git a/go.mod b/go.mod index 8de27254..71c216a6 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/mark3labs/mcp-go v0.49.0 github.com/meilisearch/meilisearch-go v0.36.2 github.com/puzpuzpuz/xsync/v4 v4.5.0 - github.com/quic-go/quic-go v0.59.0 github.com/redis/go-redis/v9 v9.17.3 github.com/sashabaranov/go-openai v1.41.2 github.com/spf13/viper v1.21.0 @@ -49,6 +48,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nikolalohinski/gonja v1.5.3 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/swaggest/jsonschema-go v0.3.78 // indirect @@ -57,10 +57,11 @@ require ( github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yargevad/filepathx v1.0.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.uber.org/mock v0.5.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/net v0.47.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) @@ -73,7 +74,6 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/quic-go/qpack v0.6.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/samber/lo v1.53.0 github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect diff --git a/go.sum b/go.sum index 65ad58f4..525edb79 100644 --- a/go.sum +++ b/go.sum @@ -356,6 +356,7 @@ github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -450,10 +451,6 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/puzpuzpuz/xsync/v4 v4.5.0 h1:vOSWu6b57/emh+L/Cw0BeQfvxa/cogFywXHeGUxQxAg= github.com/puzpuzpuz/xsync/v4 v4.5.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= -github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= -github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= -github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= -github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -684,8 +681,6 @@ golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220520000938-2e3eb7b945c2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= diff --git a/main.go b/main.go index a6e1bf69..e56aebb5 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( "csust-got/chatv2" "csust-got/inline" "csust-got/meili" - "csust-got/sd" "csust-got/store" "csust-got/util/gacha" "encoding/json" @@ -67,17 +66,12 @@ func main() { registerRestrictHandler(bot) registerEventHandler(bot) registerChatConfigHandler(bot) - bot.Handle("/sd", sd.Handler, whiteMiddleware) - bot.Handle("/sdcfg", sd.ConfigHandler) - bot.Handle("/sdlast", sd.LastPromptHandler) // inline mode inline.RegisterInlineHandler(bot, config.BotConfig) meili.InitMeili() - go sd.Process() - base.Init() store.InitQueues(bot) @@ -406,26 +400,6 @@ func rateMiddleware(next HandlerFunc) HandlerFunc { } } -func whiteMiddleware(next HandlerFunc) HandlerFunc { - return func(ctx Context) error { - if !config.BotConfig.WhiteListConfig.Enabled { - return next(ctx) - } - - m := ctx.Message() - // continue with inline query - if m == nil && ctx.Query() != nil { - return next(ctx) - } - - if ctx.Chat() != nil && !config.BotConfig.WhiteListConfig.Check(ctx.Chat().ID) { - log.Info("chat ignore by white list", zap.String("chat", ctx.Chat().Title)) - return nil - } - return next(ctx) - } -} - func noStickerMiddleware(next HandlerFunc) HandlerFunc { return func(ctx Context) error { m := ctx.Message() diff --git a/orm/redis.go b/orm/redis.go index 523966d8..93358784 100644 --- a/orm/redis.go +++ b/orm/redis.go @@ -534,60 +534,6 @@ func GetTargetState(target string) bool { return r == "1" } -// SetSDConfig set stable diffusion config. -func SetSDConfig(userID int64, cfg string) error { - err := rc.Set(context.TODO(), wrapKeyWithUser("stable_diffusion_config", userID), cfg, 0).Err() - if err != nil { - log.Error("set stable diffusion config to redis failed", zap.Int64("user", userID), zap.String("config", cfg), zap.Error(err)) - return err - } - return nil -} - -// GetSDConfig get stable diffusion config. -func GetSDConfig(userID int64) (string, error) { - cfg, err := rc.Get(context.TODO(), wrapKeyWithUser("stable_diffusion_config", userID)).Result() - if err != nil { - if !errors.Is(err, redis.Nil) { - log.Error("get stable diffusion config from redis failed", zap.Int64("user", userID), zap.Error(err)) - } - return "", err - } - return cfg, nil -} - -// SetSDLastPrompt save user's last stable diffusion prompt. -func SetSDLastPrompt(userID int64, lastPrompt string) error { - err := rc.Set(context.TODO(), wrapKeyWithUser("stable_diffusion_last_prompt", userID), lastPrompt, 0).Err() - if err != nil { - log.Error("set stable diffusion last prompt to redis failed", zap.Int64("user", userID), zap.String("lastPrompt", lastPrompt), zap.Error(err)) - return err - } - return nil -} - -// GetSDLastPrompt get user's last stable diffusion prompt. -func GetSDLastPrompt(userID int64) (string, error) { - lastPrompt, err := rc.Get(context.TODO(), wrapKeyWithUser("stable_diffusion_last_prompt", userID)).Result() - if err != nil { - if !errors.Is(err, redis.Nil) { - log.Error("get stable diffusion last prompt from redis failed", zap.Int64("user", userID), zap.Error(err)) - return "", err - } - return "", err - } - return lastPrompt, nil -} - -// GetSDDefaultServer get stable diffusion default server from redis. -func GetSDDefaultServer() string { - defaultServer, err := rc.Get(context.TODO(), wrapKey("stable_diffusion::default_server")).Result() - if err != nil { - return "" - } - return defaultServer -} - // SetChatContext save user's chat context with GPT to redis. func SetChatContext(chatID int64, msgID int, chatContext []openai.ChatCompletionMessage) error { if len(chatContext) == 0 { diff --git a/sd/cfg.go b/sd/cfg.go deleted file mode 100644 index 6dc2f703..00000000 --- a/sd/cfg.go +++ /dev/null @@ -1,368 +0,0 @@ -package sd - -import ( - "csust-got/entities" - "csust-got/orm" - "csust-got/util" - "encoding/json" - "errors" - "fmt" - "strconv" - "strings" - - "github.com/redis/go-redis/v9" - . "gopkg.in/telebot.v3" -) - -// StableDiffusionConfig is the config of stable diffusion. -type StableDiffusionConfig struct { - Server string `json:"server"` - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt"` - Steps int `json:"steps"` - Scale int `json:"scale"` - Width int `json:"width"` - Height int `json:"height"` - Number int `json:"number"` - Sampler string `json:"sampler"` - - HiResEnabled string `json:"hr"` - DenoisingStrength float64 `json:"denoising_strength"` - HiResScale float64 `json:"hr_scale"` - HiResUpscaler string `json:"hr_upscaler"` - HiResSecondPassSteps int `json:"hr_second_pass_steps"` -} - -// GetValueByKey get value by key. -func (c *StableDiffusionConfig) GetValueByKey(key string) any { - switch key { - case "server": - return "🤫" - case "prompt": - if c.Prompt == "" { - return "masterpiece, best quality" - } - return c.Prompt - case "negative_prompt": - if c.NegativePrompt == "" { - return "nsfw, lowres, bad anatomy, bad hands, (((deformed))), [blurry], (poorly drawn hands), (poorly drawn feet), " + - "text, error, missing fingers, extra digit, " + - "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" - } - return c.NegativePrompt - case "steps": - if c.Steps == 0 { - return 28 - } - return c.Steps - case "scale": - if c.Scale == 0 { - return 7 - } - return c.Scale - case "width": - if c.Width == 0 { - return 512 - } - return c.Width - case "height": - if c.Height == 0 { - return 512 - } - return c.Height - case "res": - return fmt.Sprintf("%dx%d", c.GetValueByKey("width"), c.GetValueByKey("height")) - case "number": - if c.Number == 0 { - return 1 - } - return c.Number - case "sampler": - if c.Sampler == "" { - return "Euler a" - } - return c.Sampler - case "hr": - if c.HiResEnabled == "" { - return "off" - } - return c.HiResEnabled - case "denoising_strength": - if c.DenoisingStrength == 0 { - return 0.6 - } - return c.DenoisingStrength - case "hr_scale": - if c.HiResScale == 0 { - return 2.0 - } - return c.HiResScale - case "hr_upscaler": - if c.HiResUpscaler == "" { - return "Latent" - } - return c.HiResUpscaler - case "hr_second_pass_steps": - if c.HiResSecondPassSteps == 0 { - return 20 - } - return c.HiResSecondPassSteps - default: - return "key not exists" - } -} - -// SetValueByKey set config value by key. -func (c *StableDiffusionConfig) SetValueByKey(key string, value string) error { - switch key { - case "server": - if value == "*" { - c.Server = "" - } else { - c.Server = strings.TrimSuffix(value, "/") - } - case "prompt": - c.Prompt = value - case "negative_prompt": - c.NegativePrompt = value - case "steps": - steps, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: steps must be a integer", ErrConfigIsInvalid) - } - c.Steps = steps - if c.Steps < 1 || c.Steps > 50 { - return fmt.Errorf("%w: steps too small or too large", ErrConfigIsInvalid) - } - case "scale": - scale, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: scale must be a integer", ErrConfigIsInvalid) - } - c.Scale = scale - if c.Scale < 1 || c.Scale > 20 { - return fmt.Errorf("%w: scale too small or too large", ErrConfigIsInvalid) - } - case "width": - width, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: width must be a integer", ErrConfigIsInvalid) - } - c.Width = width / 64 * 64 - if c.Width < 1 || c.Width > 1024 { - return fmt.Errorf("%w: width too small or too large", ErrConfigIsInvalid) - } - case "height": - height, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: height must be a integer", ErrConfigIsInvalid) - } - c.Height = height / 64 * 64 - if c.Height < 1 || c.Height > 1024 { - return fmt.Errorf("%w: height too small or too large", ErrConfigIsInvalid) - } - case "res": - res := strings.Split(value, "x") - if len(res) != 2 { - return fmt.Errorf("%w: invalid resolution", ErrConfigIsInvalid) - } - if err := c.SetValueByKey("width", res[0]); err != nil { - return err - } - if err := c.SetValueByKey("height", res[1]); err != nil { - return err - } - case "number": - number, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: number must be a integer", ErrConfigIsInvalid) - } - c.Number = number - if c.Number < 1 || c.Number > 4 { - return fmt.Errorf("%w: number too small or too large", ErrConfigIsInvalid) - } - case "sampler": - if value == "*" { - value = "Euler a" - } - c.Sampler = value - case "hr": - if value == "on" { - c.HiResEnabled = "on" - } else { - c.HiResEnabled = "off" - } - case "denoising_strength": - denoisingStrength, err := strconv.ParseFloat(value, 64) - if err != nil { - return fmt.Errorf("%w: denoising_strength must be a float", ErrConfigIsInvalid) - } - c.DenoisingStrength = denoisingStrength - if c.DenoisingStrength < 0 || c.DenoisingStrength > 1 { - return fmt.Errorf("%w: denoising_strength must be between 0 and 1", ErrConfigIsInvalid) - } - case "hr_scale": - hrScale, err := strconv.ParseFloat(value, 64) - if err != nil { - return fmt.Errorf("%w: hr_scale must be a float", ErrConfigIsInvalid) - } - c.HiResScale = hrScale - if c.HiResScale < 1 || c.HiResScale > 4 { - return fmt.Errorf("%w: hr_scale must be between 1 and 4", ErrConfigIsInvalid) - } - case "hr_upscaler": - if value == "*" { - value = "Latent" - } - c.HiResUpscaler = value - case "hr_second_pass_steps": - hrSecondPassSteps, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("%w: hr_second_pass_steps must be a integer", ErrConfigIsInvalid) - } - c.HiResSecondPassSteps = hrSecondPassSteps - if c.HiResSecondPassSteps < 0 || c.HiResSecondPassSteps > 50 { - return fmt.Errorf("%w: hr_second_pass_steps too small or too large", ErrConfigIsInvalid) - } - default: - return fmt.Errorf("%w: invalid key: %s", ErrConfigIsInvalid, key) - } - return nil -} - -// GetServer return server. -func (c *StableDiffusionConfig) GetServer() string { - server := c.Server - if server == "" { - server = orm.GetSDDefaultServer() - } - server = strings.TrimSuffix(server, "/") - return server -} - -// GenStableDiffusionRequest generate stable diffusion request by config. -func (c *StableDiffusionConfig) GenStableDiffusionRequest() *StableDiffusionReq { - req := &StableDiffusionReq{ - Prompt: c.GetValueByKey("prompt").(string), - NegativePrompt: c.GetValueByKey("negative_prompt").(string), - Steps: c.GetValueByKey("steps").(int), - CfgScale: c.GetValueByKey("scale").(int), - Width: c.GetValueByKey("width").(int), - Height: c.GetValueByKey("height").(int), - BatchSize: c.GetValueByKey("number").(int), - SamplerIndex: c.GetValueByKey("sampler").(string), - } - if c.GetValueByKey("hr").(string) == "on" { - req.HiResEnabled = true - req.DenoisingStrength = c.GetValueByKey("denoising_strength").(float64) - req.HiResScale = c.GetValueByKey("hr_scale").(float64) - req.HiResUpscaler = c.GetValueByKey("hr_upscaler").(string) - req.HiResSecondPassSteps = c.GetValueByKey("hr_second_pass_steps").(int) - req.BatchSize = 1 - } - return req -} - -const helpInfo = "sdcfg set \\ \\\n" + - "sdcfg get \\\n" + - "available keys: \n" + - "`server`: your own stable diffusion server address\\(write only\\)\\.\n" + - "`prompt`: your default prompt, will add to your every command call\\.\n" + - "`negative_prompt`: your default negative prompt, will add to your every command call\\.\n" + - "`steps`: steps for stable diffusion\\.\n" + - "`scale`: scale for stable diffusion\\.\n" + - "`res`: resolution __width__x__height__\\.\n" + - "`number`: number of images for once command call\\.\n" + - "`sampler`: sampler for stable diffusion, default is `Euler a`\\.\n" + - "`hr`: high resolution fix `on`/`off`, will force `number` to 1\\.\n" + - "`denoising_strength`: denoising strength for high resolution\\.\n" + - "`hr_scale`: high resolution scale\\.\n" + - "`hr_upscaler`: high resolution upscaler, default is `Latent`\\.\n" + - "`hr_second_pass_steps`: high resolution fix steps\\." - -const ( - sdSubCmdSet = "set" - sdSubCmdGet = "get" -) - -// ConfigHandler handle /sdcfg command. -func ConfigHandler(ctx Context) error { - command := entities.FromMessage(ctx.Message()) - - if command.Argc() == 0 { - _, err := util.ReplyWithError(ctx, util.RawTgText(helpInfo), ModeMarkdownV2) - return err - } - - userID := ctx.Sender().ID - config, err := getConfigByUserID(userID) - if err != nil { - return ctx.Reply("完了,删库跑路了") - } - - var mode, key, value string - switch command.Arg(0) { - case sdSubCmdSet: - if command.Argc() < 3 { - _, err := util.ReplyWithError(ctx, util.RawTgText(helpInfo), ModeMarkdownV2) - return err - } - mode = sdSubCmdSet - key = command.Arg(1) - value = command.ArgAllInOneFrom(2) - case sdSubCmdGet: - if command.Argc() < 2 { - _, err := util.ReplyWithError(ctx, util.RawTgText(helpInfo), ModeMarkdownV2) - return err - } - mode = sdSubCmdGet - key = command.Arg(1) - default: - if command.Argc() == 1 { - mode = sdSubCmdGet - key = command.Arg(0) - } else { - mode = sdSubCmdSet - key = command.Arg(0) - value = command.ArgAllInOneFrom(1) - } - } - - switch mode { - case sdSubCmdSet: - err = config.SetValueByKey(key, value) - if err != nil { - return ctx.Reply(err.Error()) - } - configStr, err := json.MarshalIndent(&config, "", "") - if err != nil { - return ctx.Reply("感觉有点问题") - } - err = orm.SetSDConfig(userID, string(configStr)) - if err != nil { - return ctx.Reply("完了,删库跑路了") - } - return ctx.Reply("配置保存成功") - case sdSubCmdGet: - _, err := util.ReplyWithError(ctx, util.RawTgText(fmt.Sprintf("`%s`", util.EscapeTgMDv2ReservedChars(fmt.Sprint(config.GetValueByKey(key))))), ModeMarkdownV2) - return err - } - - _, err = util.ReplyWithError(ctx, util.RawTgText(helpInfo), ModeMarkdownV2) - return err -} - -func getConfigByUserID(userID int64) (*StableDiffusionConfig, error) { - config := &StableDiffusionConfig{} - configStr, err := orm.GetSDConfig(userID) - if err != nil && !errors.Is(err, redis.Nil) { - return config, err - } - if err == nil { - err = json.Unmarshal([]byte(configStr), &config) - if err != nil { - return config, err - } - } - return config, nil -} diff --git a/sd/context.go b/sd/context.go deleted file mode 100644 index e57ddabc..00000000 --- a/sd/context.go +++ /dev/null @@ -1,10 +0,0 @@ -package sd - -import . "gopkg.in/telebot.v3" - -// StableDiffusionContext is the context of stable diffusion worker. -type StableDiffusionContext struct { - BotContext Context - UserConfig StableDiffusionConfig - Request StableDiffusionReq -} diff --git a/sd/err.go b/sd/err.go deleted file mode 100644 index 9e7864ff..00000000 --- a/sd/err.go +++ /dev/null @@ -1,12 +0,0 @@ -package sd - -import "errors" - -// sd error -var ( - ErrServerNotConfigured = errors.New("server not configured") - ErrServerNotAvailable = errors.New("server not available") - ErrConfigKeyNotSupport = errors.New("config key not support") - ErrConfigIsInvalid = errors.New("config is invalid") - ErrRequestNotOK = errors.New("request not ok") -) diff --git a/sd/sd.go b/sd/sd.go deleted file mode 100644 index 4abd5be8..00000000 --- a/sd/sd.go +++ /dev/null @@ -1,433 +0,0 @@ -package sd - -import ( - "bytes" - "context" - "csust-got/entities" - "csust-got/log" - "csust-got/orm" - "csust-got/util" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "strings" - "sync" - "time" - - "github.com/quic-go/quic-go" - - "github.com/quic-go/quic-go/http3" - "go.uber.org/zap" - . "gopkg.in/telebot.v3" -) - -var ( - mu sync.Mutex - ch = make(chan *StableDiffusionContext, 10) - busyUser = make(map[int64]int) -) - -var httpClient *http.Client - -type mixRoundTripper struct { - TraditionalRoundTripper http.RoundTripper - H3RoundTripper http.RoundTripper -} - -func newMixRoundTripper(t http.RoundTripper, h3 *http3.Transport) *mixRoundTripper { - return &mixRoundTripper{ - TraditionalRoundTripper: t, - H3RoundTripper: h3, - } -} - -func (r *mixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if r.H3RoundTripper != nil { - log.Debug("try h3", zap.String("url", util.ReplaceSpace(req.URL.String()))) - resp, err := r.H3RoundTripper.RoundTrip(req) - if err == nil { - return resp, nil - } - log.Debug("h3 failed", zap.Error(err)) - } - return r.TraditionalRoundTripper.RoundTrip(req) -} - -func init() { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.IdleConnTimeout = 3 * time.Minute - transport.ResponseHeaderTimeout = 3 * time.Minute - - dialer := net.Dialer{ - KeepAlive: 10 * time.Second, - } - transport.DialContext = dialer.DialContext - - h3RoundTripper := &http3.Transport{ - QUICConfig: &quic.Config{ - MaxIdleTimeout: 3 * time.Minute, - KeepAlivePeriod: 10 * time.Second, - }, - } - - httpClient = &http.Client{ - Transport: newMixRoundTripper(transport, h3RoundTripper), - } -} - -// Handler stable diffusion handler. -func Handler(ctx Context) error { - if !mu.TryLock() { - return ctx.Reply("忙不过来了") - } - defer mu.Unlock() - - command := entities.FromMessage(ctx.Message()) - - userID := ctx.Sender().ID - config, err := getConfigByUserID(userID) - if err != nil { - return ctx.Reply("完了,删库跑路了") - } - - if config.GetServer() == "" { - return ctx.Reply("喂喂喂,你还没有配置服务器好吧。" + - "快使用 /sdcfg 配置一个属于自己的服务器,或者找好心人捐赠一个服务器吧") - } - - prompt := command.ArgAllInOneFrom(0) - prompt = strings.ReplaceAll(prompt, ",", ",") - if prompt == "" { - prompt, _ = orm.GetSDLastPrompt(userID) - } else { - _ = orm.SetSDLastPrompt(userID, prompt) - } - - req := config.GenStableDiffusionRequest() - req.Prompt += ", " + prompt - - if busyUser[userID] >= 3 { - return ctx.Reply("听我说你先别急,你还有3个没画完") - } - - select { - case ch <- &StableDiffusionContext{ - BotContext: ctx, - UserConfig: *config, - Request: *req, - }: - busyUser[userID]++ - msg := "在画了在画了" - if req.HiResEnabled { - msg += ",高清修复已开启,可能会比较慢,耐心等待一下~" - } - return ctx.Reply(msg) - default: - return ctx.Reply("忙不过来了") - } - -} - -// Process is the stable diffusion background worker. -func Process() { - lock := new(sync.Mutex) - inUsedServer := make(map[string]chan *StableDiffusionContext) - maxWorker := make(chan struct{}, 10) - - for ctx := range ch { - select { - case maxWorker <- struct{}{}: - // Do nothing - default: - err := ctx.BotContext.Reply("任务堆积太多,忙不过来了。") - if err != nil { - log.Error("reply error", zap.Error(err)) - } - continue - } - - server := ctx.UserConfig.GetServer() - - lock.Lock() - serverCh, ok := inUsedServer[server] - if !ok { - serverCh = make(chan *StableDiffusionContext, 10) - inUsedServer[ctx.UserConfig.GetServer()] = serverCh - } - lock.Unlock() - serverCh <- ctx - - processFn := func() { - for { - select { - case ctx := <-serverCh: - func() { - ctx := ctx - defer func() { - <-maxWorker - - mu.Lock() - busyUser[ctx.BotContext.Sender().ID]-- - mu.Unlock() - }() - resp, err := requestStableDiffusion(ctx.UserConfig.GetServer(), &ctx.Request) - if err != nil { - err = ctx.BotContext.Reply("寄了") - if err != nil { - log.Error("reply stable diffusion failed", zap.Error(err)) - } - return - } - - photos := Album{} - for _, v := range resp.Images { - var data []byte - data, err = base64.StdEncoding.DecodeString(v) - if err != nil { - log.Error("decode stable diffusion image failed", zap.Error(err)) - continue - } - photos = append(photos, &Photo{ - File: File{FileReader: bytes.NewReader(data)}, - }) - } - - err = ctx.BotContext.SendAlbum(photos) - if err != nil { - log.Error("send stable diffusion album failed", zap.Error(err)) - err = ctx.BotContext.Reply("非常的寄") - if err != nil { - log.Error("reply stable diffusion failed", zap.Error(err)) - } - return - } - }() - default: - lock.Lock() - if len(serverCh) == 0 { - delete(inUsedServer, server) - lock.Unlock() - return - } - lock.Unlock() - } - } - } - - if !ok { - go processFn() - } - } - -} - -/* - { - "enable_hr": false, - "denoising_strength": 0, - "firstphase_width": 0, - "firstphase_height": 0, - "prompt": "", - "styles": [ - "string" - ], - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 50, - "cfg_scale": 7, - "width": 512, - "height": 512, - "restore_faces": false, - "tiling": false, - "negative_prompt": "string", - "eta": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 1, - "override_settings": {}, - "sampler_index": "Euler" - } -*/ - -// StableDiffusionReq is the request body of stable diffusion. -type StableDiffusionReq struct { - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt"` - Steps int `json:"steps"` - CfgScale int `json:"cfg_scale"` - Width int `json:"width"` - Height int `json:"height"` - BatchSize int `json:"batch_size"` - SamplerIndex string `json:"sampler_index"` - - HiResEnabled bool `json:"enable_hr"` - DenoisingStrength float64 `json:"denoising_strength"` - HiResScale float64 `json:"hr_scale"` - HiResUpscaler string `json:"hr_upscaler"` - HiResSecondPassSteps int `json:"hr_second_pass_steps"` -} - -/* -{ - "images": [ - ], - "parameters": { - "enable_hr": false, - "denoising_strength": 0, - "firstphase_width": 0, - "firstphase_height": 0, - "prompt": "girl", - "styles": null, - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 50, - "cfg_scale": 7, - "width": 512, - "height": 512, - "restore_faces": false, - "tiling": false, - "negative_prompt": null, - "eta": null, - "s_churn": 0, - "s_tmax": null, - "s_tmin": 0, - "s_noise": 1, - "override_settings": null, - "sampler_index": "Euler" - }, - "info": { - "prompt": "girl", - "all_prompts": [ - "girl" - ], - "negative_prompt": "", - "seed": 327883780, - "all_seeds": [ - 327883780 - ], - "subseed": 887306102, - "all_subseeds": [ - 887306102 - ], - "subseed_strength": 0, - "width": 512, - "height": 512, - "sampler_index": 1, - "sampler": "Euler", - "cfg_scale": 7, - "steps": 50, - "batch_size": 1, - "restore_faces": false, - "face_restoration_model": null, - "sd_model_hash": "e6e8e1fc", - "seed_resize_from_w": -1, - "seed_resize_from_h": -1, - "denoising_strength": 0, - "extra_generation_params": {}, - "index_of_first_image": 0, - "infotexts": [ - "girl\nSteps: 50, Sampler: Euler, CFG scale: 7.0, Seed: 327883780, Size: 512x512, Model hash: e6e8e1fc, - Seed resize from: -1x-1, Denoising strength: 0, Clip skip: 2" - ], - "styles": [], - "job_timestamp": "0", - "clip_skip": 2 - } -} -*/ - -// StableDiffusionResp is the response of stable diffusion -type StableDiffusionResp struct { - Images []string `json:"images"` -} - -func requestStableDiffusion(addr string, req *StableDiffusionReq) (*StableDiffusionResp, error) { - if addr == "" { - return nil, ErrServerNotConfigured - } - - bs, err := json.Marshal(req) - if err != nil { - log.Error("marshal stable diffusion request failed", zap.Error(err)) - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) - defer cancel() - httpReq, err := http.NewRequest("POST", addr+"/sdapi/v1/txt2img", bytes.NewReader(bs)) - if err != nil { - log.Error("create stable diffusion request failed", zap.Error(err)) - return nil, err - } - httpReq = httpReq.WithContext(ctx) - httpReq.Header.Set("Content-Type", "application/json") - // httpReq.Header.Set("Expect", "100-continue") - - resp, err := httpClient.Do(httpReq) - if err != nil { - log.Error("request stable diffusion failed", zap.Error(err)) - return nil, fmt.Errorf("request stable diffusion failed: %w", ErrServerNotAvailable) - } - defer func() { _ = resp.Body.Close() }() - - bts, err := io.ReadAll(resp.Body) - if err != nil { - log.Error("read stable diffusion response body failed", zap.Error(err)) - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Error("stable diffusion response status code is not 200", - zap.Int("status code", resp.StatusCode), zap.String("response body", string(bts))) - return nil, fmt.Errorf("%w: request stable diffusion failed, status code: %d, response: %s", - ErrRequestNotOK, resp.StatusCode, string(bts)) - } - - var respData StableDiffusionResp - err = json.Unmarshal(bts, &respData) - if err != nil { - log.Error("unmarshal stable diffusion response failed", zap.Error(err)) - return nil, err - } - - return &respData, nil -} - -// nolint: unused // for some reason -func joinApi(baseUrl, path string) string { - if baseUrl == "" { - return "" - } - baseUrl = strings.TrimSuffix(baseUrl, "/") - return baseUrl + path -} - -// LastPromptHandler is the handler of last prompt. -func LastPromptHandler(ctx Context) error { - prompt, err := orm.GetSDLastPrompt(ctx.Message().Sender.ID) - if err != nil { - log.Error("get last prompt failed", zap.Error(err)) - return ctx.Reply("Maybe I forgot what you said last time.") - } - - if prompt == "" { - return ctx.Reply("You haven't used stable diffusion yet.") - } - - _, err = util.ReplyWithError(ctx, util.RawTgText("Your last prompt is:\n`"+util.EscapeTgMDv2ReservedChars(prompt)+"`"), ModeMarkdownV2) - return err -}