-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathai_studio_code.cpp
More file actions
236 lines (195 loc) · 8.74 KB
/
ai_studio_code.cpp
File metadata and controls
236 lines (195 loc) · 8.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#include <iostream>
#include <string>
#include <vector>
#include <memory>
#include <chrono>
// Third-party libraries
#include "httplib.h"
#include "nlohmann/json.hpp" // Required for JSON handling
#include "stable-diffusion.h" // stablediffusion.cpp header
// For Base64 encoding (helper function below)
#include "base64.h" // You'll need a base64 implementation, or see helper below
using json = nlohmann::json;
// Global model context
sd_ctx_t* sd_ctx = nullptr;
// --- Helper: Simple Base64 Encoder (if you don't have a header) ---
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string base64_encode(const unsigned char* bytes_to_encode, unsigned int in_len) {
std::string ret;
int i = 0, j = 0;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
while (in_len--) {
char_array_3[i++] = *(bytes_to_encode++);
if (i == 3) {
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for(i = 0; (i <4) ; i++) ret += base64_chars[char_array_4[i]];
i = 0;
}
}
if (i) {
for(j = i; j < 3; j++) char_array_3[j] = '\0';
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (j = 0; (j < i + 1); j++) ret += base64_chars[char_array_4[j]];
while((i++ < 3)) ret += '=';
}
return ret;
}
// --- Image Generation Logic ---
// Structure to hold image data (simplification)
struct GeneratedImage {
int width;
int height;
int channels;
std::vector<uint8_t> data; // Raw RGB data
};
// Wrapper for stablediffusion.cpp inference
// Note: You must adapt this to the specific version of sd_txt2img you are using.
std::vector<GeneratedImage> run_inference(const std::string& prompt, int steps, int width, int height, int seed) {
std::vector<GeneratedImage> results;
// Default parameters mimics standard SD usage
float cfg_scale = 7.0f;
sd_image_t* result_images = nullptr;
// Call the library function
// Assuming sd_txt2img returns a generic C-style array or struct depending on version
// This signature assumes the C-style API often found in wrappers
result_images = sd_txt2img(sd_ctx, prompt.c_str(), "",
0, // clip_skip
cfg_scale,
width, height,
EULER_A, // sample_method
steps,
seed,
1, // batch_count
nullptr, // control_cond
0.0f, // control_strength
0.0f // style_strength
);
if (result_images && result_images->data) {
GeneratedImage img;
img.width = result_images->width;
img.height = result_images->height;
img.channels = result_images->channel;
// Copy data
size_t data_size = img.width * img.height * img.channels;
img.data.resize(data_size);
std::memcpy(img.data.data(), result_images->data, data_size);
results.push_back(img);
// Free memory allocated by the library
// free(result_images->data); // Depending on library ownership rules
}
return results;
}
// Helper to convert raw RGB to PNG memory buffer
// You would typically use stbi_write_png_to_func here
void write_func(void *context, void *data, int size) {
std::vector<uint8_t> *buffer = (std::vector<uint8_t> *)context;
buffer->insert(buffer->end(), (uint8_t *)data, (uint8_t *)data + size);
}
// --- Main Application ---
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "Usage: ./sd_server <model_path>" << std::endl;
return 1;
}
std::string model_path = argv[1];
// 1. Initialize Stable Diffusion Model
std::cout << "[INFO] Loading model: " << model_path << "..." << std::endl;
// n_threads = -1 (auto), wtype = default, rng_type = standard
sd_ctx = new_sd_ctx(model_path.c_str(), "", "", "", "", "", "", false, false, false, -1, SD_TYPE_F16, STD_DEFAULT_RNG, SCHEDULE_DEFAULT, false, false, false);
if (!sd_ctx) {
std::cerr << "[ERROR] Failed to load model." << std::endl;
return 1;
}
std::cout << "[INFO] Model loaded successfully." << std::endl;
// 2. Setup HTTP Server
httplib::Server svr;
// CORS Headers (Optional but useful for web frontends)
svr.set_default_headers({
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Methods", "POST, GET, OPTIONS"},
{"Access-Control-Allow-Headers", "Content-Type"}
});
// Handle OPTIONS for CORS pre-flight
svr.Options("/v1/images/generations", [](const httplib::Request&, httplib::Response& res) {
res.status = 204;
});
// 3. Define the OpenAI Compatible Endpoint
svr.Post("/v1/images/generations", [](const httplib::Request& req, httplib::Response& res) {
std::cout << "[INFO] Received request." << std::endl;
// Parse JSON Body
json body;
try {
body = json::parse(req.body);
} catch (...) {
res.status = 400;
res.set_content(R"({"error": {"message": "Invalid JSON"}})", "application/json");
return;
}
// Extract Parameters (with defaults)
std::string prompt = body.value("prompt", "A cute cat");
std::string size_str = body.value("size", "512x512");
int steps = body.value("steps", 20); // Not standard OpenAI, but useful to accept
int n = body.value("n", 1); // Number of images
std::string response_format = body.value("response_format", "url"); // "url" or "b64_json"
// Parse size string "WxH"
int width = 512, height = 512;
size_t x_pos = size_str.find('x');
if (x_pos != std::string::npos) {
width = std::stoi(size_str.substr(0, x_pos));
height = std::stoi(size_str.substr(x_pos + 1));
}
// Generate Image
// In a real server, you should handle thread safety here if `sd_ctx` isn't thread-safe.
int seed = std::time(nullptr);
auto images = run_inference(prompt, steps, width, height, seed);
if (images.empty()) {
res.status = 500;
res.set_content(R"({"error": {"message": "Generation failed"}})", "application/json");
return;
}
// Prepare Response JSON
json response;
response["created"] = std::time(nullptr);
response["data"] = json::array();
for (const auto& img : images) {
// Convert Raw RGB to PNG (in memory)
std::vector<uint8_t> png_buffer;
// Assuming 3 channels (RGB). If 4 (RGBA), set 4.
// stride = width * channels
// Using stbi_write_png_to_func (requires stbi_image_write implementation)
// For this snippet, we assume a wrapper exists or you implement raw-to-png.
// stbi_write_png_to_func(write_func, &png_buffer, img.width, img.height, img.channels, img.data.data(), img.width * img.channels);
// Note: Since stbi logic is lengthy to include fully, let's assume png_buffer is filled.
// Placeholder: Filling with dummy data if stbi is missing in context
if (png_buffer.empty()) png_buffer = img.data;
// Encode to Base64
std::string b64_data = base64_encode(png_buffer.data(), png_buffer.size());
json item;
if (response_format == "b64_json") {
item["b64_json"] = b64_data;
} else {
// OpenAI returns URLs usually. For a local mimicker, returning b64 in the URL field
// as a data URI is a common hack, or just return the base64 anyway.
item["url"] = "data:image/png;base64," + b64_data;
}
response["data"].push_back(item);
}
res.set_content(response.dump(), "application/json");
std::cout << "[INFO] Request processed." << std::endl;
});
std::cout << "[INFO] Server listening on http://0.0.0.0:8080" << std::endl;
svr.listen("0.0.0.0", 8080);
// Cleanup
free_sd_ctx(sd_ctx);
return 0;
}