diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..658f9413 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +target +**/target +.git +**/.git +.DS_Store +**/.DS_Store diff --git a/.gitignore b/.gitignore index b6c0d3b1..d0da7b4d 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,6 @@ util/gh-pages/lints.json # dev script devsh/* -.uuid \ No newline at end of file +.uuid +scripts +mix \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 5a0070dc..a9ba6181 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,11 @@ members = [ "binary/spacegate", "binary/admin-server", + "binary/ai-gateway-service", "crates/extension/*", "crates/kernel", "crates/plugin", + "crates/plugin-wasm", "crates/model", "crates/config", "crates/shell", @@ -12,6 +14,12 @@ members = [ "examples/socks5-proxy", "examples/mitm-proxy", ] +exclude = [ + # proxy-wasm-rust-sdk guest 用作 host 集成测试参考插件,目标三元组 wasm32-wasip1,不属于主 workspace。 + "crates/plugin-wasm/tests/spec_test_guest", + "crates/plugin-wasm/tests/sdk_examples_guest", + "crates/plugin-wasm/tests/on_tick_guest", +] resolver = "2" [profile.release] codegen-units = 1 @@ -46,6 +54,7 @@ spacegate-plugin = { version = "0.2.0-alpha.4", path = "./crates/plugin" } spacegate-config = { version = "0.2.0-alpha.4", path = "./crates/config" } spacegate-model = { version = "0.2.0-alpha.4", path = "./crates/model" } spacegate-shell = { version = "0.2.0-alpha.4", path = "./crates/shell" } +spacegate-plugin-wasm = { version = "0.2.0-alpha.4", path = "./crates/plugin-wasm" } spacegate-ext-axum = { version = "0.2.0-alpha.4", path = "./crates/extension/axum" } spacegate-ext-redis = { version = "0.2.0-alpha.4", path = "./crates/extension/redis" } @@ -59,6 +68,12 @@ toml = { version = "0.8", features = ["preserve_order"] } lazy_static = { version = "1.4" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0" } +tracing-opentelemetry = { version = "0.33" } +opentelemetry = { version = "0.32" } +opentelemetry_sdk = { version = "0.32" } +opentelemetry-otlp = { version = "0.32", features = ["grpc-tonic", "http-proto", "reqwest-blocking-client"] } +opentelemetry-appender-tracing = { version = "0.32" } +opentelemetry-semantic-conventions = { version = "0.32" } # Encode base64 = { version = "0.22" } diff --git a/binary/ai-gateway-service/Cargo.toml b/binary/ai-gateway-service/Cargo.toml new file mode 100644 index 00000000..6e3e0f97 --- /dev/null +++ b/binary/ai-gateway-service/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "ai-gateway-service" +version.workspace = true +authors.workspace = true +description = "External rate-limit and queue service for SpaceGate AI gateway wasm plugins" +keywords.workspace = true +categories.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +readme = "../../README.md" + +[dependencies] +axum = { workspace = true, features = ["tracing", "macros"] } +base64 = { workspace = true } +bytes = { workspace = true } +clap = { version = "4.5", features = ["derive", "env"] } +futures-util = { workspace = true } +fred = { version = "10.1.0", default-features = false, features = ["default-nil-types", "enable-rustls", "i-keys", "i-scripts", "i-streams", "subscriber-client", "transactions"] } +http = "1" +reqwest = { workspace = true, features = ["json"] } +schemars = "0.8" +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +toml = { workspace = true } +tokio = { workspace = true, features = ["full"] } +tower-http = { version = "0.6", features = ["cors", "trace"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +ulid = "1.1" + +[features] +test-support = [] + +[lib] +name = "ai_gateway_service" +path = "src/lib.rs" + +[[test]] +name = "integration" +path = "tests/integration/mod.rs" +required-features = ["test-support"] + +[dev-dependencies] +tempfile = "3" +serde_json = { workspace = true } +reqwest = { workspace = true, features = ["json"] } diff --git a/binary/ai-gateway-service/README.md b/binary/ai-gateway-service/README.md new file mode 100644 index 00000000..16c7b664 --- /dev/null +++ b/binary/ai-gateway-service/README.md @@ -0,0 +1,223 @@ +# AI Gateway Service + +External Redis-backed service used by the `ai-gateway-queue` Proxy-Wasm plugin. + +It keeps Redis, worker execution, Pub/Sub waiting, callback delivery, and result storage outside the wasm sandbox. + +## Endpoints + +- `POST /v1/ratelimit/check` + - Reads `X-Tenant-Id`, optional `X-Model`, `X-Original-Path`, and `X-RateLimit-Policy`. + - Runs a Redis Lua token bucket keyed by **tenant only** (`ai:ratelimit:{tenant}:tokens/ts`). + - Per-tenant overrides via Admin API or Redis keys under `ai:tenant:ratelimit:{tenant}[:model:...][:path:...][:policy:...]`. + - Returns `{ "allowed": bool, "retry_after_ms": number }`. Wasm calls this for **all** policies before enqueue or upstream passthrough. +- `POST /v1/queue/enqueue` + - Requires `X-Callback-URL` by default. + - Streams the request body, then stores either inline base64 body or an object-store reference in Redis Stream. + - Returns `202 Accepted` with `X-Job-Id`. +- `POST /v1/queue/enqueue-and-wait` + - Enqueues the job and waits for the worker result via Redis Pub/Sub. + - Returns the upstream response with `X-Job-Id` and `X-Queue-Wait-Ms`, or `504`. +- `GET /v1/jobs/{job_id}` / `GET /jobs/{job_id}/status` + - When the job is completed, returns the **raw upstream HTTP response** (status, headers, body) with `X-Job-Id`. + - While pending or on error, returns JSON status metadata. +- `GET /metrics` + - Returns Prometheus text metrics for queue depth, PEL size, DLQ depth, enqueue latency, body size, waits, limits, callbacks, retries, object offload, and worker counters. + +## Run + +```bash +cargo run -p ai-gateway-service -- \ + --redis-url redis://127.0.0.1/ \ + --upstream-base-url http://127.0.0.1:9000 +``` + +Or use a TOML config file (recommended for local / deployment): + +```bash +cargo run -p ai-gateway-service -- --config config/ai-gateway-service.toml +``` + +If `--config` / `AI_GATEWAY_CONFIG` is omitted, the service looks for `ai-gateway-service.toml` in the **same directory as the executable**. For deployment, place the binary and config file together: + +```text +/opt/ai-gateway/ + ai-gateway-service # binary + ai-gateway-service.toml # auto-loaded +``` + +Example configs live under `config/`: + +- `config/ai-gateway-service.example.toml` — full reference with all sections +- `config/ai-gateway-service.toml` — minimal local dev template + +Precedence: explicit CLI flags / environment variables > config file > built-in defaults. + +Default config discovery order: + +1. `--config` or `AI_GATEWAY_CONFIG` +2. `{executable_dir}/ai-gateway-service.toml` (if the file exists) +3. Built-in defaults only + +Set the config path via environment variable: + +```bash +AI_GATEWAY_CONFIG=config/ai-gateway-service.toml cargo run -p ai-gateway-service +``` + +Useful environment variables: + +```bash +REDIS_URL=redis://127.0.0.1/ +AI_UPSTREAM_BASE_URL=http://127.0.0.1:9000 +AI_RATE_LIMIT_RPS=100 +AI_RATE_LIMIT_BURST=200 +AI_RATE_LIMIT_COST=1 +AI_WAIT_TIMEOUT_SECS=60 +AI_WORKER_CONCURRENCY=4 +AI_MAX_BODY_BYTES=33554432 +AI_INLINE_THRESHOLD=131072 +AI_QUEUE_MAX_LEN=100000 +AI_ENABLE_PRIORITY_STREAMS=true +AI_QUEUE_DEFAULT_PRIORITY=normal +AI_QUEUE_HIGH_MODELS=gpt-4o,qwen-max +AI_QUEUE_LOW_TENANTS=free +AI_QUEUE_HIGH_WEIGHT=3 +AI_QUEUE_NORMAL_WEIGHT=1 +AI_QUEUE_LOW_WEIGHT=1 +AI_RECLAIM_INTERVAL_SECS=30 +AI_RECLAIM_MIN_IDLE_SECS=30 +AI_JOB_PROCESS_LEASE_SECS=120 +AI_JOB_MAX_DELIVERY_ATTEMPTS=5 +AI_REQUIRE_HTTPS_CALLBACK=true +AI_CALLBACK_MAX_RETRY_ATTEMPTS=5 +AI_CALLBACK_RETRY_INITIAL_DELAY_MS=1000 +AI_CALLBACK_RETRY_MAX_DELAY_MS=60000 +AI_CALLBACK_RETRY_RECLAIM_IDLE_SECS=60 +``` + +Optional object offload variables: + +```bash +AI_OBJECT_STORE_ENDPOINT=http://127.0.0.1:9000 +AI_OBJECT_STORE_BUCKET=ai-gateway-body +AI_OBJECT_STORE_PREFIX=bodies +AI_OBJECT_MULTIPART_PART_SIZE=5242880 +AI_OBJECT_STORE_AUTH_HEADER='Authorization: Bearer token' +``` + +Request body reading is streaming. The service accumulates only the inline buffer until `AI_INLINE_THRESHOLD`; after that it starts multipart upload and flushes parts as `AI_OBJECT_MULTIPART_PART_SIZE` chunks become available. `AI_MAX_BODY_BYTES` is enforced while reading the stream. + +When `AI_OBJECT_STORE_ENDPOINT` is set and the body is larger than `AI_INLINE_THRESHOLD`, the service uses the S3-compatible multipart flow: + +```text +CreateMultipartUpload -> UploadPart* -> CompleteMultipartUpload +``` + +If any part upload or completion fails, the service sends `AbortMultipartUpload` before returning the enqueue error. The current implementation expects a MinIO/S3-compatible endpoint that accepts either unsigned requests or the configured static auth header. + +Tenant rate-limit overrides (Admin API + Redis): + +```text +GET/PUT/DELETE /v1/admin/tenant-rate-limits +``` + +Redis key patterns (most specific match wins; token bucket remains tenant-scoped): + +```text +ai:tenant:ratelimit:{tenant}:model:{model}:path:{path}:policy:{policy} +ai:tenant:ratelimit:{tenant}:model:{model}:path:{path} +ai:tenant:ratelimit:{tenant}:model:{model}:policy:{policy} +ai:tenant:ratelimit:{tenant}:path:{path}:policy:{policy} +ai:tenant:ratelimit:{tenant}:model:{model} +ai:tenant:ratelimit:{tenant}:path:{path} +ai:tenant:ratelimit:{tenant}:policy:{policy} +ai:tenant:ratelimit:{tenant} +``` + +JSON value: + +```json +{"rps": 20, "burst": 40, "cost": 1} +``` + +CSV value: + +```text +20,40,1 +``` + +The old per-tenant keys are still supported as fallback: `ai:tenant:ratelimit:{tenant}:rps`, `:burst`, and `:cost`. + +Global defaults when no tenant rule matches: + +```bash +AI_RATE_LIMIT_RPS=100 +AI_RATE_LIMIT_BURST=200 +AI_RATE_LIMIT_COST=1 +``` + +The Wasm plugin invokes `/v1/ratelimit/check` for **abandon**, **queue**, and **wait** before passthrough or enqueue. + +Priority streams are **enabled by default** (`AI_ENABLE_PRIORITY_STREAMS=true`). Send `X-Queue-Priority: high|normal|low` to route jobs to separate streams, or configure model/tenant defaults: + +```bash +AI_ENABLE_PRIORITY_STREAMS=true +AI_QUEUE_HIGH_STREAM=ai:jobs:high +AI_QUEUE_LOW_STREAM=ai:jobs:low +AI_QUEUE_HIGH_MODELS=gpt-4o,qwen-max +AI_QUEUE_LOW_TENANTS=free +``` + +Workers consume streams in weighted order. `AI_QUEUE_HIGH_WEIGHT`, `AI_QUEUE_NORMAL_WEIGHT`, and `AI_QUEUE_LOW_WEIGHT` control how often each priority stream is checked per loop. + +Callback failures are written to `AI_CALLBACK_RETRY_STREAM` with `attempt`, `next_attempt_at_ms`, and `last_error`. The retry worker uses exponential backoff capped by `AI_CALLBACK_RETRY_MAX_DELAY_MS`, ACKs each retry record after handling it, and moves exhausted callbacks to `AI_CALLBACK_DLQ_STREAM`. Pending Redis Stream jobs are reclaimed with `XAUTOCLAIM` according to the reclaim settings. + +For job processing, each entry acquires a Redis lease key before upstream execution. Reclaimed entries that are already leased are skipped instead of being reprocessed, and jobs exceeding `AI_JOB_MAX_DELIVERY_ATTEMPTS` are moved to `AI_JOB_DLQ_STREAM`. + +`/metrics` includes the core signals needed to operate the queue: + +- `queue_depth`, `queue_depth{priority="high|low"}` for stream backlog. +- `pel_size`, `pel_size{priority="high|low"}`, and `callback_retry_pel_size` for unacked pending entries. +- `job_dlq_depth` and `callback_dlq_depth` for exhausted jobs and callbacks. +- `enqueue_latency_ms_*`, `enqueue_body_size_bytes_*`, `wait_total`, and `wait_timeout_total` for ingress and wait-mode health. +- `worker_processing_time_ms_*`, `worker_completed_total`, `worker_failed_total`, `reclaimed_total`, `lease_skip_total`, and `job_dlq_total` for worker health. +- `object_offload_total` and `object_multipart_abort_total` for large-body offload. + +## Body offload tests + +Unit tests (mock S3 multipart server, no Docker): + +```bash +cargo test -p ai-gateway-service store_body_ +``` + +## 测试规格与集成测试 + +完整用例规格见 [`spacegate/docs/ai-gateway-queue-test-spec.md`](../../docs/ai-gateway-queue-test-spec.md)(TC-* 编号,映射设计文档章节)。 + +```bash +# 单元测试(无需 Redis) +cd spacegate && cargo test -p ai-gateway-service + +# Rust 集成测试(需 Redis 7+) +./spacegate/binary/ai-gateway-service/scripts/run-integration-tests.sh + +# Hurl 黑盒(需 hurl + Redis + 编译 release binary) +./spacegate/binary/ai-gateway-service/scripts/run-hurl-tests.sh + +# Wasm 策略纯逻辑(host 侧) +./spacegate/binary/ai-gateway-service/scripts/run-wasm-policy-tests.sh +``` + +MinIO end-to-end (Docker + worker roundtrip): + +```bash +# 需要:Redis、mock 上游 :9000、Docker +./tests/queue-object-store-e2e.sh +``` + +The script starts MinIO on `:9001` by default (avoids clashing with the mock upstream on `:9000`), launches a dedicated `ai-gateway-service` on `:18081` with `AI_OBJECT_STORE_ENDPOINT`, and verifies: + +- inline body below `AI_INLINE_THRESHOLD` does not increment `object_offload_total` +- larger body is stored in MinIO and the worker completes after `load_body()` fetches it diff --git a/binary/ai-gateway-service/config/ai-gateway-service.example.toml b/binary/ai-gateway-service/config/ai-gateway-service.example.toml new file mode 100644 index 00000000..bdcc59cc --- /dev/null +++ b/binary/ai-gateway-service/config/ai-gateway-service.example.toml @@ -0,0 +1,113 @@ +# ai-gateway-service 配置文件示例 +# 用法: +# cargo run -p ai-gateway-service -- --config config/ai-gateway-service.toml +# AI_GATEWAY_CONFIG=config/ai-gateway-service.toml cargo run -p ai-gateway-service +# +# 优先级:显式 CLI 参数 / 环境变量 > 本配置文件 > 内置默认值 + +# --------------------------------------------------------------------------- +# 服务监听 +# --------------------------------------------------------------------------- +[server] +host = "0.0.0.0" +port = 18080 + +# --------------------------------------------------------------------------- +# Redis(队列、限流、结果存储的核心依赖) +# --------------------------------------------------------------------------- +[redis] +# 单机示例 +url = "redis://127.0.0.1/" +# 带密码 / 指定 DB 示例: +# url = "redis://:your-password@redis.example.com:6379/0" + +# --------------------------------------------------------------------------- +# 上游 AI 服务(Worker 消费队列后转发到此地址) +# --------------------------------------------------------------------------- +[upstream] +base_url = "http://127.0.0.1:9000" + +# --------------------------------------------------------------------------- +# 队列 Stream 与优先级 +# --------------------------------------------------------------------------- +[queue] +stream = "ai:jobs" +high_stream = "ai:jobs:high" +low_stream = "ai:jobs:low" +enable_priority_streams = true +default_priority = "normal" +high_models = ["gpt-4o", "qwen-max"] +low_tenants = ["free"] +high_weight = 3 +normal_weight = 1 +low_weight = 1 +max_len = 100000 +group = "ai-gateway-workers" +consumer = "ai-gateway-service" +job_dlq_stream = "ai:job-dlq" + +# --------------------------------------------------------------------------- +# V1 全局限流(仅 abandon 路径生效;租户差异化配额为 V2,见 README) +# --------------------------------------------------------------------------- +[rate_limit] +rps = 100 +burst = 200 +cost = 1 +tenant_prefix = "ai:tenant:ratelimit:" + +# --------------------------------------------------------------------------- +# Worker 与任务回收 +# --------------------------------------------------------------------------- +[worker] +concurrency = 10 +wait_timeout_secs = 60 +reclaim_interval_secs = 30 +reclaim_min_idle_secs = 30 +job_process_lease_secs = 120 +job_max_delivery_attempts = 5 + +# --------------------------------------------------------------------------- +# 回调(queue 模式) +# --------------------------------------------------------------------------- +[callback] +require_https = true +max_retry_attempts = 5 +retry_initial_delay_ms = 1000 +retry_max_delay_ms = 60000 +retry_reclaim_idle_secs = 60 +retry_stream = "ai:callback-retry" +retry_group = "ai-gateway-callbacks" +dlq_stream = "ai:callback-dlq" + +# --------------------------------------------------------------------------- +# 结果缓存(wait 模式 / 轮询) +# --------------------------------------------------------------------------- +[result] +key_prefix = "result:" +channel_prefix = "result:" +ttl_secs = 120 + +# --------------------------------------------------------------------------- +# 请求体大小限制 +# --------------------------------------------------------------------------- +[body] +max_bytes = 33554432 # 32 MiB +inline_threshold = 131072 # 128 KiB 以下 inline 存 Redis +read_concurrency = 200 + +# --------------------------------------------------------------------------- +# 大 Body 对象存储(可选,S3 / MinIO 兼容) +# --------------------------------------------------------------------------- +[object_store] +# endpoint = "http://127.0.0.1:9000" +bucket = "ai-gateway-body" +prefix = "bodies" +multipart_part_size = 5242880 +# auth_header = "Authorization: Bearer your-token" + +# --------------------------------------------------------------------------- +# Admin API CORS(本地开发可留空,表示 permissive) +# --------------------------------------------------------------------------- +[admin] +cors_origins = [] +# cors_origins = ["http://localhost:5173", "http://127.0.0.1:5173"] diff --git a/binary/ai-gateway-service/config/ai-gateway-service.toml b/binary/ai-gateway-service/config/ai-gateway-service.toml new file mode 100644 index 00000000..30f691e0 --- /dev/null +++ b/binary/ai-gateway-service/config/ai-gateway-service.toml @@ -0,0 +1,30 @@ +# 本地开发配置:按需修改 Redis / 上游 / 对象存储地址 + +[server] +host = "0.0.0.0" +port = 18080 + +[redis] +url = "redis://127.0.0.1/" + +[upstream] +base_url = "http://127.0.0.1:9000" + +[callback] +require_https = false + +# 请求体大小:超过 inline_threshold 且配置了 object_store.endpoint 时走 MinIO/S3 +[body] +max_bytes = 33554432 # 32 MiB,与 Wasm 插件 limits.max_body_bytes 保持一致 +inline_threshold = 131072 # 128 KiB 以下 inline 存 Redis +read_concurrency = 200 + +# 大 Body 对象存储(S3 / MinIO 兼容) +# 本地 MinIO 可参考 tests/queue-object-store-e2e.sh,默认端口 9001(避免与 mock 上游 :9000 冲突) +[object_store] +endpoint = "http://127.0.0.1:9001" +bucket = "ai-gateway-body" +prefix = "bodies" +multipart_part_size = 5242880 +# 无鉴权 MinIO 可留空;需要鉴权时使用 "Header-Name: value" 格式 +# auth_header = "Authorization: Bearer your-token" diff --git a/binary/ai-gateway-service/scripts/queue-object-store-e2e.sh b/binary/ai-gateway-service/scripts/queue-object-store-e2e.sh new file mode 100755 index 00000000..879510b9 --- /dev/null +++ b/binary/ai-gateway-service/scripts/queue-object-store-e2e.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# MinIO + 大 body E2E(TC-BODY-02 smoke) +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +WORKSPACE="$(cd "$ROOT/../.." && pwd)" +cd "$WORKSPACE" + +REDIS_URL="${REDIS_URL:-redis://127.0.0.1/}" +MINIO_PORT="${MINIO_PORT:-9001}" +SVC_PORT="${E2E_SERVICE_PORT:-18081}" + +if ! command -v docker >/dev/null 2>&1; then + echo "SKIP: docker not available" >&2 + exit 0 +fi + +docker rm -f ai-gateway-minio-e2e 2>/dev/null || true +docker run -d --name ai-gateway-minio-e2e \ + -p "${MINIO_PORT}:9000" \ + -e MINIO_ROOT_USER=minioadmin \ + -e MINIO_ROOT_PASSWORD=minioadmin \ + minio/minio server /data >/dev/null + +cleanup() { + docker rm -f ai-gateway-minio-e2e 2>/dev/null || true + kill $SVC_PID $UP_PID 2>/dev/null || true +} +trap cleanup EXIT + +sleep 2 + +# MinIO 需先创建 bucket 并设为 public(服务使用无 SigV4 的直传 HTTP) +if docker run --rm --network host --entrypoint /bin/sh minio/mc -c \ + "mc alias set local http://127.0.0.1:${MINIO_PORT} minioadmin minioadmin && \ + mc mb --ignore-existing local/ai-gateway-body && \ + mc anonymous set public local/ai-gateway-body" >/dev/null 2>&1; then + echo "MinIO bucket ai-gateway-body ready (public)." +else + echo "ERROR: MinIO bucket bootstrap failed" >&2 + exit 1 +fi + +python3 - < SpaceGate(:9993) Wasm -> ai-gateway-service -> MinIO + Redis +# +# 前置: +# docker compose -f docker-compose.yml -f docker-compose.queue.yml --profile queue up -d +# ./scripts/sync-wasm-plugin-to-docker-config.sh && 重启 spacegate +# +# 用法: +# ./spacegate/binary/ai-gateway-service/scripts/run-gateway-large-body-e2e.sh +# ENSURE_STACK=1 ./spacegate/.../run-gateway-large-body-e2e.sh # 自动 compose up +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/../../../.." && pwd)" +cd "$ROOT" + +GATEWAY="${GATEWAY_URL:-http://127.0.0.1:9993}" +SERVICE="${SERVICE_URL:-http://127.0.0.1:18080}" +MINIO_HOST="${MINIO_HOST:-http://127.0.0.1:9010}" +MINIO_USER="${MINIO_ROOT_USER:-minioadmin}" +MINIO_PASS="${MINIO_ROOT_PASSWORD:-minioadmin}" +BUCKET="${AI_OBJECT_STORE_BUCKET:-ai-gateway-body}" +# 200 KiB,超过默认 inline 阈值 128 KiB +BODY_SIZE="${LARGE_BODY_BYTES:-204800}" +TENANT="gw-large-body-${RANDOM}" +COMPOSE=(docker compose -f docker-compose.yml -f docker-compose.queue.yml --profile queue) + +die() { echo "ERROR: $*" >&2; exit 1; } + +metric_value() { + local name="$1" + curl -sf "$SERVICE/metrics" | awk -v n="$name" '$1 == n { print $2; exit }' +} + +wait_http() { + local url="$1" retries="${2:-30}" + for _ in $(seq 1 "$retries"); do + curl -sf "$url" >/dev/null 2>&1 && return 0 + sleep 1 + done + return 1 +} + +if [[ "${ENSURE_STACK:-0}" == "1" ]]; then + echo "==> 启动 / 更新 Docker 栈(queue profile)" + export DOCKER_BUILDKIT=1 + if [[ -x ./scripts/sync-wasm-plugin-to-docker-config.sh ]]; then + ./scripts/sync-wasm-plugin-to-docker-config.sh + fi + "${COMPOSE[@]}" up -d --build minio minio-init ai-gateway-service spacegate admin-server +fi + +echo "==> 前置检查" +wait_http "$SERVICE/healthz" || die "ai-gateway-service 不可达: $SERVICE" +wait_http "http://127.0.0.1:19880/health" || die "SpaceGate 不可达" +curl -sf "$MINIO_HOST/minio/health/live" >/dev/null || die "MinIO 不可达: $MINIO_HOST" + +echo "==> 配置租户限流(burst=1,便于触发 queue 超额入队)" +curl -sf -X PUT "$SERVICE/v1/admin/tenant-rate-limits" \ + -H 'Content-Type: application/json' \ + -d "{\"tenant\":\"$TENANT\",\"rps\":1,\"burst\":1}" >/dev/null + +BASELINE="$(metric_value object_offload_total || echo 0)" + +echo "==> 消耗令牌(abandon 配额内 1 次)" +code=$(curl -s -o /dev/null -w '%{http_code}' -X POST "$GATEWAY/v1/chat/completions" \ + -H 'X-RateLimit-Policy: abandon' \ + -H "X-Tenant-Id: $TENANT" \ + -H 'Content-Type: application/json' \ + -d '{"warmup":true}') +[[ "$code" == "200" ]] || die "预热请求期望 200,实际 $code" + +echo "==> 经网关发送大 body(queue 策略,应 202 入队)" +LARGE="$(python3 -c "print('x'*${BODY_SIZE})")" +# 回调走 Docker 内 mock-upstream,service 容器可直接访问 +CALLBACK="${CALLBACK_URL:-http://mock-upstream:9000/callback}" + +http=$(curl -sS -o /tmp/gw-large-body.json -w '%{http_code}' \ + -X POST "$GATEWAY/v1/chat/completions" \ + -H "X-Tenant-Id: $TENANT" \ + -H 'X-RateLimit-Policy: queue' \ + -H "X-Callback-URL: $CALLBACK" \ + -H 'Content-Type: application/octet-stream' \ + --data-binary "$LARGE") +[[ "$http" == "202" ]] || die "大 body 入队期望 202,实际 $http body=$(cat /tmp/gw-large-body.json)" + +echo "==> 等待 object_offload_total 递增" +found=0 +for _ in $(seq 1 45); do + now="$(metric_value object_offload_total || echo 0)" + if awk -v a="$BASELINE" -v b="$now" 'BEGIN{exit !(b>a)}'; then + echo "object_offload_total: $BASELINE -> $now" + found=1 + break + fi + sleep 1 +done +[[ "$found" == "1" ]] || die "object_offload_total 未递增(baseline=$BASELINE)" + +echo "==> 验证 MinIO bucket 内有对象" +obj_count=$(docker run --rm --network ai-gateway-net --entrypoint /bin/sh minio/mc:latest \ + -c " + mc alias set local http://minio:9000 '$MINIO_USER' '$MINIO_PASS' >/dev/null && + mc ls -r local/$BUCKET/bodies 2>/dev/null | wc -l | tr -d ' ' + ") +[[ "${obj_count:-0}" -gt 0 ]] || die "MinIO bucket/$BUCKET 下未发现 bodies/ 对象" + +echo "==> 全栈大 body E2E 通过(tenant=$TENANT, body=${BODY_SIZE}B, minio_objects>=1)" diff --git a/binary/ai-gateway-service/scripts/run-hurl-tests.sh b/binary/ai-gateway-service/scripts/run-hurl-tests.sh new file mode 100755 index 00000000..32e57f86 --- /dev/null +++ b/binary/ai-gateway-service/scripts/run-hurl-tests.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Hurl 黑盒测试:启动 mock 上游/回调 + ai-gateway-service +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +WORKSPACE="$(cd "$ROOT/../.." && pwd)" +cd "$WORKSPACE" + +REDIS_URL="${REDIS_URL:-redis://127.0.0.1/}" +export REDIS_URL + +if ! command -v hurl >/dev/null 2>&1; then + echo "ERROR: hurl not installed. See https://hurl.dev" >&2 + exit 1 +fi + +redis_ok=false +if command -v redis-cli >/dev/null 2>&1 && redis-cli -u "$REDIS_URL" PING >/dev/null 2>&1; then + redis_ok=true +elif docker exec ai-gateway-redis redis-cli PING >/dev/null 2>&1; then + redis_ok=true +elif nc -z 127.0.0.1 6379 2>/dev/null; then + redis_ok=true +fi +if [[ "$redis_ok" != true ]]; then + echo "ERROR: Redis not reachable at $REDIS_URL" >&2 + exit 1 +fi + +# mock upstream :9000 +python3 - <<'PY' & +import json +from http.server import BaseHTTPRequestHandler, HTTPServer + +class H(BaseHTTPRequestHandler): + def do_POST(self): + body = json.dumps({"upstream": True, "hurl": True}).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + def log_message(self, *args): pass + +HTTPServer(("127.0.0.1", 9000), H).serve_forever() +PY +UP_PID=$! + +# mock callback +python3 - <<'PY' & +from http.server import BaseHTTPRequestHandler, HTTPServer + +class H(BaseHTTPRequestHandler): + def do_POST(self): + n = int(self.headers.get("Content-Length", 0)) + self.rfile.read(n) + self.send_response(200) + self.end_headers() + def log_message(self, *args): pass + +HTTPServer(("127.0.0.1", 9002), H).serve_forever() +PY +CB_PID=$! + +cleanup() { + kill $UP_PID $CB_PID $SVC_PID 2>/dev/null || true +} +trap cleanup EXIT + +cargo build -q -p ai-gateway-service --release +SVC="$WORKSPACE/target/release/ai-gateway-service" +PORT="${HURL_SERVICE_PORT:-18090}" +CALLBACK="http://127.0.0.1:9002/cb" + +"$SVC" \ + --redis-url "$REDIS_URL" \ + --port "$PORT" \ + --host 127.0.0.1 \ + --upstream-base-url http://127.0.0.1:9000 \ + & +SVC_PID=$! +sleep 1 + +export service_url="http://127.0.0.1:${PORT}" +export callback_url="$CALLBACK" + +hurl --test \ + --variable service_url="$service_url" \ + --variable callback_url="$callback_url" \ + --file-root "$ROOT/tests/fixtures" \ + "$ROOT/tests/hurl/ratelimit.hurl" \ + "$ROOT/tests/hurl/queue.hurl" \ + "$ROOT/tests/hurl/wait.hurl" \ + "$ROOT/tests/hurl/metrics.hurl" \ + "$ROOT/tests/hurl/admin.hurl" + +echo "Hurl tests passed." diff --git a/binary/ai-gateway-service/scripts/run-integration-tests.sh b/binary/ai-gateway-service/scripts/run-integration-tests.sh new file mode 100755 index 00000000..5b69c10b --- /dev/null +++ b/binary/ai-gateway-service/scripts/run-integration-tests.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# 运行 Rust 集成测试(需 Redis 7+) +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT/../.." + +REDIS_URL="${REDIS_URL:-redis://127.0.0.1/}" +export REDIS_URL + +echo "Checking Redis at $REDIS_URL ..." +redis_ok=false +if command -v redis-cli >/dev/null 2>&1; then + if redis-cli -u "$REDIS_URL" INFO server 2>/dev/null | grep -qE 'redis_version:(7|[89])'; then + redis_ok=true + fi +elif docker exec ai-gateway-redis redis-cli INFO server 2>/dev/null | grep -qE 'redis_version:(7|[89])'; then + redis_ok=true +elif nc -z 127.0.0.1 6379 2>/dev/null; then + redis_ok=true +fi +if [[ "$redis_ok" != true ]]; then + echo "ERROR: Redis 7+ required. Start redis or set REDIS_URL." >&2 + exit 1 +fi + +cargo test -p ai-gateway-service --features test-support --test integration "$@" diff --git a/binary/ai-gateway-service/scripts/run-wasm-policy-tests.sh b/binary/ai-gateway-service/scripts/run-wasm-policy-tests.sh new file mode 100755 index 00000000..04d47754 --- /dev/null +++ b/binary/ai-gateway-service/scripts/run-wasm-policy-tests.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +# Wasm 策略纯逻辑 host 侧单测(绕过 wasm32 默认 target) +set -euo pipefail +DIR="$(cd "$(dirname "$0")/../../../plugins/wasm/ai-gateway-queue" && pwd)" +HOST=$(rustc -vV | sed -n 's/host: //p') +cd "$DIR" +cargo test --lib --target "$HOST" "$@" diff --git a/binary/ai-gateway-service/src/app.rs b/binary/ai-gateway-service/src/app.rs new file mode 100644 index 00000000..5b424929 --- /dev/null +++ b/binary/ai-gateway-service/src/app.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use axum::body::Body; +use axum::extract::{DefaultBodyLimit, Path, Query, State}; +use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use base64::Engine; +use clap::Parser; +use fred::clients::{Client as FredClient, SubscriberClient}; +use fred::prelude::*; +use fred::types::InfoKind; +use fred::types::streams::XReadResponse; +use fred::types::ExpireOptions; +use futures_util::StreamExt; +use schemars::{schema_for, JsonSchema}; +use serde::{Deserialize, Serialize}; +use std::sync::Mutex; +use tokio::sync::{oneshot, Semaphore}; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::trace::TraceLayer; + +include!("app/types.rs"); +include!("app/config.rs"); +include!("app/runtime.rs"); +include!("app/handlers.rs"); +include!("app/queue.rs"); +include!("app/callback.rs"); +include!("app/result_store.rs"); +include!("app/object_store.rs"); +include!("app/metrics.rs"); +include!("app/ratelimit.rs"); +include!("app/admin.rs"); +include!("app/wait_subscriber.rs"); +include!("app/util.rs"); + +#[cfg(feature = "test-support")] +pub mod test_support { + include!("app/test_support.rs"); +} + +include!("app/tests.rs"); diff --git a/binary/ai-gateway-service/src/app/admin.rs b/binary/ai-gateway-service/src/app/admin.rs new file mode 100644 index 00000000..1ae18460 --- /dev/null +++ b/binary/ai-gateway-service/src/app/admin.rs @@ -0,0 +1,277 @@ +const AI_GATEWAY_QUEUE_PLUGIN: &str = "ai-gateway-queue"; +const AI_GATEWAY_QUEUE_README: &str = include_str!("../../../../plugins/wasm/ai-gateway-queue/README.md"); + +async fn admin_plugin_schema(Path(plugin): Path) -> Result, ServiceError> { + if plugin != AI_GATEWAY_QUEUE_PLUGIN { + return Err(ServiceError::bad_request(format!("unsupported plugin `{plugin}`"))); + } + + let schema = schema_for!(AiGatewayQueuePluginConfig); + let mut value = serde_json::to_value(&schema).map_err(|e| ServiceError::internal(format!("serialize schema: {e}")))?; + add_ai_gateway_queue_schema_extensions(&mut value); + Ok(Json(value)) +} + +async fn admin_plugin_readme(Path(plugin): Path) -> Result { + if plugin != AI_GATEWAY_QUEUE_PLUGIN { + return Err(ServiceError::bad_request(format!("unsupported plugin `{plugin}`"))); + } + Ok((StatusCode::OK, [("content-type", "text/markdown; charset=utf-8")], AI_GATEWAY_QUEUE_README).into_response()) +} + + +async fn admin_list_tenant_rate_limits(State(state): State, Query(filters): Query>) -> Result>, ServiceError> { + Ok(Json(list_tenant_rate_limit_rules(&state, &filters).await?)) +} + +async fn admin_upsert_tenant_rate_limit(State(state): State, Json(rule): Json) -> Result, ServiceError> { + Ok(Json(upsert_tenant_rate_limit_rule(&state, rule).await?)) +} + +async fn admin_delete_tenant_rate_limit(State(state): State, Json(rule): Json) -> Result, ServiceError> { + let removed = delete_tenant_rate_limit_rule(&state, rule).await?; + Ok(Json(serde_json::json!({ "removed": removed }))) +} + +fn add_ai_gateway_queue_schema_extensions(value: &mut serde_json::Value) { + let example = serde_json::to_string_pretty(&AiGatewayQueuePluginConfig::default()).unwrap_or_default(); + if let Some(object) = value.as_object_mut() { + object.insert("x-example-raw".to_string(), serde_json::Value::String(example)); + object.insert( + "x-title-i18n".to_string(), + serde_json::json!({ + "zh-CN": "AI 请求队列网关", + "en": "AI Request Queue Gateway" + }), + ); + object.insert( + "x-description-i18n".to_string(), + serde_json::json!({ + "zh-CN": "配置 AI 请求队列网关:队列后端接入、入队接口路径、请求头映射、队列模式与优先级路由。", + "en": "Configure the AI request queue gateway: queue backend access, enqueue paths, header mapping, queue mode and priority routing." + }), + ); + } + + let Some(definitions) = value.get_mut("definitions").and_then(|v| v.as_object_mut()) else { return }; + + // 子配置卡片自身的标题/说明:会被 SchemaForm 的 el-card 标题区使用。 + annotate_schema_meta( + definitions.get_mut("AiGatewayServiceConfig"), + "队列后端接入", + "Queue Backend Access", + "Wasm 插件调用外部队列后端时使用的 cluster、authority 和超时设置。", + "Cluster, authority and timeout the wasm plugin uses to call the external queue backend.", + ); + annotate_schema_meta( + definitions.get_mut("AiGatewayPathsConfig"), + "接口路径", + "Paths", + "队列后端暴露的准入判定、入队、入队并等待三类 HTTP 路径。", + "HTTP paths exposed by the queue backend for admission check, enqueue, and enqueue-and-wait.", + ); + annotate_schema_meta( + definitions.get_mut("AiGatewayHeadersConfig"), + "请求头映射", + "Headers", + "客户端实际使用的 Header 名称;插件会把它们统一转成队列后端期望的标准 Header。", + "Header names used by clients; the plugin remaps them to the standard headers the queue backend expects.", + ); + annotate_schema_meta( + definitions.get_mut("AiGatewayPoliciesConfig"), + "队列模式", + "Queue Mode", + "控制 X-RateLimit-Policy 请求头是否必填,以及未携带时使用的默认队列模式(abandon / queue / wait)。", + "Controls whether the X-RateLimit-Policy header is required, and the default queue mode used when it is missing (abandon / queue / wait).", + ); + annotate_schema_meta( + definitions.get_mut("AiGatewayPriorityConfig"), + "优先级路由", + "Priority Routing", + "队列优先级的开关、默认值,以及按模型 / 租户自动选择高 / 普通 / 低优先级队列的规则。", + "Queue priority switch, default value, and per-model / per-tenant rules that route requests into high / normal / low priority streams.", + ); + + set_field_descriptions( + definitions.get_mut("AiGatewayServiceConfig"), + &[ + ( + "cluster", + "队列后端 Cluster", + "Queue Backend Cluster", + "SpaceGate 中指向队列后端的 cluster 名称,对应 SpaceGate 配置里的 clusters 键。", + "Name of the SpaceGate cluster pointing to the queue backend; matches the key under the clusters field.", + ), + ( + "authority", + "队列后端 Authority", + "Queue Backend Authority", + "Wasm 插件 dispatch HTTP call 时使用的 :authority,通常和 cluster 同名。", + "The :authority used by the wasm dispatch_http_call; usually the same as the cluster name.", + ), + ( + "timeout_ms", + "调用超时(毫秒)", + "Timeout (ms)", + "调用队列后端的超时时间;wait 模式需要留足同步等待时间,建议 60000 ms 以上。", + "Timeout for calling the queue backend. Keep it above 60000 ms when wait mode is used.", + ), + ], + ); + + set_field_descriptions( + definitions.get_mut("AiGatewayPathsConfig"), + &[ + ( + "rate_limit", + "准入判定路径", + "Admission Check Path", + "队列后端用于判断请求是否需要入队的准入接口,默认 /v1/ratelimit/check。", + "Backend path that decides whether a request should be enqueued. Default: /v1/ratelimit/check.", + ), + ( + "enqueue", + "入队路径", + "Enqueue Path", + "queue 模式使用的异步入队接口,默认 /v1/queue/enqueue。", + "Endpoint used by the queue (async) mode. Default: /v1/queue/enqueue.", + ), + ( + "wait", + "入队并等待路径", + "Enqueue-and-Wait Path", + "wait 模式使用的入队并同步等待结果接口,默认 /v1/queue/enqueue-and-wait。", + "Endpoint used by the wait (sync) mode. Default: /v1/queue/enqueue-and-wait.", + ), + ], + ); + + set_field_descriptions( + definitions.get_mut("AiGatewayHeadersConfig"), + &[ + ( + "policy", + "队列模式 Header", + "Queue Mode Header", + "客户端用于声明队列模式(abandon / queue / wait)的 Header,插件会转成后端使用的 x-ratelimit-policy。", + "Header the client uses to declare the queue mode (abandon / queue / wait); remapped to x-ratelimit-policy.", + ), + ( + "tenant", + "租户 Header", + "Tenant Header", + "客户端表示租户身份的 Header,插件会转成队列后端使用的 x-tenant-id。", + "Header carrying tenant identity; remapped to x-tenant-id for the queue backend.", + ), + ( + "model", + "模型 Header", + "Model Header", + "客户端声明目标模型的 Header,会被透传为队列后端使用的 x-model。", + "Header that names the target model; remapped to x-model for the queue backend.", + ), + ( + "priority", + "优先级 Header", + "Priority Header", + "客户端可选的队列优先级 Header,启用优先级时会被转为 x-queue-priority(取值 high/normal/low)。", + "Optional header for queue priority, remapped to x-queue-priority (values high/normal/low).", + ), + ], + ); + + set_field_descriptions( + definitions.get_mut("AiGatewayPoliciesConfig"), + &[ + ( + "require", + "强制要求队列模式 Header", + "Require Queue Mode Header", + "为 true 时,请求未携带队列模式 Header 会直接返回 400;关闭后会回退到默认队列模式。", + "When true, requests without the queue-mode header are rejected with 400; otherwise falls back to the default mode.", + ), + ( + "default", + "默认队列模式", + "Default Queue Mode", + "未携带队列模式 Header 且 require 为 false 时使用的默认模式,可选 abandon / queue / wait。", + "Default queue mode when require is false and the request omits the header. One of abandon / queue / wait.", + ), + ], + ); + + set_field_descriptions( + definitions.get_mut("AiGatewayPriorityConfig"), + &[ + ( + "enabled", + "启用优先级路由", + "Enable Priority Routing", + "总开关:关闭后所有请求都进入 normal 优先级队列,不再读取模型/租户规则。", + "Master switch; when disabled, all requests go to the normal-priority queue and per-model / per-tenant rules are ignored.", + ), + ( + "default", + "默认队列优先级", + "Default Queue Priority", + "命中不到任何规则时使用的默认队列优先级,可选 high / normal / low。", + "Default queue priority used when no rule matches. One of high / normal / low.", + ), + ( + "high_models", + "高优队列模型列表", + "High Priority Models", + "命中后自动路由到高优队列的模型名列表(精确匹配,区分大小写)。", + "Models that are routed to the high-priority queue (exact, case-sensitive match).", + ), + ( + "low_models", + "低优队列模型列表", + "Low Priority Models", + "命中后自动路由到低优队列的模型名列表。", + "Models that are routed to the low-priority queue.", + ), + ( + "high_tenants", + "高优队列租户列表", + "High Priority Tenants", + "命中后自动路由到高优队列的租户 ID 列表。", + "Tenant IDs that are routed to the high-priority queue.", + ), + ( + "low_tenants", + "低优队列租户列表", + "Low Priority Tenants", + "命中后自动路由到低优队列的租户 ID 列表,常用于免费 / 试用租户。", + "Tenant IDs that are routed to the low-priority queue, typically used for free or trial tenants.", + ), + ], + ); +} + +fn annotate_schema_meta(schema: Option<&mut serde_json::Value>, zh_title: &str, en_title: &str, zh_desc: &str, en_desc: &str) { + let Some(object) = schema.and_then(|v| v.as_object_mut()) else { return }; + object.insert( + "x-title-i18n".to_string(), + serde_json::json!({ "zh-CN": zh_title, "en": en_title }), + ); + object.insert( + "x-description-i18n".to_string(), + serde_json::json!({ "zh-CN": zh_desc, "en": en_desc }), + ); +} + +fn set_field_descriptions(schema: Option<&mut serde_json::Value>, items: &[(&str, &str, &str, &str, &str)]) { + let Some(properties) = schema.and_then(|v| v.get_mut("properties")).and_then(|v| v.as_object_mut()) else { return }; + for (key, zh_title, en_title, zh_desc, en_desc) in items { + let Some(field) = properties.get_mut(*key).and_then(|v| v.as_object_mut()) else { continue }; + field.insert( + "x-title-i18n".to_string(), + serde_json::json!({ "zh-CN": zh_title, "en": en_title }), + ); + field.insert( + "x-description-i18n".to_string(), + serde_json::json!({ "zh-CN": zh_desc, "en": en_desc }), + ); + } +} diff --git a/binary/ai-gateway-service/src/app/callback.rs b/binary/ai-gateway-service/src/app/callback.rs new file mode 100644 index 00000000..0836ad19 --- /dev/null +++ b/binary/ai-gateway-service/src/app/callback.rs @@ -0,0 +1,187 @@ +fn callback_body(result: &StoredResult) -> serde_json::Value { + // 设计文档回调 JSON:job_id / status / result / completed_at + serde_json::json!({ + "job_id": result.job_id, + "status": result.status, + "result": decode_callback_result(&result.body_base64), + "completed_at": format_completed_at_rfc3339(result.completed_at_ms), + }) +} + +async fn post_callback(state: &AppState, callback_url: &str, job_id: &str, body: &serde_json::Value) -> Result<(), ServiceError> { + state.http.post(callback_url).header("x-gateway-job-id", job_id).json(body).send().await?.error_for_status()?; + Ok(()) +} + +async fn enqueue_callback_retry(state: &AppState, callback_url: &str, job_id: &str, body: &serde_json::Value, last_error: &str) -> Result<(), ServiceError> { + let body = serde_json::to_string(body).map_err(|e| ServiceError::internal(format!("serialize callback retry: {e}")))?; + enqueue_callback_retry_raw( + state, + callback_url, + job_id, + &body, + 1, + now_ms().saturating_add(state.cfg.callback_retry_initial_delay_ms), + last_error, + ) + .await +} + +async fn enqueue_callback_retry_raw( + state: &AppState, + callback_url: &str, + job_id: &str, + body: &str, + attempt: u32, + next_attempt_at_ms: u64, + last_error: &str, +) -> Result<(), ServiceError> { + let _: String = state + .redis + .xadd( + state.cfg.callback_retry_stream.as_str(), + false, + None::<()>, + "*", + vec![ + ("job_id", Value::String(job_id.to_string().into())), + ("callback_url", Value::String(callback_url.to_string().into())), + ("body", Value::String(body.to_string().into())), + ("attempt", Value::Integer(attempt as i64)), + ("next_attempt_at_ms", Value::Integer(next_attempt_at_ms as i64)), + ("last_error", Value::String(last_error.to_string().into())), + ("created_at", Value::Integer(now_ms() as i64)), + ], + ) + .await?; + trim_stream(state, &state.cfg.callback_retry_stream).await?; + state.metrics.callback_retry_total.fetch_add(1, Ordering::Relaxed); + Ok(()) +} + +fn spawn_callback_retry_worker(state: AppState) { + tokio::spawn(async move { + loop { + if let Err(e) = callback_retry_once(&state).await { + tracing::warn!(error = %e.message, "callback retry loop failed"); + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + }); +} + +async fn callback_retry_once(state: &AppState) -> Result<(), ServiceError> { + reclaim_callback_retries(state).await?; + let reply = xreadgroup_map_or_empty( + &state.worker_redis, + state.cfg.callback_retry_group.as_str(), + state.cfg.consumer_name.as_str(), + Some(5), + Some(1000), + false, + vec![state.cfg.callback_retry_stream.as_str()], + vec![">"], + ) + .await?; + + for (_stream, entries) in reply { + for (entry_id, fields) in entries { + process_callback_retry_entry(state, entry_id.as_str(), &fields).await?; + } + } + Ok(()) +} + +async fn reclaim_callback_retries(state: &AppState) -> Result<(), ServiceError> { + let consumer = format!("{}-callback-reclaimer", state.cfg.consumer_name); + let min_idle_ms = state.cfg.callback_retry_reclaim_idle_secs.saturating_mul(1000); + let (_cursor, entries): (String, Vec<(String, HashMap)>) = state + .worker_redis + .xautoclaim_values( + state.cfg.callback_retry_stream.as_str(), + state.cfg.callback_retry_group.as_str(), + consumer.as_str(), + min_idle_ms, + "0-0", + Some(10), + false, + ) + .await?; + for (entry_id, fields) in entries { + process_callback_retry_entry(state, entry_id.as_str(), &fields).await?; + } + Ok(()) +} + +async fn process_callback_retry_entry(state: &AppState, entry_id: &str, fields: &HashMap) -> Result<(), ServiceError> { + let job_id = field_string(fields, "job_id").unwrap_or_default(); + let callback_url = field_string(fields, "callback_url").unwrap_or_default(); + let body = field_string(fields, "body").unwrap_or_else(|| "{}".to_string()); + let attempt = field_u32(fields, "attempt").unwrap_or(1); + let next_attempt_at_ms = field_u64(fields, "next_attempt_at_ms").unwrap_or(0); + let now = now_ms(); + + if next_attempt_at_ms > now { + let last_error = field_string(fields, "last_error").unwrap_or_default(); + enqueue_callback_retry_raw(state, &callback_url, &job_id, &body, attempt, next_attempt_at_ms, &last_error).await?; + ack_callback_retry(state, entry_id).await?; + return Ok(()); + } + + let parsed = serde_json::from_str::(&body).unwrap_or_else(|_| serde_json::json!({ "body": body })); + match post_callback(state, &callback_url, &job_id, &parsed).await { + Ok(()) => { + ack_callback_retry(state, entry_id).await?; + state.metrics.callback_retry_success_total.fetch_add(1, Ordering::Relaxed); + } + Err(e) => { + tracing::warn!(job_id = %job_id, attempt, error = %e.message, "callback retry failed"); + if attempt >= state.cfg.callback_max_retry_attempts { + enqueue_callback_dlq(state, &callback_url, &job_id, &parsed, attempt, &e.message).await?; + ack_callback_retry(state, entry_id).await?; + } else { + let next_attempt = attempt.saturating_add(1); + let delay_ms = callback_retry_delay_ms(state.cfg.callback_retry_initial_delay_ms, state.cfg.callback_retry_max_delay_ms, next_attempt); + let retry_body = serde_json::to_string(&parsed).unwrap_or_else(|_| "{}".to_string()); + enqueue_callback_retry_raw(state, &callback_url, &job_id, &retry_body, next_attempt, now.saturating_add(delay_ms), &e.message).await?; + ack_callback_retry(state, entry_id).await?; + } + } + } + Ok(()) +} + +async fn ack_callback_retry(state: &AppState, entry_id: &str) -> Result<(), ServiceError> { + let _: i64 = state.redis.xack(state.cfg.callback_retry_stream.as_str(), state.cfg.callback_retry_group.as_str(), vec![entry_id]).await?; + Ok(()) +} + +async fn enqueue_callback_dlq(state: &AppState, callback_url: &str, job_id: &str, body: &serde_json::Value, attempts: u32, final_error: &str) -> Result<(), ServiceError> { + let body = serde_json::to_string(body).map_err(|e| ServiceError::internal(format!("serialize callback dlq: {e}")))?; + let _: String = state + .redis + .xadd( + state.cfg.callback_dlq_stream.as_str(), + false, + None::<()>, + "*", + vec![ + ("job_id", Value::String(job_id.to_string().into())), + ("callback_url", Value::String(callback_url.to_string().into())), + ("body", Value::String(body.into())), + ("attempts", Value::Integer(attempts as i64)), + ("final_error", Value::String(final_error.to_string().into())), + ("failed_at", Value::Integer(now_ms() as i64)), + ], + ) + .await?; + trim_stream(state, &state.cfg.callback_dlq_stream).await?; + state.metrics.callback_retry_dlq_total.fetch_add(1, Ordering::Relaxed); + Ok(()) +} + +fn callback_retry_delay_ms(initial_delay_ms: u64, max_delay_ms: u64, attempt: u32) -> u64 { + let exponent = attempt.saturating_sub(1).min(16); + let multiplier = 1u64.checked_shl(exponent).unwrap_or(u64::MAX); + initial_delay_ms.saturating_mul(multiplier).min(max_delay_ms) +} diff --git a/binary/ai-gateway-service/src/app/config.rs b/binary/ai-gateway-service/src/app/config.rs new file mode 100644 index 00000000..b602ee8e --- /dev/null +++ b/binary/ai-gateway-service/src/app/config.rs @@ -0,0 +1,492 @@ +use std::path::{Path as ConfigPath, PathBuf}; + +use clap::{parser::ValueSource, ArgMatches, CommandFactory, FromArgMatches}; + +/// 默认可执行文件同目录下的配置文件名。 +const DEFAULT_CONFIG_FILE_NAME: &str = "ai-gateway-service.toml"; + +/// CLI 包装层:配置文件路径 + 原有 Args。 +#[derive(Debug, Parser)] +#[command(version, about = "External Redis-backed rate-limit and queue service for SpaceGate AI gateway")] +struct Cli { + /// TOML 配置文件路径;未指定时尝试读取可执行文件同目录下的 ai-gateway-service.toml。 + #[arg(long, env = "AI_GATEWAY_CONFIG", value_name = "FILE")] + config: Option, + #[command(flatten)] + args: Args, +} + +/// 解析最终使用的配置文件路径:显式参数 > 可执行文件同目录默认文件。 +fn resolve_config_path(explicit: Option) -> Option { + if let Some(path) = explicit { + return Some(path); + } + default_config_path_beside_executable() +} + +/// 可执行文件所在目录下的默认配置文件(存在才返回)。 +fn default_config_path_beside_executable() -> Option { + std::env::current_exe().ok().and_then(|exe| default_config_path_in_dir(&exe)) +} + +/// 给定可执行文件路径,返回同目录下默认配置文件路径(存在才返回)。 +fn default_config_path_in_dir(exe_path: &ConfigPath) -> Option { + let dir = exe_path.parent()?; + let path = dir.join(DEFAULT_CONFIG_FILE_NAME); + path.is_file().then_some(path) +} + +/// 从 CLI、环境变量和可选 TOML 配置文件合并出最终运行参数。 +fn load_args() -> Result> { + let matches = Cli::command().get_matches(); + let explicit_config = matches.get_one::("config").cloned(); + let config_path = resolve_config_path(explicit_config); + let cli = Cli::from_arg_matches(&matches).expect("cli args"); + + let file_args = match config_path.as_deref() { + Some(path) => { + tracing::info!(path = %path.display(), "loading config file"); + Some(ServiceConfigFile::load(path)?.into_args()) + } + None => None, + }; + + Ok(merge_args(file_args, cli.args, &matches)) +} + +/// TOML 配置文件根结构;各 section 均可选,便于按需扩展。 +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct ServiceConfigFile { + server: ServerSection, + redis: RedisSection, + upstream: UpstreamSection, + queue: QueueSection, + rate_limit: RateLimitSection, + worker: WorkerSection, + callback: CallbackSection, + result: ResultSection, + body: BodySection, + object_store: ObjectStoreSection, + admin: AdminSection, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct ServerSection { + host: Option, + port: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct RedisSection { + /// Redis 连接 URL,例如 redis://127.0.0.1/ 或 redis://:password@host:6379/0 + url: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct UpstreamSection { + /// 上游 AI 服务地址;未配置时只入队,不启动 worker。 + base_url: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct QueueSection { + stream: Option, + high_stream: Option, + low_stream: Option, + enable_priority_streams: Option, + default_priority: Option, + high_models: Option>, + low_models: Option>, + high_tenants: Option>, + low_tenants: Option>, + high_weight: Option, + normal_weight: Option, + low_weight: Option, + max_len: Option, + group: Option, + consumer: Option, + job_dlq_stream: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct RateLimitSection { + rps: Option, + burst: Option, + cost: Option, + tenant_prefix: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct WorkerSection { + concurrency: Option, + wait_timeout_secs: Option, + reclaim_interval_secs: Option, + reclaim_min_idle_secs: Option, + job_process_lease_secs: Option, + job_max_delivery_attempts: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct CallbackSection { + require_https: Option, + max_retry_attempts: Option, + retry_initial_delay_ms: Option, + retry_max_delay_ms: Option, + retry_reclaim_idle_secs: Option, + retry_stream: Option, + retry_group: Option, + dlq_stream: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct ResultSection { + key_prefix: Option, + channel_prefix: Option, + ttl_secs: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct BodySection { + max_bytes: Option, + inline_threshold: Option, + read_concurrency: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct ObjectStoreSection { + endpoint: Option, + bucket: Option, + prefix: Option, + multipart_part_size: Option, + auth_header: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct AdminSection { + cors_origins: Option>, +} + +impl ServiceConfigFile { + fn load(path: &ConfigPath) -> Result> { + let raw = std::fs::read_to_string(path).map_err(|e| format!("read config `{}`: {e}", path.display()))?; + let cfg: Self = toml::from_str(&raw).map_err(|e| format!("parse config `{}`: {e}", path.display()))?; + Ok(cfg) + } + + fn into_args(self) -> Args { + let mut args = Args::default(); + if let Some(host) = self.server.host { + args.host = host.parse().unwrap_or(args.host); + } + if let Some(port) = self.server.port { + args.port = port; + } + if let Some(url) = self.redis.url { + args.redis_url = url; + } + args.upstream_base_url = self.upstream.base_url; + + if let Some(stream) = self.queue.stream { + args.stream_key = stream; + } + if let Some(stream) = self.queue.high_stream { + args.high_priority_stream_key = stream; + } + if let Some(stream) = self.queue.low_stream { + args.low_priority_stream_key = stream; + } + if let Some(value) = self.queue.enable_priority_streams { + args.enable_priority_streams = value; + } + if let Some(value) = self.queue.default_priority { + args.queue_default_priority = value; + } + if let Some(values) = self.queue.high_models { + args.queue_high_models = join_csv(&values); + } + if let Some(values) = self.queue.low_models { + args.queue_low_models = join_csv(&values); + } + if let Some(values) = self.queue.high_tenants { + args.queue_high_tenants = join_csv(&values); + } + if let Some(values) = self.queue.low_tenants { + args.queue_low_tenants = join_csv(&values); + } + if let Some(value) = self.queue.high_weight { + args.queue_high_weight = value; + } + if let Some(value) = self.queue.normal_weight { + args.queue_normal_weight = value; + } + if let Some(value) = self.queue.low_weight { + args.queue_low_weight = value; + } + if let Some(value) = self.queue.max_len { + args.stream_max_len = value; + } + if let Some(value) = self.queue.group { + args.consumer_group = value; + } + if let Some(value) = self.queue.consumer { + args.consumer_name = value; + } + if let Some(value) = self.queue.job_dlq_stream { + args.job_dlq_stream = value; + } + + if let Some(value) = self.rate_limit.rps { + args.rate_limit_rps = value; + } + if let Some(value) = self.rate_limit.burst { + args.rate_limit_burst = value; + } + if let Some(value) = self.rate_limit.cost { + args.rate_limit_cost = value; + } + if let Some(value) = self.rate_limit.tenant_prefix { + args.tenant_rate_limit_prefix = value; + } + + if let Some(value) = self.worker.concurrency { + args.worker_concurrency = value; + } + if let Some(value) = self.worker.wait_timeout_secs { + args.wait_timeout_secs = value; + } + if let Some(value) = self.worker.reclaim_interval_secs { + args.reclaim_interval_secs = value; + } + if let Some(value) = self.worker.reclaim_min_idle_secs { + args.reclaim_min_idle_secs = value; + } + if let Some(value) = self.worker.job_process_lease_secs { + args.job_process_lease_secs = value; + } + if let Some(value) = self.worker.job_max_delivery_attempts { + args.job_max_delivery_attempts = value; + } + + if let Some(value) = self.callback.require_https { + args.require_https_callback = value; + } + if let Some(value) = self.callback.max_retry_attempts { + args.callback_max_retry_attempts = value; + } + if let Some(value) = self.callback.retry_initial_delay_ms { + args.callback_retry_initial_delay_ms = value; + } + if let Some(value) = self.callback.retry_max_delay_ms { + args.callback_retry_max_delay_ms = value; + } + if let Some(value) = self.callback.retry_reclaim_idle_secs { + args.callback_retry_reclaim_idle_secs = value; + } + if let Some(value) = self.callback.retry_stream { + args.callback_retry_stream = value; + } + if let Some(value) = self.callback.retry_group { + args.callback_retry_group = value; + } + if let Some(value) = self.callback.dlq_stream { + args.callback_dlq_stream = value; + } + + if let Some(value) = self.result.key_prefix { + args.result_key_prefix = value; + } + if let Some(value) = self.result.channel_prefix { + args.result_channel_prefix = value; + } + if let Some(value) = self.result.ttl_secs { + args.result_ttl_secs = value; + } + + if let Some(value) = self.body.max_bytes { + args.max_body_bytes = value; + } + if let Some(value) = self.body.inline_threshold { + args.inline_threshold = value; + } + if let Some(value) = self.body.read_concurrency { + args.body_read_concurrency = value; + } + + args.object_store_endpoint = self.object_store.endpoint; + if let Some(value) = self.object_store.bucket { + args.object_store_bucket = value; + } + if let Some(value) = self.object_store.prefix { + args.object_store_prefix = value; + } + if let Some(value) = self.object_store.multipart_part_size { + args.object_multipart_part_size = value; + } + args.object_store_auth_header = self.object_store.auth_header; + + if let Some(values) = self.admin.cors_origins { + args.admin_cors_origins = join_csv(&values); + } + + args + } +} + +/// 合并优先级:显式 CLI / 环境变量 > 配置文件 > 内置默认值。 +fn merge_args(file_args: Option, cli_args: Args, matches: &ArgMatches) -> Args { + let file = file_args.unwrap_or_else(Args::default); + let mut out = file; + + macro_rules! pick { + ($field:ident, $id:expr) => { + if is_explicit(matches, $id) { + out.$field = cli_args.$field; + } + }; + ($field:ident, $id:expr, clone) => { + if is_explicit(matches, $id) { + out.$field = cli_args.$field.clone(); + } + }; + } + + pick!(host, "host"); + pick!(port, "port"); + pick!(redis_url, "redis_url", clone); + pick!(stream_key, "stream_key", clone); + pick!(high_priority_stream_key, "high_priority_stream_key", clone); + pick!(low_priority_stream_key, "low_priority_stream_key", clone); + pick!(enable_priority_streams, "enable_priority_streams"); + pick!(queue_default_priority, "queue_default_priority", clone); + pick!(queue_high_models, "queue_high_models", clone); + pick!(queue_low_models, "queue_low_models", clone); + pick!(queue_high_tenants, "queue_high_tenants", clone); + pick!(queue_low_tenants, "queue_low_tenants", clone); + pick!(queue_high_weight, "queue_high_weight"); + pick!(queue_normal_weight, "queue_normal_weight"); + pick!(queue_low_weight, "queue_low_weight"); + pick!(stream_max_len, "stream_max_len"); + pick!(consumer_group, "consumer_group", clone); + pick!(consumer_name, "consumer_name", clone); + pick!(job_dlq_stream, "job_dlq_stream", clone); + pick!(callback_retry_stream, "callback_retry_stream", clone); + pick!(callback_retry_group, "callback_retry_group", clone); + pick!(callback_dlq_stream, "callback_dlq_stream", clone); + pick!(callback_max_retry_attempts, "callback_max_retry_attempts"); + pick!(callback_retry_initial_delay_ms, "callback_retry_initial_delay_ms"); + pick!(callback_retry_max_delay_ms, "callback_retry_max_delay_ms"); + pick!(callback_retry_reclaim_idle_secs, "callback_retry_reclaim_idle_secs"); + pick!(result_key_prefix, "result_key_prefix", clone); + pick!(result_channel_prefix, "result_channel_prefix", clone); + pick!(result_ttl_secs, "result_ttl_secs"); + pick!(rate_limit_rps, "rate_limit_rps"); + pick!(rate_limit_burst, "rate_limit_burst"); + pick!(rate_limit_cost, "rate_limit_cost"); + pick!(tenant_rate_limit_prefix, "tenant_rate_limit_prefix", clone); + pick!(wait_timeout_secs, "wait_timeout_secs"); + pick!(worker_concurrency, "worker_concurrency"); + pick!(admin_cors_origins, "admin_cors_origins", clone); + pick!(max_body_bytes, "max_body_bytes"); + pick!(inline_threshold, "inline_threshold"); + pick!(body_read_concurrency, "body_read_concurrency"); + pick!(reclaim_interval_secs, "reclaim_interval_secs"); + pick!(reclaim_min_idle_secs, "reclaim_min_idle_secs"); + pick!(job_process_lease_secs, "job_process_lease_secs"); + pick!(job_max_delivery_attempts, "job_max_delivery_attempts"); + pick!(require_https_callback, "require_https_callback"); + pick!(object_store_bucket, "object_store_bucket", clone); + pick!(object_store_prefix, "object_store_prefix", clone); + pick!(object_multipart_part_size, "object_multipart_part_size"); + + if is_explicit(matches, "upstream_base_url") { + out.upstream_base_url = cli_args.upstream_base_url.clone(); + } + if is_explicit(matches, "object_store_endpoint") { + out.object_store_endpoint = cli_args.object_store_endpoint.clone(); + } + if is_explicit(matches, "object_store_auth_header") { + out.object_store_auth_header = cli_args.object_store_auth_header.clone(); + } + + out +} + +fn is_explicit(matches: &ArgMatches, id: &str) -> bool { + matches + .value_source(id) + .is_some_and(|source| matches!(source, ValueSource::CommandLine | ValueSource::EnvVariable)) +} + +fn join_csv(values: &[String]) -> String { + values.iter().map(String::as_str).collect::>().join(",") +} + +#[cfg(test)] +mod config_tests { + use super::*; + use std::io::Write; + + #[test] + fn loads_redis_and_upstream_from_toml() { + let mut file = tempfile::NamedTempFile::new().expect("temp file"); + write!( + file, + r#" +[redis] +url = "redis://redis.example:6379/0" + +[upstream] +base_url = "http://upstream.example:9000" + +[server] +port = 19080 +"# + ) + .expect("write temp config"); + + let cfg = ServiceConfigFile::load(file.path()).expect("load config"); + let args = cfg.into_args(); + assert_eq!(args.redis_url, "redis://redis.example:6379/0"); + assert_eq!(args.upstream_base_url.as_deref(), Some("http://upstream.example:9000")); + assert_eq!(args.port, 19080); + } + + #[test] + fn resolve_config_path_prefers_explicit() { + let explicit = PathBuf::from("/tmp/custom.toml"); + assert_eq!(resolve_config_path(Some(explicit.clone())), Some(explicit)); + } + + #[test] + fn default_config_path_in_dir_finds_sibling_file() { + let dir = tempfile::tempdir().expect("temp dir"); + let config = dir.path().join(DEFAULT_CONFIG_FILE_NAME); + std::fs::write(&config, "[redis]\nurl = \"redis://127.0.0.1/\"").expect("write config"); + + let fake_exe = dir.path().join("ai-gateway-service"); + std::fs::write(&fake_exe, b"").expect("write fake exe"); + + assert_eq!(default_config_path_in_dir(&fake_exe), Some(config)); + } + + #[test] + fn default_config_path_in_dir_returns_none_when_missing() { + let dir = tempfile::tempdir().expect("temp dir"); + let fake_exe = dir.path().join("ai-gateway-service"); + std::fs::write(&fake_exe, b"").expect("write fake exe"); + + assert_eq!(default_config_path_in_dir(&fake_exe), None); + } +} diff --git a/binary/ai-gateway-service/src/app/handlers.rs b/binary/ai-gateway-service/src/app/handlers.rs new file mode 100644 index 00000000..8aaa62f0 --- /dev/null +++ b/binary/ai-gateway-service/src/app/handlers.rs @@ -0,0 +1,240 @@ +async fn check_rate_limit(State(state): State, headers: HeaderMap, uri: Uri) -> Result, ServiceError> { + let tenant = required_header(&headers, "x-tenant-id")?; + let model = optional_header(&headers, "x-model").unwrap_or_else(|| "default".to_string()); + let path = optional_header(&headers, "x-original-path").unwrap_or_else(|| uri.path().to_string()); + let policy = optional_header(&headers, "x-ratelimit-policy").unwrap_or_else(|| "abandon".to_string()); + let rate_limit = resolve_rate_limit(&state, &tenant, &model, &path, &policy).await?; + let (tokens_key, ts_key) = tenant_rate_limit_keys(&tenant); + let now = now_ms(); + + let out: Vec = state + .redis + .eval( + TOKEN_BUCKET_LUA, + vec![tokens_key, ts_key], + vec![rate_limit.rps.to_string(), rate_limit.burst.to_string(), now.to_string(), rate_limit.cost.to_string()], + ) + .await?; + + let allowed = out.first().copied().unwrap_or(0) == 1; + if !allowed { + record_rate_limited(&state.metrics, &policy, &tenant); + } + Ok(Json(RateLimitResponse { + allowed, + remaining_tokens_milli: out.get(1).copied().unwrap_or(0), + retry_after_ms: out.get(2).copied().unwrap_or(0), + })) +} + +async fn enqueue(State(state): State, method: Method, uri: Uri, headers: HeaderMap, body: Body) -> Result { + let accepted = enqueue_job(&state, QueuePolicy::Queue, method, uri, headers, body).await?; + let mut resp = (StatusCode::ACCEPTED, Json(&accepted.response)).into_response(); + resp.headers_mut().insert("x-job-id", header_value(&accepted.response.job_id)?); + resp.headers_mut().insert("location", header_value(&accepted.response.poll_url)?); + Ok(resp) +} + +async fn enqueue_and_wait(State(state): State, method: Method, uri: Uri, headers: HeaderMap, body: Body) -> Result { + let timeout_secs = optional_header(&headers, "x-request-timeout").and_then(|v| v.parse::().ok()).unwrap_or(state.cfg.wait_timeout_secs); + state.metrics.wait_total.fetch_add(1, Ordering::Relaxed); + let accepted = enqueue_job(&state, QueuePolicy::Wait, method, uri, headers, body).await?; + let channel = result_channel(&state, &accepted.response.job_id); + + if let Some(result) = load_result(&state, &accepted.response.job_id).await? { + return Ok(result_to_response(result, accepted.created_at_ms)?); + } + + let wait = state + .wait_subscriber + .wait_for_channel(channel.as_str(), Duration::from_secs(timeout_secs)) + .await; + + match wait { + Ok(()) => { + if let Some(result) = load_result(&state, &accepted.response.job_id).await? { + Ok(result_to_response(result, accepted.created_at_ms)?) + } else { + Err(ServiceError::gateway_timeout(format!( + "job {} completed notification received but result is missing", + accepted.response.job_id + ))) + } + } + Err(e) if e.status == StatusCode::GATEWAY_TIMEOUT => { + state.metrics.wait_timeout_total.fetch_add(1, Ordering::Relaxed); + let waited_ms = now_ms().saturating_sub(accepted.created_at_ms); + let body = Json(serde_json::json!({ + "error": "timeout", + "job_id": accepted.response.job_id, + "poll_url": accepted.response.poll_url, + "waited_ms": waited_ms, + "message": "Job is still processing. Switch to queue mode with a callback for long tasks." + })); + Ok((StatusCode::GATEWAY_TIMEOUT, body).into_response()) + } + Err(e) => Err(e), + } +} + +async fn get_job(State(state): State, Path(job_id): Path) -> Result { + match load_result(&state, &job_id).await? { + Some(result) if result.status == "completed" => poll_result_to_response(result), + Some(result) => Ok(Json(serde_json::json!({ + "job_id": result.job_id, + "status": result.status, + "http_status": result.http_status, + "error": result.error, + "completed_at": format_completed_at_rfc3339(result.completed_at_ms), + })) + .into_response()), + None => Ok((StatusCode::NOT_FOUND, Json(serde_json::json!({ "error": "not_found", "job_id": job_id }))).into_response()), + } +} + +async fn healthz() -> &'static str { + "ok" +} + +async fn metrics(State(state): State) -> Result { + let queue_depth: i64 = state.redis.xlen(state.cfg.stream_key.as_str()).await.unwrap_or_default(); + let high_queue_depth: i64 = state.redis.xlen(state.cfg.high_priority_stream_key.as_str()).await.unwrap_or_default(); + let low_queue_depth: i64 = state.redis.xlen(state.cfg.low_priority_stream_key.as_str()).await.unwrap_or_default(); + let job_dlq_depth: i64 = state.redis.xlen(state.cfg.job_dlq_stream.as_str()).await.unwrap_or_default(); + let callback_retry_depth: i64 = state.redis.xlen(state.cfg.callback_retry_stream.as_str()).await.unwrap_or_default(); + let callback_dlq_depth: i64 = state.redis.xlen(state.cfg.callback_dlq_stream.as_str()).await.unwrap_or_default(); + let pel_size = pending_size(&state, &state.cfg.stream_key).await; + let high_pel_size = pending_size(&state, &state.cfg.high_priority_stream_key).await; + let low_pel_size = pending_size(&state, &state.cfg.low_priority_stream_key).await; + let callback_retry_pel_size = pending_size_for_group(&state, &state.cfg.callback_retry_stream, &state.cfg.callback_retry_group).await; + + let wait_total = state.metrics.wait_total.load(Ordering::Relaxed); + let wait_timeout_total = state.metrics.wait_timeout_total.load(Ordering::Relaxed); + let callback_failure_total = state.metrics.callback_failure_total.load(Ordering::Relaxed); + let worker_completed_total = state.metrics.worker_completed_total.load(Ordering::Relaxed); + let wait_timeout_rate = if wait_total > 0 { + wait_timeout_total as f64 / wait_total as f64 + } else { + 0.0 + }; + let callback_failure_rate = if worker_completed_total > 0 { + callback_failure_total as f64 / worker_completed_total as f64 + } else { + 0.0 + }; + let labeled_lines = format_labeled_lines(&state.metrics); + + let body = format!( + "\ +rate_limited_total {}\n\ +enqueue_total {}\n\ +enqueue_total{{policy=\"queue\"}} {}\n\ +enqueue_total{{policy=\"wait\"}} {}\n\ +enqueue_total{{priority=\"high\"}} {}\n\ +enqueue_total{{priority=\"normal\"}} {}\n\ +enqueue_total{{priority=\"low\"}} {}\n\ +enqueue_latency_ms_count {}\n\ +enqueue_latency_ms_sum {}\n\ +enqueue_latency_ms_bucket{{le=\"100\"}} {}\n\ +enqueue_latency_ms_bucket{{le=\"500\"}} {}\n\ +enqueue_latency_ms_bucket{{le=\"1000\"}} {}\n\ +enqueue_latency_ms_bucket{{le=\"+Inf\"}} {}\n\ +enqueue_body_size_bytes_count {}\n\ +enqueue_body_size_bytes_sum {}\n\ +enqueue_body_size_bytes_bucket{{le=\"10240\"}} {}\n\ +enqueue_body_size_bytes_bucket{{le=\"131072\"}} {}\n\ +enqueue_body_size_bytes_bucket{{le=\"5242880\"}} {}\n\ +enqueue_body_size_bytes_bucket{{le=\"+Inf\"}} {}\n\ +wait_total {}\n\ +wait_timeout_total {}\n\ +wait_timeout_rate {:.6}\n\ +callback_failure_total {}\n\ +callback_failure_rate {:.6}\n\ +callback_retry_total {}\n\ +callback_retry_success_total {}\n\ +callback_retry_dlq_total {}\n\ +worker_completed_total {}\n\ +worker_failed_total {}\n\ +worker_processing_time_ms_count {}\n\ +worker_processing_time_ms_sum {}\n\ +worker_processing_time_ms_bucket{{le=\"1000\"}} {}\n\ +worker_processing_time_ms_bucket{{le=\"5000\"}} {}\n\ +worker_processing_time_ms_bucket{{le=\"30000\"}} {}\n\ +worker_processing_time_ms_bucket{{le=\"+Inf\"}} {}\n\ +reclaimed_total {}\n\ +job_dlq_total {}\n\ +lease_skip_total {}\n\ +object_offload_total {}\n\ +object_multipart_abort_total {}\n\ +queue_depth {}\n\ +queue_depth{{priority=\"normal\"}} {}\n\ +queue_depth{{priority=\"high\"}} {}\n\ +queue_depth{{priority=\"low\"}} {}\n\ +pel_size {}\n\ +pel_size{{priority=\"normal\"}} {}\n\ +pel_size{{priority=\"high\"}} {}\n\ +pel_size{{priority=\"low\"}} {}\n\ +job_dlq_depth {}\n\ +callback_retry_depth {}\n\ +callback_retry_pel_size {}\n\ +callback_dlq_depth {}\n\ +{labeled_lines}\n", + state.metrics.rate_limited_total.load(Ordering::Relaxed), + state.metrics.enqueue_total.load(Ordering::Relaxed), + state.metrics.enqueue_queue_total.load(Ordering::Relaxed), + state.metrics.enqueue_wait_total.load(Ordering::Relaxed), + state.metrics.enqueue_priority_high_total.load(Ordering::Relaxed), + state.metrics.enqueue_priority_normal_total.load(Ordering::Relaxed), + state.metrics.enqueue_priority_low_total.load(Ordering::Relaxed), + state.metrics.enqueue_latency_count.load(Ordering::Relaxed), + state.metrics.enqueue_latency_sum_ms.load(Ordering::Relaxed), + state.metrics.enqueue_latency_le_100_ms.load(Ordering::Relaxed), + state.metrics.enqueue_latency_le_100_ms.load(Ordering::Relaxed) + state.metrics.enqueue_latency_le_500_ms.load(Ordering::Relaxed), + state.metrics.enqueue_latency_le_100_ms.load(Ordering::Relaxed) + + state.metrics.enqueue_latency_le_500_ms.load(Ordering::Relaxed) + + state.metrics.enqueue_latency_le_1000_ms.load(Ordering::Relaxed), + state.metrics.enqueue_latency_count.load(Ordering::Relaxed), + state.metrics.body_size_count.load(Ordering::Relaxed), + state.metrics.body_size_sum_bytes.load(Ordering::Relaxed), + state.metrics.body_size_le_10kb.load(Ordering::Relaxed), + state.metrics.body_size_le_10kb.load(Ordering::Relaxed) + state.metrics.body_size_le_128kb.load(Ordering::Relaxed), + state.metrics.body_size_le_10kb.load(Ordering::Relaxed) + state.metrics.body_size_le_128kb.load(Ordering::Relaxed) + state.metrics.body_size_le_5mb.load(Ordering::Relaxed), + state.metrics.body_size_count.load(Ordering::Relaxed), + state.metrics.wait_total.load(Ordering::Relaxed), + state.metrics.wait_timeout_total.load(Ordering::Relaxed), + wait_timeout_rate, + state.metrics.callback_failure_total.load(Ordering::Relaxed), + callback_failure_rate, + state.metrics.callback_retry_total.load(Ordering::Relaxed), + state.metrics.callback_retry_success_total.load(Ordering::Relaxed), + state.metrics.callback_retry_dlq_total.load(Ordering::Relaxed), + state.metrics.worker_completed_total.load(Ordering::Relaxed), + state.metrics.worker_failed_total.load(Ordering::Relaxed), + state.metrics.worker_processing_count.load(Ordering::Relaxed), + state.metrics.worker_processing_sum_ms.load(Ordering::Relaxed), + state.metrics.worker_processing_le_1000_ms.load(Ordering::Relaxed), + state.metrics.worker_processing_le_1000_ms.load(Ordering::Relaxed) + state.metrics.worker_processing_le_5000_ms.load(Ordering::Relaxed), + state.metrics.worker_processing_le_1000_ms.load(Ordering::Relaxed) + + state.metrics.worker_processing_le_5000_ms.load(Ordering::Relaxed) + + state.metrics.worker_processing_le_30000_ms.load(Ordering::Relaxed), + state.metrics.worker_processing_count.load(Ordering::Relaxed), + state.metrics.reclaimed_total.load(Ordering::Relaxed), + state.metrics.job_dlq_total.load(Ordering::Relaxed), + state.metrics.lease_skip_total.load(Ordering::Relaxed), + state.metrics.object_offload_total.load(Ordering::Relaxed), + state.metrics.object_multipart_abort_total.load(Ordering::Relaxed), + queue_depth, + queue_depth, + high_queue_depth, + low_queue_depth, + pel_size, + pel_size, + high_pel_size, + low_pel_size, + job_dlq_depth, + callback_retry_depth, + callback_retry_pel_size, + callback_dlq_depth, + ); + Ok((StatusCode::OK, [("content-type", "text/plain; version=0.0.4")], body).into_response()) +} \ No newline at end of file diff --git a/binary/ai-gateway-service/src/app/metrics.rs b/binary/ai-gateway-service/src/app/metrics.rs new file mode 100644 index 00000000..18ed9f12 --- /dev/null +++ b/binary/ai-gateway-service/src/app/metrics.rs @@ -0,0 +1,116 @@ +async fn trim_stream(state: &AppState, stream: &str) -> Result<(), ServiceError> { + if state.cfg.stream_max_len > 0 { + let _: i64 = state.redis.xtrim(stream, ("MAXLEN", "~", state.cfg.stream_max_len as i64)).await?; + } + Ok(()) +} + +async fn pending_size(state: &AppState, stream: &str) -> i64 { + pending_size_for_group(state, stream, state.cfg.consumer_group.as_str()).await +} + +async fn pending_size_for_group(state: &AppState, stream: &str, group: &str) -> i64 { + let raw: FredResult = state.redis.xpending(stream, group, ()).await; + match raw { + Ok(value) => pending_count_from_value(&value), + Err(e) => { + tracing::debug!(stream = %stream, group = %group, error = %e, "read stream pending size failed"); + 0 + } + } +} + +fn pending_count_from_value(value: &Value) -> i64 { + match value { + Value::Integer(value) => (*value).max(0), + Value::String(value) => value.parse::().unwrap_or(0).max(0), + Value::Bytes(value) => std::str::from_utf8(value).ok().and_then(|value| value.parse::().ok()).unwrap_or(0).max(0), + Value::Array(values) => values.first().map(pending_count_from_value).unwrap_or(0), + Value::Map(values) => values + .iter() + .find_map(|(key, value)| { + let key = key.as_str()?; + if key.eq_ignore_ascii_case("pending") || key.eq_ignore_ascii_case("count") { + Some(pending_count_from_value(value)) + } else { + None + } + }) + .unwrap_or(0), + _ => 0, + } +} + +fn inc_labeled(metrics: &Metrics, key: impl Into) { + let mut map = metrics.labeled.lock().unwrap_or_else(|error| error.into_inner()); + *map.entry(key.into()).or_insert(0) += 1; +} + +fn format_labeled_lines(metrics: &Metrics) -> String { + let map = metrics.labeled.lock().unwrap_or_else(|error| error.into_inner()); + let mut keys: Vec<_> = map.keys().cloned().collect(); + keys.sort(); + keys.into_iter() + .filter_map(|key| map.get(&key).copied().map(|value| format!("{key} {value}"))) + .collect::>() + .join("\n") +} + +fn observe_enqueue_latency(metrics: &Metrics, elapsed_ms: u64, policy: &str, size_bucket: &str) { + metrics.enqueue_latency_count.fetch_add(1, Ordering::Relaxed); + metrics.enqueue_latency_sum_ms.fetch_add(elapsed_ms, Ordering::Relaxed); + let le = if elapsed_ms <= 100 { + metrics.enqueue_latency_le_100_ms.fetch_add(1, Ordering::Relaxed); + "100" + } else if elapsed_ms <= 500 { + metrics.enqueue_latency_le_500_ms.fetch_add(1, Ordering::Relaxed); + "500" + } else if elapsed_ms <= 1000 { + metrics.enqueue_latency_le_1000_ms.fetch_add(1, Ordering::Relaxed); + "1000" + } else { + metrics.enqueue_latency_gt_1000_ms.fetch_add(1, Ordering::Relaxed); + "+Inf" + }; + inc_labeled( + metrics, + format!(r#"enqueue_latency_ms_bucket{{policy="{policy}",size_bucket="{size_bucket}",le="{le}"}}"#), + ); +} + +fn observe_body_size(metrics: &Metrics, size: usize) { + metrics.body_size_count.fetch_add(1, Ordering::Relaxed); + metrics.body_size_sum_bytes.fetch_add(size as u64, Ordering::Relaxed); + if size <= 10 * 1024 { + metrics.body_size_le_10kb.fetch_add(1, Ordering::Relaxed); + } else if size <= 128 * 1024 { + metrics.body_size_le_128kb.fetch_add(1, Ordering::Relaxed); + } else if size <= 5 * 1024 * 1024 { + metrics.body_size_le_5mb.fetch_add(1, Ordering::Relaxed); + } else { + metrics.body_size_gt_5mb.fetch_add(1, Ordering::Relaxed); + } +} + +fn observe_worker_processing(metrics: &Metrics, elapsed_ms: u64, model: &str) { + metrics.worker_processing_count.fetch_add(1, Ordering::Relaxed); + metrics.worker_processing_sum_ms.fetch_add(elapsed_ms, Ordering::Relaxed); + let model = metrics_label(model); + let le = if elapsed_ms <= 1000 { + metrics.worker_processing_le_1000_ms.fetch_add(1, Ordering::Relaxed); + "1000" + } else if elapsed_ms <= 5000 { + metrics.worker_processing_le_5000_ms.fetch_add(1, Ordering::Relaxed); + "5000" + } else if elapsed_ms <= 30_000 { + metrics.worker_processing_le_30000_ms.fetch_add(1, Ordering::Relaxed); + "30000" + } else { + metrics.worker_processing_gt_30000_ms.fetch_add(1, Ordering::Relaxed); + "+Inf" + }; + inc_labeled( + metrics, + format!(r#"worker_processing_time_ms_bucket{{model="{model}",le="{le}"}}"#), + ); +} diff --git a/binary/ai-gateway-service/src/app/object_store.rs b/binary/ai-gateway-service/src/app/object_store.rs new file mode 100644 index 00000000..f5969f06 --- /dev/null +++ b/binary/ai-gateway-service/src/app/object_store.rs @@ -0,0 +1,240 @@ +async fn store_body(state: &AppState, job_id: &str, body: Body) -> Result { + let object_ref = format!("{}/{}/body.bin", state.cfg.object_store_prefix.trim_matches('/'), sanitize_key(job_id)); + let mut stream = body.into_data_stream(); + let mut pending = Vec::new(); + let mut total_size = 0usize; + let part_size = state.cfg.object_multipart_part_size.max(5 * 1024 * 1024); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| ServiceError::bad_request(format!("read request body: {e}")))?; + total_size = total_size.checked_add(chunk.len()).ok_or_else(|| ServiceError::payload_too_large("request body is too large"))?; + if total_size > state.cfg.max_body_bytes { + return Err(ServiceError::payload_too_large(format!("request body exceeds max size {}", state.cfg.max_body_bytes))); + } + + if state.cfg.object_store_endpoint.is_none() && pending.len() + chunk.len() > state.cfg.inline_threshold { + return Err(ServiceError::payload_too_large(format!( + "request body exceeds inline threshold {} and object store is not configured", + state.cfg.inline_threshold + ))); + } + + if state.cfg.object_store_endpoint.is_some() && pending.len() + chunk.len() > state.cfg.inline_threshold { + pending.extend_from_slice(&chunk); + let upload_id = initiate_multipart_upload(state, &object_ref).await?; + let state = state.clone(); + let object_ref_for_task = object_ref.clone(); + let handle = tokio::spawn(async move { + finish_offload_upload(state, object_ref_for_task, upload_id, pending, stream, part_size, total_size).await + }); + return Ok(BodyStoreOutcome { + location: BodyLocation { + body_base64: String::new(), + object_ref, + size: total_size, + storage: "object", + }, + pending_upload: Some(handle), + }); + } + + pending.extend_from_slice(&chunk); + } + + Ok(BodyStoreOutcome { + location: BodyLocation { + body_base64: base64::engine::general_purpose::STANDARD.encode(&pending), + object_ref: String::new(), + size: total_size, + storage: "inline", + }, + pending_upload: None, + }) +} + +async fn finish_offload_upload( + state: AppState, + object_ref: String, + upload_id: String, + mut pending: Vec, + mut stream: impl futures_util::Stream> + Unpin, + part_size: usize, + mut total_size: usize, +) -> Result<(), ServiceError> { + let mut parts = Vec::new(); + let upload_result = async { + while pending.len() >= part_size { + let part_body = pending.drain(..part_size).collect::>(); + parts.push(upload_multipart_part(&state, &object_ref, &upload_id, parts.len() + 1, part_body).await?); + } + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| ServiceError::bad_request(format!("read request body: {e}")))?; + total_size = total_size.checked_add(chunk.len()).ok_or_else(|| ServiceError::payload_too_large("request body is too large"))?; + if total_size > state.cfg.max_body_bytes { + abort_upload_if_needed(&state, &object_ref, Some(&upload_id)).await; + return Err(ServiceError::payload_too_large(format!("request body exceeds max size {}", state.cfg.max_body_bytes))); + } + pending.extend_from_slice(&chunk); + while pending.len() >= part_size { + let part_body = pending.drain(..part_size).collect::>(); + parts.push(upload_multipart_part(&state, &object_ref, &upload_id, parts.len() + 1, part_body).await?); + } + } + if !pending.is_empty() || parts.is_empty() { + parts.push(upload_multipart_part(&state, &object_ref, &upload_id, parts.len() + 1, pending).await?); + } + complete_multipart_upload(&state, &object_ref, &upload_id, &parts).await?; + state.metrics.object_offload_total.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + .await; + + if upload_result.is_err() { + abort_upload_if_needed(&state, &object_ref, Some(&upload_id)).await; + } + upload_result +} + +async fn load_body(state: &AppState, fields: &HashMap) -> Result, ServiceError> { + let storage = field_string(fields, "storage").unwrap_or_else(|| "inline".to_string()); + if storage == "object" { + let object_ref = field_string(fields, "ref").ok_or_else(|| ServiceError::bad_request("job body is missing object ref"))?; + let url = object_url(state, &object_ref); + let mut req = state.http.get(url); + if let Some((name, value)) = object_auth_header(&state.cfg.object_store_auth_header)? { + req = req.header(name, value); + } + return Ok(req.send().await?.error_for_status()?.bytes().await?.to_vec()); + } + + if let Some(body_base64) = field_string(fields, "body") { + return base64::engine::general_purpose::STANDARD.decode(body_base64).map_err(|e| ServiceError::bad_request(format!("decode job body: {e}"))); + } + Ok(field_bytes(fields, "body").unwrap_or_default()) +} + +async fn initiate_multipart_upload(state: &AppState, object_ref: &str) -> Result { + let url = object_url_with_query(state, object_ref, "uploads"); + let mut req = state.http.post(url); + if let Some((name, value)) = object_auth_header(&state.cfg.object_store_auth_header)? { + req = req.header(name, value); + } + let body = req.send().await?.error_for_status()?.text().await?; + extract_xml_tag(&body, "UploadId").ok_or_else(|| ServiceError::internal("multipart initiate response missing UploadId")) +} + +async fn upload_multipart_part(state: &AppState, object_ref: &str, upload_id: &str, part_number: usize, body: Vec) -> Result { + let query = format!("partNumber={part_number}&uploadId={}", encode_query_component(upload_id)); + let url = object_url_with_query(state, object_ref, &query); + let mut req = state.http.put(url).body(body); + if let Some((name, value)) = object_auth_header(&state.cfg.object_store_auth_header)? { + req = req.header(name, value); + } + let resp = req.send().await?.error_for_status()?; + let etag = resp + .headers() + .get("etag") + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) + .ok_or_else(|| ServiceError::internal("multipart upload part response missing ETag"))?; + Ok(CompletedPart { part_number, etag }) +} + +async fn complete_multipart_upload(state: &AppState, object_ref: &str, upload_id: &str, parts: &[CompletedPart]) -> Result<(), ServiceError> { + let query = format!("uploadId={}", encode_query_component(upload_id)); + let url = object_url_with_query(state, object_ref, &query); + let body = complete_multipart_xml(parts); + let mut req = state.http.post(url).header("content-type", "application/xml").body(body); + if let Some((name, value)) = object_auth_header(&state.cfg.object_store_auth_header)? { + req = req.header(name, value); + } + req.send().await?.error_for_status()?; + Ok(()) +} + +async fn abort_multipart_upload(state: &AppState, object_ref: &str, upload_id: &str) -> Result<(), ServiceError> { + let query = format!("uploadId={}", encode_query_component(upload_id)); + let url = object_url_with_query(state, object_ref, &query); + let mut req = state.http.delete(url); + if let Some((name, value)) = object_auth_header(&state.cfg.object_store_auth_header)? { + req = req.header(name, value); + } + req.send().await?.error_for_status()?; + Ok(()) +} + +async fn abort_upload_if_needed(state: &AppState, object_ref: &str, upload_id: Option<&str>) { + let Some(upload_id) = upload_id else { + return; + }; + state.metrics.object_multipart_abort_total.fetch_add(1, Ordering::Relaxed); + if let Err(abort_err) = abort_multipart_upload(state, object_ref, upload_id).await { + tracing::warn!(object_ref = %object_ref, upload_id = %upload_id, error = %abort_err.message, "multipart upload abort failed"); + } +} + +fn complete_multipart_xml(parts: &[CompletedPart]) -> String { + let mut out = String::from(""); + for part in parts { + out.push_str(""); + out.push_str(""); + out.push_str(&part.part_number.to_string()); + out.push_str(""); + out.push_str(""); + out.push_str(&xml_escape(&part.etag)); + out.push_str(""); + out.push_str(""); + } + out.push_str(""); + out +} + +fn object_url(state: &AppState, object_ref: &str) -> String { + format!( + "{}/{}/{}", + state.cfg.object_store_endpoint.as_deref().unwrap_or_default().trim_end_matches('/'), + state.cfg.object_store_bucket.trim_matches('/'), + object_ref.trim_start_matches('/') + ) +} + +fn object_url_with_query(state: &AppState, object_ref: &str, query: &str) -> String { + format!("{}?{}", object_url(state, object_ref), query) +} + +fn object_auth_header(raw: &Option) -> Result, ServiceError> { + let Some(raw) = raw.as_deref() else { + return Ok(None); + }; + let Some((name, value)) = raw.split_once(':') else { + return Err(ServiceError::bad_request("AI_OBJECT_STORE_AUTH_HEADER must be `Header-Name: value`")); + }; + if HeaderName::try_from(name.trim()).is_err() || HeaderValue::from_str(value.trim()).is_err() { + return Err(ServiceError::bad_request("invalid object auth header")); + } + Ok(Some((name.trim().to_string(), value.trim().to_string()))) +} + +fn extract_xml_tag(xml: &str, tag: &str) -> Option { + let start_tag = format!("<{tag}>"); + let end_tag = format!(""); + let start = xml.find(&start_tag)? + start_tag.len(); + let end = xml[start..].find(&end_tag)? + start; + Some(xml[start..end].trim().to_string()) +} + +fn encode_query_component(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + for byte in input.bytes() { + if byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_' | b'.' | b'~') { + out.push(byte as char); + } else { + out.push_str(&format!("%{byte:02X}")); + } + } + out +} + +fn xml_escape(input: &str) -> String { + input.replace('&', "&").replace('<', "<").replace('>', ">").replace('"', """).replace('\'', "'") +} diff --git a/binary/ai-gateway-service/src/app/queue.rs b/binary/ai-gateway-service/src/app/queue.rs new file mode 100644 index 00000000..39f5dcb0 --- /dev/null +++ b/binary/ai-gateway-service/src/app/queue.rs @@ -0,0 +1,479 @@ +async fn enqueue_job(state: &AppState, policy: QueuePolicy, _method: Method, uri: Uri, headers: HeaderMap, body: Body) -> Result { + let enqueue_started_at = now_ms(); + let _permit = state.body_permits.acquire().await.map_err(|_| ServiceError::internal("body semaphore closed"))?; + let job_id = new_job_id(); + let tenant_id = required_header(&headers, "x-tenant-id")?; + let model = optional_header(&headers, "x-model").unwrap_or_else(|| "default".to_string()); + let callback_url = optional_header(&headers, "x-callback-url").unwrap_or_default(); + validate_callback_url(state, policy, &callback_url)?; + let original_method = optional_header(&headers, "x-original-method").unwrap_or_else(|| "POST".to_string()); + let original_path = optional_header(&headers, "x-original-path").unwrap_or_else(|| uri.path().to_string()); + let request_headers = headers_to_json(&headers)?; + let created_at = now_ms(); + let body_outcome = store_body(state, &job_id, body).await?; + let body_ref = body_outcome.location; + let body_size = body_ref.size; + let body_storage = body_ref.storage; + let (stream_key, priority) = stream_for_request(state, &headers, &tenant_id, &model); + + let xadd_future = async { + let stream_id: String = state + .redis + .xadd( + stream_key.as_str(), + false, + None::<()>, + "*", + vec![ + ("job_id", Value::String(job_id.clone().into())), + ("tenant_id", Value::String(tenant_id.into())), + ("policy", Value::String(policy.as_str().into())), + ("model", Value::String(model.into())), + ("priority", Value::String(priority.as_str().into())), + ("method", Value::String(original_method.into())), + ("path", Value::String(original_path.into())), + ("headers", Value::String(request_headers.into())), + ("body", Value::String(body_ref.body_base64.into())), + ("ref", Value::String(body_ref.object_ref.into())), + ("size", Value::Integer(body_ref.size as i64)), + ("storage", Value::String(body_ref.storage.into())), + ("callback_url", Value::String(callback_url.into())), + ("created_at", Value::Integer(created_at as i64)), + ], + ) + .await?; + trim_stream(state, &stream_key).await?; + Ok::(stream_id) + }; + + let stream_id = if let Some(upload) = body_outcome.pending_upload { + let (upload_join, stream_id_result) = tokio::join!(upload, xadd_future); + let upload_result = upload_join.map_err(|e| ServiceError::internal(format!("body upload task failed: {e}")))?; + upload_result?; + stream_id_result? + } else { + xadd_future.await? + }; + + state.metrics.enqueue_total.fetch_add(1, Ordering::Relaxed); + observe_enqueue_latency( + &state.metrics, + now_ms().saturating_sub(enqueue_started_at), + policy.as_str(), + body_size_bucket(body_size, body_storage), + ); + observe_body_size(&state.metrics, body_size); + match priority { + QueuePriority::High => { + state.metrics.enqueue_priority_high_total.fetch_add(1, Ordering::Relaxed); + } + QueuePriority::Normal => { + state.metrics.enqueue_priority_normal_total.fetch_add(1, Ordering::Relaxed); + } + QueuePriority::Low => { + state.metrics.enqueue_priority_low_total.fetch_add(1, Ordering::Relaxed); + } + } + match policy { + QueuePolicy::Queue => { + state.metrics.enqueue_queue_total.fetch_add(1, Ordering::Relaxed); + } + QueuePolicy::Wait => { + state.metrics.enqueue_wait_total.fetch_add(1, Ordering::Relaxed); + } + } + + Ok(AcceptedJob { + response: EnqueueResponse { + job_id: job_id.clone(), + stream_id, + stream_key, + status: "queued", + poll_url: job_poll_url(&job_id), + status_url: job_status_url_legacy(&job_id), + }, + created_at_ms: created_at, + }) +} + +fn spawn_workers(state: AppState) { + for idx in 0..state.cfg.worker_concurrency.max(1) { + let state = state.clone(); + tokio::spawn(async move { + let consumer = format!("{}-{idx}", state.cfg.consumer_name); + loop { + if let Err(e) = worker_once(&state, &consumer).await { + tracing::warn!(error = %e.message, "worker loop failed"); + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + }); + } +} + +async fn worker_once(state: &AppState, consumer: &str) -> Result<(), ServiceError> { + let streams = worker_stream_order(state); + for (idx, stream) in streams.iter().enumerate() { + let block = if idx + 1 == streams.len() { 1000 } else { 10 }; + let processed = read_worker_stream(state, consumer, stream, block).await?; + if processed > 0 { + return Ok(()); + } + } + Ok(()) +} + +async fn read_worker_stream(state: &AppState, consumer: &str, stream: &str, block_ms: u64) -> Result { + let reply = xreadgroup_map_or_empty( + &state.worker_redis, + state.cfg.consumer_group.as_str(), + consumer, + Some(5), + Some(block_ms), + false, + vec![stream], + vec![">"], + ) + .await?; + + let mut tasks = Vec::new(); + for (_stream, entries) in reply { + for (entry_id, fields) in entries { + let state = state.clone(); + let stream = stream.to_string(); + tasks.push(tokio::spawn(async move { + process_stream_entry(&state, stream.as_str(), entry_id.as_str(), &fields).await + })); + } + } + + let mut processed = 0; + for task in tasks { + match task.await { + Ok(Ok(true)) => processed += 1, + Ok(Ok(false)) => {} + Ok(Err(e)) => { + tracing::warn!(error = %e.message, "job processing failed"); + state.metrics.worker_failed_total.fetch_add(1, Ordering::Relaxed); + } + Err(e) => { + tracing::warn!(error = %e, "worker task join failed"); + state.metrics.worker_failed_total.fetch_add(1, Ordering::Relaxed); + } + } + } + Ok(processed) +} + +async fn process_stream_entry(state: &AppState, stream: &str, entry_id: &str, fields: &HashMap) -> Result { + let job_id = field_string(fields, "job_id").ok_or_else(|| ServiceError::bad_request("job missing job_id"))?; + let lease_owner = format!("{}:{stream}:{entry_id}:{}", state.cfg.consumer_name, now_ms()); + + if !acquire_job_lease(state, &job_id, &lease_owner).await? { + state.metrics.lease_skip_total.fetch_add(1, Ordering::Relaxed); + tracing::info!(job_id = %job_id, stream = %stream, entry_id = %entry_id, "job is already leased; skip reclaimed duplicate"); + return Ok(false); + } + + let attempt = increment_job_delivery_attempt(state, &job_id).await?; + if attempt > state.cfg.job_max_delivery_attempts { + enqueue_job_dlq(state, stream, entry_id, fields, attempt, "max_delivery_attempts_exceeded").await?; + ack_stream_entry(state, stream, entry_id).await?; + release_job_lease(state, &job_id).await; + state.metrics.job_dlq_total.fetch_add(1, Ordering::Relaxed); + return Ok(true); + } + + let processing_started_at = now_ms(); + let model = field_string(fields, "model").unwrap_or_else(|| "default".to_string()); + match process_job(state, stream, entry_id, fields).await { + Ok(()) => { + observe_worker_processing(&state.metrics, now_ms().saturating_sub(processing_started_at), &model); + ack_stream_entry(state, stream, entry_id).await?; + clear_job_delivery_attempt(state, &job_id).await; + release_job_lease(state, &job_id).await; + Ok(true) + } + Err(e) => { + observe_worker_processing(&state.metrics, now_ms().saturating_sub(processing_started_at), &model); + release_job_lease(state, &job_id).await; + Err(e) + } + } +} + +async fn process_job(state: &AppState, _stream: &str, _stream_id: &str, fields: &HashMap) -> Result<(), ServiceError> { + let Some(base) = state.cfg.upstream_base_url.as_deref() else { + return Err(ServiceError::internal("upstream base URL is not configured")); + }; + let job_id = field_string(fields, "job_id").ok_or_else(|| ServiceError::bad_request("job missing job_id"))?; + let method = field_string(fields, "method").unwrap_or_else(|| "POST".to_string()); + let path = field_string(fields, "path").unwrap_or_else(|| "/".to_string()); + let headers_json = field_string(fields, "headers").unwrap_or_else(|| "{}".to_string()); + let callback_url = field_string(fields, "callback_url").unwrap_or_default(); + let body = load_body(state, fields).await?; + let headers: HashMap = serde_json::from_str(&headers_json).unwrap_or_default(); + + let url = format!("{}{}", base.trim_end_matches('/'), path); + let parsed_method = method.parse::().unwrap_or(reqwest::Method::POST); + let mut req = state.http.request(parsed_method, url); + for (name, value) in headers { + if should_forward_header(&name) { + req = req.header(name, value); + } + } + let upstream = req.body(body).send().await; + let result = match upstream { + Ok(resp) => { + let status = resp.status().as_u16(); + let mut headers = HashMap::new(); + for (name, value) in resp.headers() { + if let Ok(value) = value.to_str() { + headers.insert(name.as_str().to_string(), value.to_string()); + } + } + let body = resp.bytes().await.unwrap_or_default(); + StoredResult { + job_id: job_id.clone(), + status: "completed".to_string(), + http_status: status, + headers, + body_base64: base64::engine::general_purpose::STANDARD.encode(body), + completed_at_ms: now_ms(), + error: None, + } + } + Err(e) => StoredResult { + job_id: job_id.clone(), + status: "failed".to_string(), + http_status: 502, + headers: HashMap::new(), + body_base64: String::new(), + completed_at_ms: now_ms(), + error: Some(e.to_string()), + }, + }; + + store_result(state, &result).await?; + if !callback_url.is_empty() { + let callback_body = callback_body(&result); + if let Err(e) = post_callback(state, &callback_url, &job_id, &callback_body).await { + tracing::warn!(job_id = %job_id, error = %e.message, "callback failed"); + state.metrics.callback_failure_total.fetch_add(1, Ordering::Relaxed); + enqueue_callback_retry(state, &callback_url, &job_id, &callback_body, e.message.as_str()).await?; + } + } + state.metrics.worker_completed_total.fetch_add(1, Ordering::Relaxed); + Ok(()) +} + +async fn acquire_job_lease(state: &AppState, job_id: &str, owner: &str) -> Result { + let key = job_lease_key(job_id); + let result: Option = state + .redis + .set( + key, + owner, + Some(Expiration::EX(state.cfg.job_process_lease_secs.max(1) as i64)), + Some(SetOptions::NX), + false, + ) + .await?; + Ok(result.is_some()) +} + +async fn release_job_lease(state: &AppState, job_id: &str) { + let _: Result = state.redis.del(job_lease_key(job_id)).await; +} + +async fn increment_job_delivery_attempt(state: &AppState, job_id: &str) -> Result { + let key = job_attempt_key(job_id); + let attempt: i64 = state.redis.incr_by(key.as_str(), 1).await?; + let _: () = state.redis.expire(key.as_str(), state.cfg.result_ttl_secs.max(300) as i64, None::).await?; + Ok(attempt.max(0) as u32) +} + +async fn clear_job_delivery_attempt(state: &AppState, job_id: &str) { + let _: Result = state.redis.del(job_attempt_key(job_id)).await; +} + +async fn ack_stream_entry(state: &AppState, stream: &str, entry_id: &str) -> Result<(), ServiceError> { + let _: i64 = state.redis.xack(stream, state.cfg.consumer_group.as_str(), vec![entry_id]).await?; + Ok(()) +} + +async fn enqueue_job_dlq(state: &AppState, stream: &str, entry_id: &str, fields: &HashMap, attempts: u32, reason: &str) -> Result<(), ServiceError> { + let job_id = field_string(fields, "job_id").unwrap_or_default(); + let fields_json = stream_fields_to_json(fields)?; + let _: String = state + .redis + .xadd( + state.cfg.job_dlq_stream.as_str(), + false, + None::<()>, + "*", + vec![ + ("job_id", Value::String(job_id.into())), + ("source_stream", Value::String(stream.to_string().into())), + ("source_entry_id", Value::String(entry_id.to_string().into())), + ("attempts", Value::Integer(attempts as i64)), + ("reason", Value::String(reason.to_string().into())), + ("fields", Value::String(fields_json.into())), + ("failed_at", Value::Integer(now_ms() as i64)), + ], + ) + .await?; + trim_stream(state, &state.cfg.job_dlq_stream).await?; + Ok(()) +} + +fn stream_fields_to_json(fields: &HashMap) -> Result { + let mut out = HashMap::new(); + for (key, value) in fields { + if let Some(value) = field_string(fields, key) { + out.insert(key.clone(), value); + } else { + out.insert(key.clone(), format!("{value:?}")); + } + } + serde_json::to_string(&out).map_err(|e| ServiceError::internal(format!("serialize job dlq fields: {e}"))) +} + +fn job_lease_key(job_id: &str) -> String { + format!("ai:job:lease:{}", sanitize_key(job_id)) +} + +fn job_attempt_key(job_id: &str) -> String { + format!("ai:job:attempt:{}", sanitize_key(job_id)) +} + +fn spawn_reclaimer(state: AppState) { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(state.cfg.reclaim_interval_secs.max(1))); + loop { + interval.tick().await; + if let Err(e) = reclaim_once(&state).await { + tracing::warn!(error = %e.message, "stream reclaim failed"); + } + } + }); +} + +async fn reclaim_once(state: &AppState) -> Result<(), ServiceError> { + let consumer = format!("{}-reclaimer", state.cfg.consumer_name); + let min_idle_ms = state.cfg.reclaim_min_idle_secs.saturating_mul(1000); + for stream in configured_streams(state) { + let (_cursor, entries): (String, Vec<(String, HashMap)>) = + state.worker_redis.xautoclaim_values(stream.as_str(), state.cfg.consumer_group.as_str(), consumer.as_str(), min_idle_ms, "0-0", Some(10), false).await?; + for (entry_id, fields) in entries { + match process_stream_entry(state, stream.as_str(), entry_id.as_str(), &fields).await { + Ok(true) => { + state.metrics.reclaimed_total.fetch_add(1, Ordering::Relaxed); + } + Ok(false) => {} + Err(e) => { + tracing::warn!(stream = %stream, entry_id = %entry_id, error = %e.message, "reclaimed job failed"); + } + } + } + } + Ok(()) +} + +async fn ensure_consumer_groups(state: &AppState) -> Result<(), ServiceError> { + for stream in configured_streams(state) { + ensure_consumer_group(state, &stream, &state.cfg.consumer_group).await?; + } + ensure_consumer_group(state, &state.cfg.callback_retry_stream, &state.cfg.callback_retry_group).await?; + Ok(()) +} + +async fn ensure_consumer_group(state: &AppState, stream: &str, group: &str) -> Result<(), ServiceError> { + let res: FredResult = state.redis.xgroup_create(stream, group, "$", true).await; + match res { + Ok(_) => Ok(()), + Err(e) if e.to_string().contains("BUSYGROUP") => Ok(()), + Err(e) => Err(e.into()), + } +} + +fn stream_for_request(state: &AppState, headers: &HeaderMap, tenant: &str, model: &str) -> (String, QueuePriority) { + let priority = request_priority(state, headers, tenant, model); + if !state.cfg.enable_priority_streams { + return (state.cfg.stream_key.clone(), priority); + } + let stream = match priority { + QueuePriority::High => state.cfg.high_priority_stream_key.clone(), + QueuePriority::Low => state.cfg.low_priority_stream_key.clone(), + QueuePriority::Normal => state.cfg.stream_key.clone(), + }; + (stream, priority) +} + +fn request_priority(state: &AppState, headers: &HeaderMap, tenant: &str, model: &str) -> QueuePriority { + if let Some(priority) = optional_header(headers, "x-queue-priority").and_then(|value| parse_queue_priority(&value)) { + return priority; + } + if contains_csv_value(&state.cfg.queue_high_tenants, tenant) || contains_csv_value(&state.cfg.queue_high_models, model) { + return QueuePriority::High; + } + if contains_csv_value(&state.cfg.queue_low_tenants, tenant) || contains_csv_value(&state.cfg.queue_low_models, model) { + return QueuePriority::Low; + } + parse_queue_priority(&state.cfg.queue_default_priority).unwrap_or(QueuePriority::Normal) +} + +fn configured_streams(state: &AppState) -> Vec { + if state.cfg.enable_priority_streams { + vec![ + state.cfg.high_priority_stream_key.clone(), + state.cfg.stream_key.clone(), + state.cfg.low_priority_stream_key.clone(), + ] + } else { + vec![state.cfg.stream_key.clone()] + } +} + +fn worker_stream_order(state: &AppState) -> Vec { + if !state.cfg.enable_priority_streams { + return vec![state.cfg.stream_key.clone()]; + } + let mut out = Vec::new(); + push_weighted(&mut out, &state.cfg.high_priority_stream_key, state.cfg.queue_high_weight); + push_weighted(&mut out, &state.cfg.stream_key, state.cfg.queue_normal_weight); + push_weighted(&mut out, &state.cfg.low_priority_stream_key, state.cfg.queue_low_weight); + if out.is_empty() { + out.push(state.cfg.stream_key.clone()); + } + out +} + +fn push_weighted(out: &mut Vec, stream: &str, weight: usize) { + for _ in 0..weight { + out.push(stream.to_string()); + } +} + +fn parse_queue_priority(value: &str) -> Option { + match value.trim().to_ascii_lowercase().as_str() { + "high" => Some(QueuePriority::High), + "normal" | "default" | "medium" => Some(QueuePriority::Normal), + "low" => Some(QueuePriority::Low), + _ => None, + } +} + +fn contains_csv_value(csv: &str, needle: &str) -> bool { + csv.split(',').map(str::trim).filter(|value| !value.is_empty()).any(|value| value.eq_ignore_ascii_case(needle)) +} + +fn validate_callback_url(state: &AppState, policy: QueuePolicy, callback_url: &str) -> Result<(), ServiceError> { + if policy == QueuePolicy::Queue && callback_url.is_empty() { + return Err(ServiceError::bad_request("missing required header `x-callback-url` for queue policy")); + } + if !callback_url.is_empty() && state.cfg.require_https_callback && !callback_url.starts_with("https://") { + return Err(ServiceError::bad_request("x-callback-url must use https")); + } + Ok(()) +} diff --git a/binary/ai-gateway-service/src/app/ratelimit.rs b/binary/ai-gateway-service/src/app/ratelimit.rs new file mode 100644 index 00000000..589f84fa --- /dev/null +++ b/binary/ai-gateway-service/src/app/ratelimit.rs @@ -0,0 +1,262 @@ +fn global_rate_limit(state: &AppState) -> TenantRateLimit { + TenantRateLimit { + rps: state.cfg.rate_limit_rps, + burst: state.cfg.rate_limit_burst, + cost: state.cfg.rate_limit_cost.max(1), + } +} + +fn tenant_rate_limit_candidate_keys(state: &AppState, tenant: &str, model: &str, path: &str, policy: &str) -> Vec { + let base = format!("{}{}", state.cfg.tenant_rate_limit_prefix, sanitize_key(tenant)); + let model = sanitize_key(model); + let path = sanitize_key(path); + let policy = sanitize_key(policy); + vec![ + format!("{base}:model:{model}:path:{path}:policy:{policy}"), + format!("{base}:model:{model}:path:{path}"), + format!("{base}:model:{model}:policy:{policy}"), + format!("{base}:path:{path}:policy:{policy}"), + format!("{base}:model:{model}"), + format!("{base}:path:{path}"), + format!("{base}:policy:{policy}"), + base, + ] +} + +struct ParsedStoredTenantRateLimit { + limit: TenantRateLimit, + ttl_secs: Option, +} + +fn parse_stored_tenant_rate_limit(raw: &str) -> Option { + let raw = raw.trim(); + if raw.is_empty() { + return None; + } + if let Ok(stored) = serde_json::from_str::(raw) { + return Some(ParsedStoredTenantRateLimit { + limit: TenantRateLimit { + rps: stored.rps, + burst: stored.burst, + cost: stored.cost.max(1), + }, + ttl_secs: stored.ttl_secs, + }); + } + if let Ok(mut limit) = serde_json::from_str::(raw) { + limit.cost = limit.cost.max(1); + return Some(ParsedStoredTenantRateLimit { limit, ttl_secs: None }); + } + parse_tenant_rate_limit_csv(raw).map(|limit| ParsedStoredTenantRateLimit { limit, ttl_secs: None }) +} + +#[cfg(test)] +fn parse_tenant_rate_limit(raw: &str) -> Option { + parse_stored_tenant_rate_limit(raw).map(|stored| stored.limit) +} + +fn parse_tenant_rate_limit_csv(raw: &str) -> Option { + let mut parts = raw.split(',').map(str::trim); + let rps = parts.next()?.parse().ok()?; + let burst = parts.next()?.parse().ok()?; + let cost = parts.next().and_then(|value| value.parse().ok()).unwrap_or(1); + Some(TenantRateLimit { rps, burst, cost: cost.max(1) }) +} + +/// 按租户规则(可含 model/path/policy 维度)解析配额,未命中则回退全局默认值。 +async fn resolve_rate_limit(state: &AppState, tenant: &str, model: &str, path: &str, policy: &str) -> Result { + for key in tenant_rate_limit_candidate_keys(state, tenant, model, path, policy) { + let raw: Option = state.redis.get(key.as_str()).await?; + if let Some(stored) = raw.and_then(|raw| parse_stored_tenant_rate_limit(&raw)) { + return Ok(stored.limit); + } + } + Ok(global_rate_limit(state)) +} + +fn tenant_rate_limit_keys(tenant: &str) -> (String, String) { + let tenant_key = sanitize_key(tenant); + ( + format!("ai:ratelimit:{tenant_key}:tokens"), + format!("ai:ratelimit:{tenant_key}:ts"), + ) +} + +async fn list_tenant_rate_limit_rules(state: &AppState, filters: &HashMap) -> Result, ServiceError> { + let pattern = format!("{}*", state.cfg.tenant_rate_limit_prefix); + let mut stream = state.redis.scan_buffered(pattern, Some(100), None); + let mut out = Vec::new(); + + while let Some(key) = stream.next().await { + let key = key?.into_string().unwrap_or_default(); + if is_legacy_tenant_rate_limit_key(&key) { + continue; + } + + let raw: Option = state.redis.get(key.as_str()).await?; + let Some(stored) = raw.and_then(|raw| parse_stored_tenant_rate_limit(&raw)) else { + continue; + }; + let Some(mut rule) = tenant_rate_limit_rule_from_key(state, &key, stored.limit, stored.ttl_secs) else { + continue; + }; + rule.cost = rule.cost.max(1); + if tenant_rule_matches_filters(&rule, filters) { + let ttl_remaining_secs = read_ttl_remaining_secs(state, key.as_str()).await; + out.push(tenant_rate_limit_rule_view(key, rule, ttl_remaining_secs)); + } + } + + out.sort_by(|a, b| tenant_rule_specificity_rule(a).cmp(&tenant_rule_specificity_rule(b)).then_with(|| a.key.cmp(&b.key))); + Ok(out) +} + +async fn upsert_tenant_rate_limit_rule(state: &AppState, mut rule: TenantRateLimitRule) -> Result { + validate_tenant_rate_limit_rule(&rule)?; + rule.cost = rule.cost.max(1); + let key = tenant_rate_limit_rule_key(state, &rule); + let value = serde_json::to_string(&StoredTenantRateLimit { + rps: rule.rps, + burst: rule.burst, + cost: rule.cost, + ttl_secs: rule.ttl_secs, + }) + .map_err(|e| ServiceError::internal(format!("serialize tenant rate limit: {e}")))?; + let expiration = rule.ttl_secs.map(|ttl| Expiration::EX(ttl.max(1) as i64)); + let _: String = state.redis.set(key.as_str(), value, expiration, None, false).await?; + let ttl_remaining_secs = read_ttl_remaining_secs(state, key.as_str()).await; + Ok(tenant_rate_limit_rule_view(key, rule, ttl_remaining_secs)) +} + +async fn delete_tenant_rate_limit_rule(state: &AppState, rule: TenantRateLimitRule) -> Result { + validate_tenant_rule_dimensions(&rule)?; + let key = tenant_rate_limit_rule_key(state, &rule); + let removed: u64 = state.redis.del(key.as_str()).await?; + Ok(removed) +} + +async fn read_ttl_remaining_secs(state: &AppState, key: &str) -> Option { + let ttl: i64 = state.redis.ttl(key).await.unwrap_or(-2); + if ttl > 0 { Some(ttl) } else { None } +} + +fn tenant_rate_limit_rule_key(state: &AppState, rule: &TenantRateLimitRule) -> String { + let base = format!("{}{}", state.cfg.tenant_rate_limit_prefix, sanitize_key(rule.tenant.trim())); + let mut key = base; + if let Some(model) = non_empty_opt(&rule.model) { + key.push_str(":model:"); + key.push_str(&sanitize_key(model)); + } + if let Some(path) = non_empty_opt(&rule.path) { + key.push_str(":path:"); + key.push_str(&sanitize_key(path)); + } + if let Some(policy) = non_empty_opt(&rule.policy) { + key.push_str(":policy:"); + key.push_str(&sanitize_key(policy)); + } + key +} + +fn tenant_rate_limit_rule_from_key(state: &AppState, key: &str, limit: TenantRateLimit, ttl_secs: Option) -> Option { + let rest = key.strip_prefix(&state.cfg.tenant_rate_limit_prefix)?; + let mut parts = rest.split(':'); + let tenant = parts.next()?.to_string(); + if tenant.is_empty() { + return None; + } + + let mut model = None; + let mut path = None; + let mut policy = None; + while let (Some(name), Some(value)) = (parts.next(), parts.next()) { + match name { + "model" => model = Some(value.to_string()), + "path" => path = Some(value.to_string()), + "policy" => policy = Some(value.to_string()), + _ => {} + } + } + + Some(TenantRateLimitRule { + tenant, + model, + path, + policy, + rps: limit.rps, + burst: limit.burst, + cost: limit.cost.max(1), + ttl_secs, + }) +} + +fn validate_tenant_rate_limit_rule(rule: &TenantRateLimitRule) -> Result<(), ServiceError> { + validate_tenant_rule_dimensions(rule)?; + if rule.rps == 0 { + return Err(ServiceError::bad_request("rps must be greater than 0")); + } + if rule.burst == 0 { + return Err(ServiceError::bad_request("burst must be greater than 0")); + } + if rule.cost == 0 { + return Err(ServiceError::bad_request("cost must be greater than 0")); + } + Ok(()) +} + +fn validate_tenant_rule_dimensions(rule: &TenantRateLimitRule) -> Result<(), ServiceError> { + if rule.tenant.trim().is_empty() { + return Err(ServiceError::bad_request("tenant is required")); + } + if let Some(policy) = non_empty_opt(&rule.policy) { + match policy { + "abandon" | "queue" | "wait" => {} + _ => return Err(ServiceError::bad_request("policy must be abandon, queue, or wait")), + } + } + Ok(()) +} + +fn tenant_rule_matches_filters(rule: &TenantRateLimitRule, filters: &HashMap) -> bool { + for (name, value) in filters { + let value = value.trim(); + if value.is_empty() { + continue; + } + let matches = match name.as_str() { + "tenant" => rule.tenant.contains(value), + "model" => rule.model.as_deref().unwrap_or("").contains(value), + "path" => rule.path.as_deref().unwrap_or("").contains(value), + "policy" => rule.policy.as_deref().unwrap_or("") == value, + _ => true, + }; + if !matches { + return false; + } + } + true +} + +fn tenant_rule_specificity_rule(view: &TenantRateLimitRuleView) -> usize { + usize::from(non_empty_opt(&view.model).is_some()) + usize::from(non_empty_opt(&view.path).is_some()) + usize::from(non_empty_opt(&view.policy).is_some()) +} + +fn is_legacy_tenant_rate_limit_key(key: &str) -> bool { + key.ends_with(":rps") || key.ends_with(":burst") || key.ends_with(":cost") +} + +fn non_empty_opt(value: &Option) -> Option<&str> { + value.as_deref().map(str::trim).filter(|value| !value.is_empty()) +} + +fn record_rate_limited(metrics: &Metrics, policy: &str, tenant: &str) { + metrics.rate_limited_total.fetch_add(1, Ordering::Relaxed); + inc_labeled( + metrics, + format!( + r#"rate_limited_total{{policy="{}",tenant="{}"}}"#, + metrics_label(policy), + metrics_label(tenant) + ), + ); +} diff --git a/binary/ai-gateway-service/src/app/result_store.rs b/binary/ai-gateway-service/src/app/result_store.rs new file mode 100644 index 00000000..6f1b886f --- /dev/null +++ b/binary/ai-gateway-service/src/app/result_store.rs @@ -0,0 +1,28 @@ +async fn store_result(state: &AppState, result: &StoredResult) -> Result<(), ServiceError> { + let json = serde_json::to_string(result).map_err(|e| ServiceError::internal(format!("serialize result: {e}")))?; + let key = result_key(state, &result.job_id); + let channel = result_channel(state, &result.job_id); + let ttl = state.cfg.result_ttl_secs.min(i64::MAX as u64) as i64; + let _: () = state.redis.set(key, json, Some(Expiration::EX(ttl)), None::, false).await?; + let _: i64 = state.redis.publish(channel, "done").await?; + Ok(()) +} + +async fn load_result(state: &AppState, job_id: &str) -> Result, ServiceError> { + let raw: Option = state.redis.get(result_key(state, job_id)).await?; + raw.map(|s| serde_json::from_str(&s).map_err(|e| ServiceError::internal(format!("parse result: {e}")))).transpose() +} + +fn result_to_response(result: StoredResult, created_at_ms: u64) -> Result { + let status = StatusCode::from_u16(result.http_status).unwrap_or(StatusCode::OK); + let body = base64::engine::general_purpose::STANDARD.decode(result.body_base64).map_err(|e| ServiceError::internal(format!("decode result body: {e}")))?; + let mut resp = (status, body).into_response(); + for (name, value) in result.headers { + if let (Ok(name), Ok(value)) = (HeaderName::try_from(name.as_str()), HeaderValue::from_str(&value)) { + resp.headers_mut().insert(name, value); + } + } + resp.headers_mut().insert("x-job-id", header_value(&result.job_id)?); + resp.headers_mut().insert("x-queue-wait-ms", header_value(&now_ms().saturating_sub(created_at_ms).to_string())?); + Ok(resp) +} diff --git a/binary/ai-gateway-service/src/app/runtime.rs b/binary/ai-gateway-service/src/app/runtime.rs new file mode 100644 index 00000000..186ce9d1 --- /dev/null +++ b/binary/ai-gateway-service/src/app/runtime.rs @@ -0,0 +1,89 @@ +pub async fn run() -> Result<(), Box> { + tracing_subscriber::fmt().with_env_filter(tracing_subscriber::EnvFilter::from_default_env()).init(); + + let args = load_args()?; + let redis = build_redis_client(&args.redis_url)?; + let _redis_task = redis.init().await?; + check_redis_version(&redis).await?; + let worker_redis = build_redis_client(&args.redis_url)?; + let _worker_redis_task = worker_redis.init().await?; + let wait_subscriber = WaitSubscriberHub::new(&args.redis_url).await?; + let state = AppState { + redis, + worker_redis, + http: reqwest::Client::new(), + cfg: Arc::new(args.clone()), + body_permits: Arc::new(Semaphore::new(args.body_read_concurrency.max(1))), + metrics: Arc::new(Metrics::default()), + wait_subscriber, + }; + + ensure_consumer_groups(&state).await?; + if state.cfg.upstream_base_url.is_some() { + spawn_workers(state.clone()); + spawn_reclaimer(state.clone()); + spawn_callback_retry_worker(state.clone()); + } else { + tracing::warn!("AI_UPSTREAM_BASE_URL is not set; queue jobs will be stored but no local worker will process them"); + } + + let app = build_router(state, args.max_body_bytes); + + let addr = SocketAddr::new(args.host, args.port); + tracing::info!(%addr, "ai-gateway-service listening"); + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + Ok(()) +} + +/// 构建 HTTP 路由,供 main 与集成测试复用。 +pub fn build_router(state: AppState, max_body_bytes: usize) -> Router { + Router::new() + .route("/healthz", get(healthz)) + .route("/metrics", get(metrics)) + .route("/v1/ratelimit/check", post(check_rate_limit)) + .route("/v1/queue/enqueue", post(enqueue)) + .route("/v1/queue/enqueue-and-wait", post(enqueue_and_wait)) + .route("/v1/jobs/{job_id}", get(get_job)) + .route("/jobs/{job_id}/status", get(get_job)) + .route("/v1/admin/plugins/{plugin}/schema", get(admin_plugin_schema)) + .route("/v1/admin/plugins/{plugin}/readme", get(admin_plugin_readme)) + .route("/v1/admin/tenant-rate-limits", get(admin_list_tenant_rate_limits).put(admin_upsert_tenant_rate_limit).delete(admin_delete_tenant_rate_limit)) + .layer(DefaultBodyLimit::max(max_body_bytes)) + .layer(build_admin_cors_layer(state.cfg.as_ref())) + .layer(TraceLayer::new_for_http()) + .with_state(state) +} + +async fn check_redis_version(redis: &FredClient) -> Result<(), Box> { + let info: String = redis.info(Some(InfoKind::Server)).await?; + for line in info.lines() { + if let Some(version) = line.strip_prefix("redis_version:") { + let major = version.split('.').next().and_then(|v| v.parse::().ok()).unwrap_or(0); + if major < 7 { + return Err(format!("Redis 7+ is required, found redis_version={version}").into()); + } + tracing::info!(redis_version = %version.trim(), "redis version check passed"); + return Ok(()); + } + } + tracing::warn!("could not parse redis_version from INFO; continuing without version check"); + Ok(()) +} + +fn build_admin_cors_layer(args: &Args) -> CorsLayer { + let origins: Vec = args + .admin_cors_origins + .split(',') + .map(str::trim) + .filter(|value| !value.is_empty()) + .filter_map(|value| HeaderValue::from_str(value).ok()) + .collect(); + if origins.is_empty() { + return CorsLayer::permissive(); + } + CorsLayer::new() + .allow_origin(origins) + .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS]) + .allow_headers(Any) +} diff --git a/binary/ai-gateway-service/src/app/test_support.rs b/binary/ai-gateway-service/src/app/test_support.rs new file mode 100644 index 00000000..d25cd473 --- /dev/null +++ b/binary/ai-gateway-service/src/app/test_support.rs @@ -0,0 +1,366 @@ +// 集成测试 harness:启动 mock 上游/回调、隔离 Redis key、进程内 HTTP 服务。 +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use axum::http::{HeaderMap, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::post; +use axum::{Json, Router}; +use fred::prelude::*; +use futures_util::StreamExt; +use reqwest::Client; +use tokio::net::TcpListener; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; + +use super::*; + +/// 集成测试可选配置(字段公开,供外部 integration crate 使用)。 +#[derive(Default, Clone)] +pub struct HarnessConfig { + pub rate_limit_rps: Option, + pub rate_limit_burst: Option, + pub wait_timeout_secs: Option, + pub require_https_callback: Option, + pub inline_threshold: Option, + pub clear_object_store: bool, +} + +impl HarnessConfig { + fn apply(self, args: &mut Args) { + if let Some(v) = self.rate_limit_rps { + args.rate_limit_rps = v; + } + if let Some(v) = self.rate_limit_burst { + args.rate_limit_burst = v; + } + if let Some(v) = self.wait_timeout_secs { + args.wait_timeout_secs = v; + } + if let Some(v) = self.require_https_callback { + args.require_https_callback = v; + } + if let Some(v) = self.inline_threshold { + args.inline_threshold = v; + } + if self.clear_object_store { + args.object_store_endpoint = None; + } + } +} + +/// 回调服务器记录到的 POST 请求。 +#[derive(Debug, Clone, Default)] +pub struct CallbackRecord { + pub job_id: String, + pub body: serde_json::Value, + pub headers: Vec<(String, String)>, +} + +/// 集成测试环境:随机端口 HTTP 服务 + mock upstream/callback。 +pub struct TestHarness { + pub base_url: String, + pub client: Client, + pub state: AppState, + pub upstream_url: String, + pub callback_url: String, + pub redis: FredClient, + pub suffix: String, + _server: JoinHandle<()>, + _upstream: JoinHandle<()>, + _callback: JoinHandle<()>, + callback_records: Arc>>, +} + +impl TestHarness { + /// 使用默认 Redis(`REDIS_URL` 或 `redis://127.0.0.1/`)启动隔离测试环境。 + pub async fn start() -> Self { + Self::start_with(|_| {}).await + } + + /// 使用 [`HarnessConfig`] 启动(供 tests/integration 使用)。 + pub async fn start_config(config: HarnessConfig) -> Self { + Self::start_with(move |a| { + config.apply(a); + }) + .await + } + + /// 允许调用方微调 Args(限流、timeout、stream key 等)。 + pub async fn start_with(configure: impl FnOnce(&mut Args)) -> Self { + if !redis_available().await { + panic!("Redis 7+ is required for integration tests (set REDIS_URL or start redis locally)"); + } + + let suffix = ulid::Ulid::new().to_string().to_ascii_lowercase(); + let mut args = Args::parse_from(["ai-gateway-service"]); + args.stream_key = format!("ai:jobs:test:{suffix}"); + args.high_priority_stream_key = format!("ai:jobs:high:test:{suffix}"); + args.low_priority_stream_key = format!("ai:jobs:low:test:{suffix}"); + args.consumer_group = format!("ai-gateway-workers-test-{suffix}"); + args.consumer_name = format!("ai-gateway-test-{suffix}"); + args.job_dlq_stream = format!("ai:job-dlq:test:{suffix}"); + args.callback_retry_stream = format!("ai:callback-retry:test:{suffix}"); + args.callback_retry_group = format!("ai-gateway-callbacks-test-{suffix}"); + args.callback_dlq_stream = format!("ai:callback-dlq:test:{suffix}"); + args.tenant_rate_limit_prefix = format!("ai:tenant:ratelimit:test:{suffix}:"); + args.result_key_prefix = format!("result:test:{suffix}:"); + args.result_channel_prefix = format!("result:test:{suffix}:"); + args.rate_limit_rps = 100; + args.rate_limit_burst = 2; + args.wait_timeout_secs = 3; + args.reclaim_interval_secs = 2; + args.reclaim_min_idle_secs = 1; + args.worker_concurrency = 2; + args.enable_priority_streams = true; + // mock 回调为 http://127.0.0.1,测试环境关闭 HTTPS 强制 + args.require_https_callback = false; + configure(&mut args); + + let (upstream_url, upstream_task) = spawn_mock_upstream(Duration::from_millis(50)).await; + args.upstream_base_url = Some(upstream_url.clone()); + + let callback_records = Arc::new(Mutex::new(Vec::new())); + let (callback_url, callback_task) = spawn_mock_callback(callback_records.clone()).await; + + let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1/".to_string()); + let redis = build_redis_client(&redis_url).expect("redis client"); + redis.init().await.expect("redis init"); + let worker_redis = build_redis_client(&redis_url).expect("worker redis"); + worker_redis.init().await.expect("worker redis init"); + let wait_subscriber = WaitSubscriberHub::new(&redis_url).await.expect("wait subscriber"); + + let state = AppState { + redis: redis.clone(), + worker_redis, + http: Client::new(), + cfg: Arc::new(args.clone()), + body_permits: Arc::new(Semaphore::new(args.body_read_concurrency.max(1))), + metrics: Arc::new(Metrics::default()), + wait_subscriber, + }; + + ensure_consumer_groups(&state).await.expect("consumer groups"); + spawn_workers(state.clone()); + spawn_reclaimer(state.clone()); + spawn_callback_retry_worker(state.clone()); + + let app = build_router(state.clone(), args.max_body_bytes); + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind harness"); + let addr = listener.local_addr().expect("local addr"); + let server = tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve harness"); + }); + + let base_url = format!("http://{addr}"); + // 等待 worker 与 HTTP 就绪 + tokio::time::sleep(Duration::from_millis(100)).await; + + Self { + base_url, + client: Client::new(), + state, + upstream_url, + callback_url, + redis, + suffix, + _server: server, + _upstream: upstream_task, + _callback: callback_task, + callback_records, + } + } + + pub fn callback_records(&self) -> Vec { + self.callback_records.lock().unwrap_or_else(|e| e.into_inner()).clone() + } + + /// 查询 tenant 限流 Redis key(TC-RL-07)。 + pub async fn ratelimit_keys_for_tenant(&self, tenant: &str) -> Vec { + let pattern = format!("ai:ratelimit:{tenant}:*"); + let mut stream = self.redis.scan_buffered(pattern, Some(100), None); + let mut out = Vec::new(); + while let Some(key) = stream.next().await { + if let Ok(key) = key { + out.push(key.into_string().unwrap_or_default()); + } + } + out + } + + /// callback retry stream 深度。 + pub async fn callback_retry_depth(&self) -> i64 { + self.redis + .xlen(self.state.cfg.callback_retry_stream.as_str()) + .await + .unwrap_or(0) + } + + /// POST /v1/ratelimit/check + pub async fn check_rate_limit(&self, tenant: &str, policy: &str) -> reqwest::Response { + self.client + .post(format!("{}/v1/ratelimit/check", self.base_url)) + .header("x-tenant-id", tenant) + .header("x-ratelimit-policy", policy) + .header("x-original-path", "/v1/chat") + .send() + .await + .expect("rate limit request") + } + + /// POST /v1/queue/enqueue + pub async fn enqueue(&self, tenant: &str, body: Vec, extra: HeaderMap) -> reqwest::Response { + let mut req = self + .client + .post(format!("{}/v1/queue/enqueue", self.base_url)) + .header("x-tenant-id", tenant) + .header("x-ratelimit-policy", "queue") + .header("x-callback-url", &self.callback_url) + .header("x-original-method", "POST") + .header("x-original-path", "/v1/chat"); + for (k, v) in extra.iter() { + if let Ok(v) = v.to_str() { + req = req.header(k.as_str(), v); + } + } + req.body(body).send().await.expect("enqueue") + } + + /// POST /v1/queue/enqueue-and-wait + pub async fn enqueue_and_wait(&self, tenant: &str, body: Vec, timeout_secs: Option) -> reqwest::Response { + let mut req = self + .client + .post(format!("{}/v1/queue/enqueue-and-wait", self.base_url)) + .header("x-tenant-id", tenant) + .header("x-ratelimit-policy", "wait") + .header("x-original-method", "POST") + .header("x-original-path", "/v1/chat"); + if let Some(secs) = timeout_secs { + req = req.header("x-request-timeout", secs.to_string()); + } + req.body(body).send().await.expect("enqueue and wait") + } + + pub async fn get_job(&self, job_id: &str) -> reqwest::Response { + self.client + .get(format!("{}/jobs/{job_id}/status", self.base_url)) + .send() + .await + .expect("get job") + } + + pub async fn metrics(&self) -> String { + self.client + .get(format!("{}/metrics", self.base_url)) + .send() + .await + .expect("metrics") + .text() + .await + .expect("metrics body") + } + + /// 耗尽 tenant 令牌桶至 denied。 + pub async fn exhaust_tenant(&self, tenant: &str, policy: &str, times: u32) { + for _ in 0..times { + let _ = self.check_rate_limit(tenant, policy).await; + } + } +} + +impl Drop for TestHarness { + fn drop(&mut self) { + let redis = self.redis.clone(); + let keys = vec![ + self.state.cfg.stream_key.clone(), + self.state.cfg.high_priority_stream_key.clone(), + self.state.cfg.low_priority_stream_key.clone(), + self.state.cfg.job_dlq_stream.clone(), + self.state.cfg.callback_retry_stream.clone(), + self.state.cfg.callback_dlq_stream.clone(), + ]; + tokio::spawn(async move { + for key in keys { + let _: u64 = redis.del(key.as_str()).await.unwrap_or(0); + } + }); + } +} + +async fn redis_available() -> bool { + let url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1/".to_string()); + let Ok(client) = build_redis_client(&url) else { + return false; + }; + if client.init().await.is_err() { + return false; + } + let info: Result = client.info(Some(InfoKind::Server)).await; + match info { + Ok(text) => text.lines().any(|line| { + line.strip_prefix("redis_version:") + .and_then(|v| v.split('.').next()) + .and_then(|v| v.parse::().ok()) + .is_some_and(|major| major >= 7) + }), + Err(_) => false, + } +} + +async fn spawn_mock_upstream(delay: Duration) -> (String, JoinHandle<()>) { + let app = Router::new().fallback({ + let delay = delay; + move |method: axum::http::Method| { + let delay = delay; + async move { + if method == axum::http::Method::POST { + tokio::time::sleep(delay).await; + Json(serde_json::json!({ "upstream": true, "model": "test" })).into_response() + } else { + StatusCode::METHOD_NOT_ALLOWED.into_response() + } + } + } + }); + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind upstream"); + let addr = listener.local_addr().expect("upstream addr"); + let task = tokio::spawn(async move { + axum::serve(listener, app).await.expect("upstream serve"); + }); + (format!("http://{addr}"), task) +} + +async fn spawn_mock_callback(records: Arc>>) -> (String, JoinHandle<()>) { + let app = Router::new().route( + "/cb", + post({ + let records = records.clone(); + move |headers: HeaderMap, Json(body): Json| { + let records = records.clone(); + async move { + let job_id = headers + .get("x-gateway-job-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + let header_pairs = headers + .iter() + .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + records.lock().unwrap_or_else(|e| e.into_inner()).push(CallbackRecord { + job_id, + body, + headers: header_pairs, + }); + StatusCode::OK.into_response() + } + } + }), + ); + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind callback"); + let addr = listener.local_addr().expect("callback addr"); + let task = tokio::spawn(async move { + axum::serve(listener, app).await.expect("callback serve"); + }); + (format!("http://{addr}/cb"), task) +} diff --git a/binary/ai-gateway-service/src/app/tests.rs b/binary/ai-gateway-service/src/app/tests.rs new file mode 100644 index 00000000..5ad3cc50 --- /dev/null +++ b/binary/ai-gateway-service/src/app/tests.rs @@ -0,0 +1,177 @@ +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn extracts_upload_id_from_multipart_xml() { + let xml = "a+b/c="; + assert_eq!(extract_xml_tag(xml, "UploadId").as_deref(), Some("a+b/c=")); + } + + #[test] + fn encodes_upload_id_for_query_string() { + assert_eq!(encode_query_component("a+b/c="), "a%2Bb%2Fc%3D"); + } + + #[test] + fn builds_complete_multipart_xml_with_escaped_etags() { + let parts = vec![ + CompletedPart { + part_number: 1, + etag: "\"abc&1\"".to_string(), + }, + CompletedPart { + part_number: 2, + etag: "\"def\"".to_string(), + }, + ]; + let xml = complete_multipart_xml(&parts); + assert!(xml.contains("1"abc&1"")); + assert!(xml.contains("2"def"")); + } + + #[test] + fn callback_retry_delay_uses_exponential_backoff_with_cap() { + assert_eq!(callback_retry_delay_ms(1000, 60_000, 1), 1000); + assert_eq!(callback_retry_delay_ms(1000, 60_000, 3), 4000); + assert_eq!(callback_retry_delay_ms(1000, 5000, 8), 5000); + } + + #[test] + fn parses_xpending_summary_count() { + let value = Value::Array(vec![Value::Integer(7), Value::String("0-1".into()), Value::String("0-2".into())]); + assert_eq!(pending_count_from_value(&value), 7); + } + + #[test] + fn observes_histogram_buckets_as_non_overlapping_counts() { + let metrics = Metrics::default(); + observe_enqueue_latency(&metrics, 80, "queue", "inline"); + observe_enqueue_latency(&metrics, 800, "wait", "inline"); + observe_body_size(&metrics, 8 * 1024); + observe_body_size(&metrics, 256 * 1024); + observe_worker_processing(&metrics, 2000, "gpt-4o-mini"); + + assert_eq!(metrics.enqueue_latency_count.load(Ordering::Relaxed), 2); + assert_eq!(metrics.enqueue_latency_le_100_ms.load(Ordering::Relaxed), 1); + assert_eq!(metrics.enqueue_latency_le_1000_ms.load(Ordering::Relaxed), 1); + assert_eq!(metrics.body_size_count.load(Ordering::Relaxed), 2); + assert_eq!(metrics.body_size_le_10kb.load(Ordering::Relaxed), 1); + assert_eq!(metrics.body_size_le_5mb.load(Ordering::Relaxed), 1); + assert_eq!(metrics.worker_processing_count.load(Ordering::Relaxed), 1); + assert_eq!(metrics.worker_processing_le_5000_ms.load(Ordering::Relaxed), 1); + } + + #[test] + fn parses_tenant_rate_limit_json_and_csv() { + let json = parse_tenant_rate_limit(r#"{"rps":10,"burst":20,"cost":3}"#).unwrap(); + assert_eq!(json.rps, 10); + assert_eq!(json.burst, 20); + assert_eq!(json.cost, 3); + + let csv = parse_tenant_rate_limit("15,30,2").unwrap(); + assert_eq!(csv.rps, 15); + assert_eq!(csv.burst, 30); + assert_eq!(csv.cost, 2); + } + + #[test] + fn parses_queue_priority_values() { + assert_eq!(parse_queue_priority("HIGH"), Some(QueuePriority::High)); + assert_eq!(parse_queue_priority("medium"), Some(QueuePriority::Normal)); + assert_eq!(parse_queue_priority("low"), Some(QueuePriority::Low)); + assert_eq!(parse_queue_priority("urgent"), None); + } + + async fn test_app_state(object_store_endpoint: Option, inline_threshold: usize) -> AppState { + let mut args = Args::parse_from(["ai-gateway-service"]); + args.object_store_endpoint = object_store_endpoint; + args.inline_threshold = inline_threshold; + args.max_body_bytes = 8 * 1024 * 1024; + args.object_store_bucket = "ai-gateway-body".to_string(); + args.object_store_prefix = "bodies".to_string(); + args.object_multipart_part_size = 1024; + let redis = build_redis_client("redis://127.0.0.1/").expect("redis client"); + let wait_subscriber = WaitSubscriberHub::new("redis://127.0.0.1/").await.expect("wait subscriber"); + AppState { + redis: redis.clone(), + worker_redis: redis, + http: reqwest::Client::new(), + cfg: Arc::new(args), + body_permits: Arc::new(Semaphore::new(8)), + metrics: Arc::new(Metrics::default()), + wait_subscriber, + } + } + + async fn mock_s3_handler(method: Method, uri: Uri, body: axum::body::Bytes, stored: Arc>>) -> Response { + let query = uri.query().unwrap_or(""); + if method == Method::POST && query == "uploads" { + return ( + StatusCode::OK, + [(http::header::CONTENT_TYPE, "application/xml")], + r#"test-upload"#, + ) + .into_response(); + } + if method == Method::PUT && query.contains("partNumber=") { + stored.lock().unwrap_or_else(|e| e.into_inner()).extend_from_slice(&body); + return (StatusCode::OK, [(http::header::ETAG, "\"part-etag\"")]).into_response(); + } + if method == Method::POST && query.contains("uploadId=") { + return StatusCode::OK.into_response(); + } + if method == Method::GET { + let bytes = stored.lock().unwrap_or_else(|e| e.into_inner()).clone(); + return (StatusCode::OK, bytes).into_response(); + } + StatusCode::NOT_FOUND.into_response() + } + + #[tokio::test] + async fn store_body_keeps_small_payload_inline() { + let state = test_app_state(None, 16 * 1024).await; + let payload = vec![1u8; 4096]; + let outcome = store_body(&state, "job-inline", Body::from(payload.clone())).await.expect("inline store"); + let location = outcome.location; + assert_eq!(location.storage, "inline"); + assert_eq!(location.size, payload.len()); + assert!(!location.body_base64.is_empty()); + assert!(outcome.pending_upload.is_none()); + assert_eq!(state.metrics.object_offload_total.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn store_body_offloads_large_payload_via_s3_multipart_and_load_body_roundtrips() { + let stored = Arc::new(std::sync::Mutex::new(Vec::new())); + let stored_for_handler = stored.clone(); + let app = Router::new().fallback(move |method: Method, uri: Uri, body: axum::body::Bytes| { + let stored_for_handler = stored_for_handler.clone(); + async move { mock_s3_handler(method, uri, body, stored_for_handler).await } + }); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind mock s3"); + let addr = listener.local_addr().expect("mock s3 addr"); + tokio::spawn(async move { + axum::serve(listener, app).await.expect("mock s3 serve"); + }); + + let state = test_app_state(Some(format!("http://{addr}")), 1024).await; + let payload = vec![7u8; 5000]; + let outcome = store_body(&state, "job-offload", Body::from(payload.clone())).await.expect("offload store"); + if let Some(upload) = outcome.pending_upload { + upload.await.expect("upload join").expect("upload body"); + } + let location = outcome.location; + assert_eq!(location.storage, "object"); + assert_eq!(location.size, payload.len()); + assert!(location.body_base64.is_empty()); + assert!(location.object_ref.contains("job-offload")); + assert_eq!(state.metrics.object_offload_total.load(Ordering::Relaxed), 1); + + let mut fields = HashMap::new(); + fields.insert("storage".to_string(), Value::String("object".into())); + fields.insert("ref".to_string(), Value::String(location.object_ref.into())); + let loaded = load_body(&state, &fields).await.expect("load offloaded body"); + assert_eq!(loaded, payload); + } +} diff --git a/binary/ai-gateway-service/src/app/types.rs b/binary/ai-gateway-service/src/app/types.rs new file mode 100644 index 00000000..b94f8341 --- /dev/null +++ b/binary/ai-gateway-service/src/app/types.rs @@ -0,0 +1,887 @@ +const TOKEN_BUCKET_LUA: &str = r#" +local tokens_key = KEYS[1] +local ts_key = KEYS[2] +local rate = tonumber(ARGV[1]) +local burst = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local cost = tonumber(ARGV[4]) + +if rate <= 0 or burst <= 0 or cost <= 0 then + return {0, 0, 1000} +end + +local burst_milli = burst * 1000 +local cost_milli = cost * 1000 +local tokens = tonumber(redis.call('GET', tokens_key) or burst_milli) +local last_ts = tonumber(redis.call('GET', ts_key) or now) +local elapsed = math.max(0, now - last_ts) +tokens = math.min(burst_milli, tokens + elapsed * rate) + +-- TTL 必须显著大于典型连续判定窗口;过短会导致 key 过期后每次都回到满 burst。 +local ttl = math.max(300000, math.ceil((burst_milli / rate) * 10)) +if tokens >= cost_milli then + tokens = tokens - cost_milli + redis.call('SET', tokens_key, tokens, 'PX', ttl) + redis.call('SET', ts_key, now, 'PX', ttl) + return {1, tokens, 0} +else + local wait_ms = math.ceil((cost_milli - tokens) / rate) + redis.call('SET', tokens_key, tokens, 'PX', ttl) + redis.call('SET', ts_key, now, 'PX', ttl) + return {0, tokens, wait_ms} +end +"#; + +#[derive(Debug, Clone, Parser)] +pub struct Args { + #[arg(long, env = "AI_GATEWAY_SERVICE_HOST", default_value = "0.0.0.0")] + host: IpAddr, + #[arg(long, env = "AI_GATEWAY_SERVICE_PORT", default_value_t = 18080)] + port: u16, + #[arg(long, env = "REDIS_URL", default_value = "redis://127.0.0.1/")] + redis_url: String, + #[arg(long, env = "AI_QUEUE_STREAM", default_value = "ai:jobs")] + stream_key: String, + #[arg(long, env = "AI_QUEUE_HIGH_STREAM", default_value = "ai:jobs:high")] + high_priority_stream_key: String, + #[arg(long, env = "AI_QUEUE_LOW_STREAM", default_value = "ai:jobs:low")] + low_priority_stream_key: String, + #[arg(long, env = "AI_ENABLE_PRIORITY_STREAMS", default_value_t = false)] + enable_priority_streams: bool, + #[arg(long, env = "AI_QUEUE_DEFAULT_PRIORITY", default_value = "normal")] + queue_default_priority: String, + #[arg(long, env = "AI_QUEUE_HIGH_MODELS", default_value = "")] + queue_high_models: String, + #[arg(long, env = "AI_QUEUE_LOW_MODELS", default_value = "")] + queue_low_models: String, + #[arg(long, env = "AI_QUEUE_HIGH_TENANTS", default_value = "")] + queue_high_tenants: String, + #[arg(long, env = "AI_QUEUE_LOW_TENANTS", default_value = "")] + queue_low_tenants: String, + #[arg(long, env = "AI_QUEUE_HIGH_WEIGHT", default_value_t = 3)] + queue_high_weight: usize, + #[arg(long, env = "AI_QUEUE_NORMAL_WEIGHT", default_value_t = 1)] + queue_normal_weight: usize, + #[arg(long, env = "AI_QUEUE_LOW_WEIGHT", default_value_t = 1)] + queue_low_weight: usize, + #[arg(long, env = "AI_QUEUE_MAX_LEN", default_value_t = 100_000)] + stream_max_len: u64, + #[arg(long, env = "AI_QUEUE_GROUP", default_value = "ai-gateway-workers")] + consumer_group: String, + #[arg(long, env = "AI_QUEUE_CONSUMER", default_value = "ai-gateway-service")] + consumer_name: String, + #[arg(long, env = "AI_JOB_DLQ_STREAM", default_value = "ai:job-dlq")] + job_dlq_stream: String, + #[arg(long, env = "AI_CALLBACK_RETRY_STREAM", default_value = "ai:callback-retry")] + callback_retry_stream: String, + #[arg(long, env = "AI_CALLBACK_RETRY_GROUP", default_value = "ai-gateway-callbacks")] + callback_retry_group: String, + #[arg(long, env = "AI_CALLBACK_DLQ_STREAM", default_value = "ai:callback-dlq")] + callback_dlq_stream: String, + #[arg(long, env = "AI_CALLBACK_MAX_RETRY_ATTEMPTS", default_value_t = 5)] + callback_max_retry_attempts: u32, + #[arg(long, env = "AI_CALLBACK_RETRY_INITIAL_DELAY_MS", default_value_t = 1000)] + callback_retry_initial_delay_ms: u64, + #[arg(long, env = "AI_CALLBACK_RETRY_MAX_DELAY_MS", default_value_t = 60_000)] + callback_retry_max_delay_ms: u64, + #[arg(long, env = "AI_CALLBACK_RETRY_RECLAIM_IDLE_SECS", default_value_t = 60)] + callback_retry_reclaim_idle_secs: u64, + #[arg(long, env = "AI_RESULT_KEY_PREFIX", default_value = "result:")] + result_key_prefix: String, + #[arg(long, env = "AI_RESULT_CHANNEL_PREFIX", default_value = "result:")] + result_channel_prefix: String, + #[arg(long, env = "AI_RESULT_TTL_SECS", default_value_t = 120)] + result_ttl_secs: u64, + #[arg(long, env = "AI_RATE_LIMIT_RPS", default_value_t = 100)] + rate_limit_rps: u64, + #[arg(long, env = "AI_RATE_LIMIT_BURST", default_value_t = 200)] + rate_limit_burst: u64, + #[arg(long, env = "AI_RATE_LIMIT_COST", default_value_t = 1)] + rate_limit_cost: u64, + #[arg(long, env = "AI_TENANT_RATE_LIMIT_PREFIX", default_value = "ai:tenant:ratelimit:")] + tenant_rate_limit_prefix: String, + #[arg(long, env = "AI_WAIT_TIMEOUT_SECS", default_value_t = 60)] + wait_timeout_secs: u64, + #[arg(long, env = "AI_WORKER_CONCURRENCY", default_value_t = 10)] + worker_concurrency: usize, + /// 逗号分隔的 Admin UI CORS 来源;为空则保持 permissive(本地开发)。 + #[arg(long, env = "AI_ADMIN_CORS_ORIGINS", default_value = "")] + admin_cors_origins: String, + #[arg(long, env = "AI_UPSTREAM_BASE_URL")] + upstream_base_url: Option, + #[arg(long, env = "AI_MAX_BODY_BYTES", default_value_t = 32 * 1024 * 1024)] + max_body_bytes: usize, + #[arg(long, env = "AI_INLINE_THRESHOLD", default_value_t = 128 * 1024)] + inline_threshold: usize, + #[arg(long, env = "AI_BODY_READ_CONCURRENCY", default_value_t = 200)] + body_read_concurrency: usize, + #[arg(long, env = "AI_RECLAIM_INTERVAL_SECS", default_value_t = 30)] + reclaim_interval_secs: u64, + #[arg(long, env = "AI_RECLAIM_MIN_IDLE_SECS", default_value_t = 30)] + reclaim_min_idle_secs: u64, + #[arg(long, env = "AI_JOB_PROCESS_LEASE_SECS", default_value_t = 120)] + job_process_lease_secs: u64, + #[arg(long, env = "AI_JOB_MAX_DELIVERY_ATTEMPTS", default_value_t = 5)] + job_max_delivery_attempts: u32, + #[arg(long, env = "AI_REQUIRE_HTTPS_CALLBACK", default_value_t = true)] + require_https_callback: bool, + #[arg(long, env = "AI_OBJECT_STORE_ENDPOINT")] + object_store_endpoint: Option, + #[arg(long, env = "AI_OBJECT_STORE_BUCKET", default_value = "ai-gateway-body")] + object_store_bucket: String, + #[arg(long, env = "AI_OBJECT_STORE_PREFIX", default_value = "bodies")] + object_store_prefix: String, + #[arg(long, env = "AI_OBJECT_MULTIPART_PART_SIZE", default_value_t = 5 * 1024 * 1024)] + object_multipart_part_size: usize, + #[arg(long, env = "AI_OBJECT_STORE_AUTH_HEADER")] + object_store_auth_header: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + host: default_host(), + port: default_port(), + redis_url: default_redis_url(), + stream_key: default_stream_key(), + high_priority_stream_key: default_high_priority_stream_key(), + low_priority_stream_key: default_low_priority_stream_key(), + enable_priority_streams: default_enable_priority_streams(), + queue_default_priority: default_queue_default_priority(), + queue_high_models: default_queue_high_models(), + queue_low_models: default_queue_low_models(), + queue_high_tenants: default_queue_high_tenants(), + queue_low_tenants: default_queue_low_tenants(), + queue_high_weight: default_queue_high_weight(), + queue_normal_weight: default_queue_normal_weight(), + queue_low_weight: default_queue_low_weight(), + stream_max_len: default_stream_max_len(), + consumer_group: default_consumer_group(), + consumer_name: default_consumer_name(), + job_dlq_stream: default_job_dlq_stream(), + callback_retry_stream: default_callback_retry_stream(), + callback_retry_group: default_callback_retry_group(), + callback_dlq_stream: default_callback_dlq_stream(), + callback_max_retry_attempts: default_callback_max_retry_attempts(), + callback_retry_initial_delay_ms: default_callback_retry_initial_delay_ms(), + callback_retry_max_delay_ms: default_callback_retry_max_delay_ms(), + callback_retry_reclaim_idle_secs: default_callback_retry_reclaim_idle_secs(), + result_key_prefix: default_result_key_prefix(), + result_channel_prefix: default_result_channel_prefix(), + result_ttl_secs: default_result_ttl_secs(), + rate_limit_rps: default_rate_limit_rps(), + rate_limit_burst: default_rate_limit_burst(), + rate_limit_cost: default_rate_limit_cost(), + tenant_rate_limit_prefix: default_tenant_rate_limit_prefix(), + wait_timeout_secs: default_wait_timeout_secs(), + worker_concurrency: default_worker_concurrency(), + admin_cors_origins: default_admin_cors_origins(), + upstream_base_url: None, + max_body_bytes: default_max_body_bytes(), + inline_threshold: default_inline_threshold(), + body_read_concurrency: default_body_read_concurrency(), + reclaim_interval_secs: default_reclaim_interval_secs(), + reclaim_min_idle_secs: default_reclaim_min_idle_secs(), + job_process_lease_secs: default_job_process_lease_secs(), + job_max_delivery_attempts: default_job_max_delivery_attempts(), + require_https_callback: default_require_https_callback(), + object_store_endpoint: None, + object_store_bucket: default_object_store_bucket(), + object_store_prefix: default_object_store_prefix(), + object_multipart_part_size: default_object_multipart_part_size(), + object_store_auth_header: None, + } + } +} + +#[derive(Clone)] +pub struct AppState { + /// 非阻塞 API 路径专用连接(准入、入队、metrics、admin)。 + redis: FredClient, + /// worker / reclaimer / callback-retry 专用连接,避免 BLOCK 型 XREADGROUP 占满 API 连接。 + worker_redis: FredClient, + http: reqwest::Client, + cfg: Arc, + body_permits: Arc, + metrics: Arc, + /// wait 模式共享 Pub/Sub 连接池。 + wait_subscriber: Arc, +} + +struct Metrics { + rate_limited_total: AtomicU64, + enqueue_total: AtomicU64, + enqueue_queue_total: AtomicU64, + enqueue_wait_total: AtomicU64, + enqueue_priority_high_total: AtomicU64, + enqueue_priority_normal_total: AtomicU64, + enqueue_priority_low_total: AtomicU64, + enqueue_latency_count: AtomicU64, + enqueue_latency_sum_ms: AtomicU64, + enqueue_latency_le_100_ms: AtomicU64, + enqueue_latency_le_500_ms: AtomicU64, + enqueue_latency_le_1000_ms: AtomicU64, + enqueue_latency_gt_1000_ms: AtomicU64, + body_size_le_10kb: AtomicU64, + body_size_le_128kb: AtomicU64, + body_size_le_5mb: AtomicU64, + body_size_gt_5mb: AtomicU64, + body_size_count: AtomicU64, + body_size_sum_bytes: AtomicU64, + wait_total: AtomicU64, + wait_timeout_total: AtomicU64, + callback_failure_total: AtomicU64, + callback_retry_total: AtomicU64, + callback_retry_success_total: AtomicU64, + callback_retry_dlq_total: AtomicU64, + worker_completed_total: AtomicU64, + worker_failed_total: AtomicU64, + worker_processing_count: AtomicU64, + worker_processing_sum_ms: AtomicU64, + worker_processing_le_1000_ms: AtomicU64, + worker_processing_le_5000_ms: AtomicU64, + worker_processing_le_30000_ms: AtomicU64, + worker_processing_gt_30000_ms: AtomicU64, + reclaimed_total: AtomicU64, + job_dlq_total: AtomicU64, + lease_skip_total: AtomicU64, + object_offload_total: AtomicU64, + object_multipart_abort_total: AtomicU64, + /// Prometheus 带 label 的 counter(policy/tenant/model/size_bucket 等)。 + labeled: Mutex>, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + rate_limited_total: AtomicU64::new(0), + enqueue_total: AtomicU64::new(0), + enqueue_queue_total: AtomicU64::new(0), + enqueue_wait_total: AtomicU64::new(0), + enqueue_priority_high_total: AtomicU64::new(0), + enqueue_priority_normal_total: AtomicU64::new(0), + enqueue_priority_low_total: AtomicU64::new(0), + enqueue_latency_count: AtomicU64::new(0), + enqueue_latency_sum_ms: AtomicU64::new(0), + enqueue_latency_le_100_ms: AtomicU64::new(0), + enqueue_latency_le_500_ms: AtomicU64::new(0), + enqueue_latency_le_1000_ms: AtomicU64::new(0), + enqueue_latency_gt_1000_ms: AtomicU64::new(0), + body_size_le_10kb: AtomicU64::new(0), + body_size_le_128kb: AtomicU64::new(0), + body_size_le_5mb: AtomicU64::new(0), + body_size_gt_5mb: AtomicU64::new(0), + body_size_count: AtomicU64::new(0), + body_size_sum_bytes: AtomicU64::new(0), + wait_total: AtomicU64::new(0), + wait_timeout_total: AtomicU64::new(0), + callback_failure_total: AtomicU64::new(0), + callback_retry_total: AtomicU64::new(0), + callback_retry_success_total: AtomicU64::new(0), + callback_retry_dlq_total: AtomicU64::new(0), + worker_completed_total: AtomicU64::new(0), + worker_failed_total: AtomicU64::new(0), + worker_processing_count: AtomicU64::new(0), + worker_processing_sum_ms: AtomicU64::new(0), + worker_processing_le_1000_ms: AtomicU64::new(0), + worker_processing_le_5000_ms: AtomicU64::new(0), + worker_processing_le_30000_ms: AtomicU64::new(0), + worker_processing_gt_30000_ms: AtomicU64::new(0), + reclaimed_total: AtomicU64::new(0), + job_dlq_total: AtomicU64::new(0), + lease_skip_total: AtomicU64::new(0), + object_offload_total: AtomicU64::new(0), + object_multipart_abort_total: AtomicU64::new(0), + labeled: Mutex::new(HashMap::new()), + } + } +} + +#[derive(Debug, Serialize)] +struct RateLimitResponse { + allowed: bool, + remaining_tokens_milli: i64, + retry_after_ms: i64, +} + +#[derive(Debug, Serialize)] +struct EnqueueResponse { + job_id: String, + stream_id: String, + stream_key: String, + status: &'static str, + /// 设计文档 poll 路径:`/jobs/{id}/status` + poll_url: String, + /// 兼容旧客户端:`/v1/jobs/{id}` + status_url: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct StoredResult { + job_id: String, + status: String, + http_status: u16, + headers: HashMap, + body_base64: String, + completed_at_ms: u64, + error: Option, +} + +#[derive(Debug)] +struct AcceptedJob { + response: EnqueueResponse, + created_at_ms: u64, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum QueuePolicy { + Queue, + Wait, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum QueuePriority { + High, + Normal, + Low, +} + +impl QueuePriority { + fn as_str(self) -> &'static str { + match self { + QueuePriority::High => "high", + QueuePriority::Normal => "normal", + QueuePriority::Low => "low", + } + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] +struct TenantRateLimit { + rps: u64, + burst: u64, + #[serde(default = "default_rate_limit_cost")] + cost: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayQueuePluginConfig { + #[serde(default)] + service: AiGatewayServiceConfig, + #[serde(default)] + paths: AiGatewayPathsConfig, + #[serde(default)] + headers: AiGatewayHeadersConfig, + #[serde(default)] + policies: AiGatewayPoliciesConfig, + #[serde(default)] + priority: AiGatewayPriorityConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayServiceConfig { + #[serde(default = "default_service_cluster")] + cluster: String, + #[serde(default = "default_service_authority")] + authority: String, + #[serde(default = "default_service_timeout_ms")] + timeout_ms: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayPathsConfig { + #[serde(default = "default_rate_limit_path")] + rate_limit: String, + #[serde(default = "default_enqueue_path")] + enqueue: String, + #[serde(default = "default_wait_path")] + wait: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayHeadersConfig { + #[serde(default = "default_policy_header")] + policy: String, + #[serde(default = "default_tenant_header")] + tenant: String, + #[serde(default = "default_model_header")] + model: String, + #[serde(default = "default_priority_header")] + priority: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayPoliciesConfig { + #[serde(default = "default_require_policy")] + require: bool, + #[serde(default)] + default: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct AiGatewayPriorityConfig { + #[serde(default = "default_priority_enabled")] + enabled: bool, + #[serde(default = "default_priority")] + default: String, + #[serde(default)] + high_models: Vec, + #[serde(default)] + low_models: Vec, + #[serde(default)] + high_tenants: Vec, + #[serde(default)] + low_tenants: Vec, +} + +impl Default for AiGatewayQueuePluginConfig { + fn default() -> Self { + Self { + service: AiGatewayServiceConfig::default(), + paths: AiGatewayPathsConfig::default(), + headers: AiGatewayHeadersConfig::default(), + policies: AiGatewayPoliciesConfig::default(), + priority: AiGatewayPriorityConfig::default(), + } + } +} + +impl Default for AiGatewayServiceConfig { + fn default() -> Self { + Self { + cluster: default_service_cluster(), + authority: default_service_authority(), + timeout_ms: default_service_timeout_ms(), + } + } +} + +impl Default for AiGatewayPathsConfig { + fn default() -> Self { + Self { + rate_limit: default_rate_limit_path(), + enqueue: default_enqueue_path(), + wait: default_wait_path(), + } + } +} + +impl Default for AiGatewayHeadersConfig { + fn default() -> Self { + Self { + policy: default_policy_header(), + tenant: default_tenant_header(), + model: default_model_header(), + priority: default_priority_header(), + } + } +} + +impl Default for AiGatewayPoliciesConfig { + fn default() -> Self { + Self { + require: default_require_policy(), + default: None, + } + } +} + +impl Default for AiGatewayPriorityConfig { + fn default() -> Self { + Self { + enabled: default_priority_enabled(), + default: default_priority(), + high_models: Vec::new(), + low_models: Vec::new(), + high_tenants: Vec::new(), + low_tenants: Vec::new(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct TenantRateLimitRule { + tenant: String, + #[serde(default)] + model: Option, + #[serde(default)] + path: Option, + #[serde(default)] + policy: Option, + rps: u64, + burst: u64, + #[serde(default = "default_rate_limit_cost")] + cost: u64, + /// 临时配额 TTL(秒);写入 Redis 时对 key 设置 EX。 + #[serde(default)] + ttl_secs: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +struct StoredTenantRateLimit { + rps: u64, + burst: u64, + #[serde(default = "default_rate_limit_cost")] + cost: u64, + #[serde(default)] + ttl_secs: Option, +} + +#[derive(Debug, Clone, Serialize)] +struct TenantRateLimitRuleView { + key: String, + tenant: String, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + policy: Option, + rps: u64, + burst: u64, + cost: u64, + #[serde(skip_serializing_if = "Option::is_none")] + ttl_secs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + ttl_remaining_secs: Option, +} + +fn default_host() -> IpAddr { + "0.0.0.0".parse().expect("default host") +} + +fn default_port() -> u16 { + 18080 +} + +fn default_redis_url() -> String { + "redis://127.0.0.1/".to_string() +} + +fn default_stream_key() -> String { + "ai:jobs".to_string() +} + +fn default_high_priority_stream_key() -> String { + "ai:jobs:high".to_string() +} + +fn default_low_priority_stream_key() -> String { + "ai:jobs:low".to_string() +} + +fn default_enable_priority_streams() -> bool { + true +} + +fn default_queue_default_priority() -> String { + "normal".to_string() +} + +fn default_queue_high_models() -> String { + String::new() +} + +fn default_queue_low_models() -> String { + String::new() +} + +fn default_queue_high_tenants() -> String { + String::new() +} + +fn default_queue_low_tenants() -> String { + String::new() +} + +fn default_queue_high_weight() -> usize { + 3 +} + +fn default_queue_normal_weight() -> usize { + 1 +} + +fn default_queue_low_weight() -> usize { + 1 +} + +fn default_stream_max_len() -> u64 { + 100_000 +} + +fn default_consumer_group() -> String { + "ai-gateway-workers".to_string() +} + +fn default_consumer_name() -> String { + "ai-gateway-service".to_string() +} + +fn default_job_dlq_stream() -> String { + "ai:job-dlq".to_string() +} + +fn default_callback_retry_stream() -> String { + "ai:callback-retry".to_string() +} + +fn default_callback_retry_group() -> String { + "ai-gateway-callbacks".to_string() +} + +fn default_callback_dlq_stream() -> String { + "ai:callback-dlq".to_string() +} + +fn default_callback_max_retry_attempts() -> u32 { + 5 +} + +fn default_callback_retry_initial_delay_ms() -> u64 { + 1000 +} + +fn default_callback_retry_max_delay_ms() -> u64 { + 60_000 +} + +fn default_callback_retry_reclaim_idle_secs() -> u64 { + 60 +} + +fn default_result_key_prefix() -> String { + "result:".to_string() +} + +fn default_result_channel_prefix() -> String { + "result:".to_string() +} + +fn default_result_ttl_secs() -> u64 { + 120 +} + +fn default_rate_limit_rps() -> u64 { + 100 +} + +fn default_rate_limit_burst() -> u64 { + 200 +} + +fn default_tenant_rate_limit_prefix() -> String { + "ai:tenant:ratelimit:".to_string() +} + +fn default_wait_timeout_secs() -> u64 { + 60 +} + +fn default_worker_concurrency() -> usize { + 10 +} + +fn default_admin_cors_origins() -> String { + String::new() +} + +fn default_max_body_bytes() -> usize { + 32 * 1024 * 1024 +} + +fn default_inline_threshold() -> usize { + 128 * 1024 +} + +fn default_body_read_concurrency() -> usize { + 200 +} + +fn default_reclaim_interval_secs() -> u64 { + 30 +} + +fn default_reclaim_min_idle_secs() -> u64 { + 30 +} + +fn default_job_process_lease_secs() -> u64 { + 120 +} + +fn default_job_max_delivery_attempts() -> u32 { + 5 +} + +fn default_require_https_callback() -> bool { + true +} + +fn default_object_store_bucket() -> String { + "ai-gateway-body".to_string() +} + +fn default_object_store_prefix() -> String { + "bodies".to_string() +} + +fn default_object_multipart_part_size() -> usize { + 5 * 1024 * 1024 +} + +fn default_rate_limit_cost() -> u64 { + 1 +} + +fn default_service_cluster() -> String { + "ai-gateway-service".to_string() +} + +fn default_service_authority() -> String { + "ai-gateway-service".to_string() +} + +fn default_service_timeout_ms() -> u64 { + 65_000 +} + +fn default_rate_limit_path() -> String { + "/v1/ratelimit/check".to_string() +} + +fn default_enqueue_path() -> String { + "/v1/queue/enqueue".to_string() +} + +fn default_wait_path() -> String { + "/v1/queue/enqueue-and-wait".to_string() +} + +fn default_policy_header() -> String { + "x-ratelimit-policy".to_string() +} + +fn default_tenant_header() -> String { + "x-tenant-id".to_string() +} + +fn default_model_header() -> String { + "x-model".to_string() +} + +fn default_priority_header() -> String { + "x-queue-priority".to_string() +} + +fn default_require_policy() -> bool { + true +} + +fn default_priority_enabled() -> bool { + true +} + +fn default_priority() -> String { + "normal".to_string() +} + +impl QueuePolicy { + fn as_str(self) -> &'static str { + match self { + QueuePolicy::Queue => "queue", + QueuePolicy::Wait => "wait", + } + } +} + +#[derive(Debug)] +struct BodyStoreOutcome { + location: BodyLocation, + /// S3 卸载上传仍在后台进行时,入队需与其并行并在返回前 join。 + pending_upload: Option>>, +} + +#[derive(Debug)] +struct BodyLocation { + body_base64: String, + object_ref: String, + size: usize, + storage: &'static str, +} + +#[derive(Debug)] +struct CompletedPart { + part_number: usize, + etag: String, +} + +#[derive(Debug)] +struct ServiceError { + status: StatusCode, + message: String, +} + +impl ServiceError { + fn bad_request(message: impl Into) -> Self { + Self { + status: StatusCode::BAD_REQUEST, + message: message.into(), + } + } + + fn internal(message: impl Into) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + message: message.into(), + } + } + + fn gateway_timeout(message: impl Into) -> Self { + Self { + status: StatusCode::GATEWAY_TIMEOUT, + message: message.into(), + } + } + + fn payload_too_large(message: impl Into) -> Self { + Self { + status: StatusCode::PAYLOAD_TOO_LARGE, + message: message.into(), + } + } + + fn not_implemented(message: impl Into) -> Self { + Self { + status: StatusCode::NOT_IMPLEMENTED, + message: message.into(), + } + } +} + +impl IntoResponse for ServiceError { + fn into_response(self) -> Response { + let body = Json(serde_json::json!({ "error": self.message })); + (self.status, body).into_response() + } +} + +impl std::fmt::Display for ServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for ServiceError {} + +impl From for ServiceError { + fn from(value: fred::error::Error) -> Self { + Self::internal(format!("redis: {value}")) + } +} + +impl From for ServiceError { + fn from(value: reqwest::Error) -> Self { + Self::internal(format!("http: {value}")) + } +} diff --git a/binary/ai-gateway-service/src/app/util.rs b/binary/ai-gateway-service/src/app/util.rs new file mode 100644 index 00000000..3e255e8c --- /dev/null +++ b/binary/ai-gateway-service/src/app/util.rs @@ -0,0 +1,207 @@ +fn required_header(headers: &HeaderMap, name: &str) -> Result { + optional_header(headers, name).ok_or_else(|| ServiceError::bad_request(format!("missing required header `{name}`"))) +} + +fn optional_header(headers: &HeaderMap, name: &str) -> Option { + headers.get(name).and_then(|value| value.to_str().ok()).map(str::trim).filter(|value| !value.is_empty()).map(ToOwned::to_owned) +} + +fn headers_to_json(headers: &HeaderMap) -> Result { + let mut out = HashMap::new(); + for (name, value) in headers { + if let Ok(value) = value.to_str() { + out.insert(name.as_str().to_string(), value.to_string()); + } + } + serde_json::to_string(&out).map_err(|e| ServiceError::internal(format!("serialize headers: {e}"))) +} + +fn should_forward_header(name: &str) -> bool { + let name = name.to_ascii_lowercase(); + !matches!( + name.as_str(), + "host" | "connection" | "content-length" | "transfer-encoding" | "x-original-method" | "x-original-path" | "x-ratelimit-policy" | "x-callback-url" | "x-request-timeout" + ) +} + +fn header_value(value: &str) -> Result { + HeaderValue::from_str(value).map_err(|e| ServiceError::internal(format!("invalid response header value: {e}"))) +} + +fn result_key(state: &AppState, job_id: &str) -> String { + format!("{}{}", state.cfg.result_key_prefix, job_id) +} + +fn result_channel(state: &AppState, job_id: &str) -> String { + format!("{}{}", state.cfg.result_channel_prefix, job_id) +} + +fn new_job_id() -> String { + ulid::Ulid::new().to_string() +} + +fn now_ms() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64 +} + +fn sanitize_key(input: &str) -> String { + input.chars().map(|ch| if ch.is_ascii_alphanumeric() || matches!(ch, ':' | '_' | '-' | '.') { ch } else { '_' }).collect() +} + +fn build_redis_client(url: &str) -> Result { + let config = Config::from_url(url)?; + Builder::from_config(config).build() +} + +fn build_subscriber_client(url: &str) -> Result { + let config = Config::from_url(url)?; + Builder::from_config(config).build_subscriber_client() +} + +fn field_string(fields: &HashMap, key: &str) -> Option { + fields.get(key).and_then(|value| match value { + Value::String(value) => Some(value.to_string()), + Value::Bytes(value) => String::from_utf8(value.to_vec()).ok(), + Value::Integer(value) => Some(value.to_string()), + _ => None, + }) +} + +fn field_bytes(fields: &HashMap, key: &str) -> Option> { + fields.get(key).and_then(|value| match value { + Value::Bytes(value) => Some(value.to_vec()), + Value::String(value) => Some(value.as_bytes().to_vec()), + _ => None, + }) +} + +fn field_u64(fields: &HashMap, key: &str) -> Option { + fields.get(key).and_then(|value| match value { + Value::Integer(value) => (*value).try_into().ok(), + Value::String(value) => value.parse().ok(), + Value::Bytes(value) => std::str::from_utf8(value).ok().and_then(|value| value.parse().ok()), + _ => None, + }) +} + +fn field_u32(fields: &HashMap, key: &str) -> Option { + field_u64(fields, key).and_then(|value| value.try_into().ok()) +} + +fn job_poll_url(job_id: &str) -> String { + format!("/jobs/{job_id}/status") +} + +fn job_status_url_legacy(job_id: &str) -> String { + format!("/v1/jobs/{job_id}") +} + +fn metrics_label(value: &str) -> String { + sanitize_key(value).chars().take(64).collect() +} + +fn body_size_bucket(size: usize, storage: &str) -> &'static str { + if storage == "object" || storage == "s3" { + "s3" + } else if size <= 10 * 1024 { + "inline_small" + } else if size <= 128 * 1024 { + "inline" + } else { + "inline_large" + } +} + +fn format_completed_at_rfc3339(ms: u64) -> String { + let days = (ms / 86_400_000) as i64; + let rem_ms = ms % 86_400_000; + let (year, month, day) = civil_from_days(days); + format!( + "{year:04}-{month:02}-{day:02}T{:02}:{:02}:{:02}.{:03}Z", + rem_ms / 3_600_000, + (rem_ms % 3_600_000) / 60_000, + (rem_ms % 60_000) / 1_000, + rem_ms % 1_000, + ) +} + +fn civil_from_days(z: i64) -> (i64, u32, u32) { + let z = z + 719468; + let era = if z >= 0 { z } else { z - 146096 } / 146097; + let doe = (z - era * 146097) as u64; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let mut y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = doy - (153 * mp + 2) / 5 + 1; + let m = if mp < 10 { mp + 3 } else { mp - 9 }; + if m <= 2 { + y += 1; + } + (y, m as u32, d as u32) +} + +fn decode_callback_result(body_base64: &str) -> serde_json::Value { + if body_base64.is_empty() { + return serde_json::Value::Null; + } + let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(body_base64) else { + return serde_json::json!({ "raw_base64": body_base64 }); + }; + if let Ok(value) = serde_json::from_slice::(&bytes) { + return value; + } + serde_json::json!({ "raw_base64": body_base64 }) +} + +fn poll_result_to_response(result: StoredResult) -> Result { + let status = StatusCode::from_u16(result.http_status).unwrap_or(StatusCode::OK); + let body = base64::engine::general_purpose::STANDARD.decode(result.body_base64).map_err(|e| ServiceError::internal(format!("decode poll result body: {e}")))?; + let mut resp = (status, body).into_response(); + for (name, value) in result.headers { + if let (Ok(name), Ok(value)) = (HeaderName::try_from(name.as_str()), HeaderValue::from_str(&value)) { + resp.headers_mut().insert(name, value); + } + } + resp.headers_mut().insert("x-job-id", header_value(&result.job_id)?); + Ok(resp) +} + +fn tenant_rate_limit_rule_view(key: String, rule: TenantRateLimitRule, ttl_remaining_secs: Option) -> TenantRateLimitRuleView { + TenantRateLimitRuleView { + key, + tenant: rule.tenant, + model: rule.model, + path: rule.path, + policy: rule.policy, + rps: rule.rps, + burst: rule.burst, + cost: rule.cost, + ttl_secs: rule.ttl_secs, + ttl_remaining_secs, + } +} + +/// Parse `XREADGROUP` into the map form used by workers. +/// +/// Redis returns `nil` when a blocking read times out or the stream has no new entries for `>`. +/// Without explicit handling, fred fails to convert that into `HashMap` and the worker aborts +/// before polling the next priority stream — even when another stream already has backlog. +async fn xreadgroup_map_or_empty( + redis: &FredClient, + group: &str, + consumer: &str, + count: Option, + block: Option, + noack: bool, + keys: Vec<&str>, + ids: Vec<&str>, +) -> Result, ServiceError> { + let value: Value = redis.xreadgroup(group, consumer, count, block, noack, keys, ids).await?; + if value.is_null() { + return Ok(HashMap::new()); + } + value + .into_xread_response() + .map_err(|e| ServiceError::internal(format!("parse xreadgroup response: {e}"))) +} diff --git a/binary/ai-gateway-service/src/app/wait_subscriber.rs b/binary/ai-gateway-service/src/app/wait_subscriber.rs new file mode 100644 index 00000000..d336efe8 --- /dev/null +++ b/binary/ai-gateway-service/src/app/wait_subscriber.rs @@ -0,0 +1,59 @@ +/// 共享 Redis Pub/Sub 连接,多 wait 请求复用同一物理连接(设计文档 §连接数)。 +struct WaitSubscriberHub { + client: SubscriberClient, + waiters: tokio::sync::Mutex>>>, +} + +impl WaitSubscriberHub { + async fn new(redis_url: &str) -> Result, ServiceError> { + let client = build_subscriber_client(redis_url).map_err(|e| ServiceError::internal(format!("wait subscriber: {e}")))?; + client.init().await.map_err(|e| ServiceError::internal(format!("wait subscriber init: {e}")))?; + let hub = Arc::new(Self { + client, + waiters: tokio::sync::Mutex::new(HashMap::new()), + }); + let reader = hub.clone(); + tokio::spawn(async move { + reader.run_dispatch_loop().await; + }); + Ok(hub) + } + + async fn wait_for_channel(self: &Arc, channel: &str, timeout: Duration) -> Result<(), ServiceError> { + let (tx, rx) = oneshot::channel(); + { + let mut waiters = self.waiters.lock().await; + waiters.entry(channel.to_string()).or_default().push(tx); + } + self.client + .subscribe(channel) + .await + .map_err(|e| ServiceError::internal(format!("pubsub subscribe: {e}")))?; + + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(())) => Ok(()), + Ok(Err(_)) => Err(ServiceError::internal("wait subscriber channel closed")), + Err(_) => Err(ServiceError::gateway_timeout(format!("timed out waiting for channel {channel}"))), + } + } + + async fn run_dispatch_loop(self: Arc) { + let mut messages = self.client.message_rx(); + loop { + let message = match messages.recv().await { + Ok(message) => message, + Err(e) => { + tracing::warn!(error = %e, "wait subscriber message loop ended"); + break; + } + }; + let channel = message.channel.to_string(); + let mut waiters = self.waiters.lock().await; + if let Some(list) = waiters.remove(&channel) { + for tx in list { + let _ = tx.send(()); + } + } + } + } +} diff --git a/binary/ai-gateway-service/src/lib.rs b/binary/ai-gateway-service/src/lib.rs new file mode 100644 index 00000000..332c62aa --- /dev/null +++ b/binary/ai-gateway-service/src/lib.rs @@ -0,0 +1,5 @@ +//! AI Gateway Service library — 供集成测试与二进制共用。 +pub mod app; + +#[cfg(feature = "test-support")] +pub use app::test_support::{CallbackRecord, HarnessConfig, TestHarness}; diff --git a/binary/ai-gateway-service/src/main.rs b/binary/ai-gateway-service/src/main.rs new file mode 100644 index 00000000..bc547401 --- /dev/null +++ b/binary/ai-gateway-service/src/main.rs @@ -0,0 +1,4 @@ +#[tokio::main] +async fn main() -> Result<(), Box> { + ai_gateway_service::app::run().await +} diff --git a/binary/ai-gateway-service/tests/fixtures/small.json b/binary/ai-gateway-service/tests/fixtures/small.json new file mode 100644 index 00000000..ae09ea7b --- /dev/null +++ b/binary/ai-gateway-service/tests/fixtures/small.json @@ -0,0 +1 @@ +{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]} diff --git a/binary/ai-gateway-service/tests/hurl/admin.hurl b/binary/ai-gateway-service/tests/hurl/admin.hurl new file mode 100644 index 00000000..951cd3a8 --- /dev/null +++ b/binary/ai-gateway-service/tests/hurl/admin.hurl @@ -0,0 +1,31 @@ +# TC-RL-05 Admin 租户规则 +PUT {{service_url}}/v1/admin/tenant-rate-limits +Content-Type: application/json +``` +{ + "tenant": "hurl-admin", + "rps": 1, + "burst": 1, + "cost": 1 +} +``` + +HTTP 200 +[Asserts] +jsonpath "$.tenant" == "hurl-admin" + +POST {{service_url}}/v1/ratelimit/check +X-Tenant-Id: hurl-admin +X-RateLimit-Policy: abandon + +HTTP 200 +[Asserts] +jsonpath "$.allowed" == true + +POST {{service_url}}/v1/ratelimit/check +X-Tenant-Id: hurl-admin +X-RateLimit-Policy: abandon + +HTTP 200 +[Asserts] +jsonpath "$.allowed" == false diff --git a/binary/ai-gateway-service/tests/hurl/metrics.hurl b/binary/ai-gateway-service/tests/hurl/metrics.hurl new file mode 100644 index 00000000..ac6a0fc2 --- /dev/null +++ b/binary/ai-gateway-service/tests/hurl/metrics.hurl @@ -0,0 +1,13 @@ +# TC-MET-01 / TC-DEP-02 +GET {{service_url}}/healthz + +HTTP 200 +[Asserts] +body == "ok" + +GET {{service_url}}/metrics + +HTTP 200 +[Asserts] +body contains "queue_depth" +body contains "pel_size" diff --git a/binary/ai-gateway-service/tests/hurl/queue.hurl b/binary/ai-gateway-service/tests/hurl/queue.hurl new file mode 100644 index 00000000..3c042534 --- /dev/null +++ b/binary/ai-gateway-service/tests/hurl/queue.hurl @@ -0,0 +1,21 @@ +# TC-Q-02 / TC-HDR-04 / TC-HDR-05 +POST {{service_url}}/v1/queue/enqueue +X-Tenant-Id: hurl-queue +X-RateLimit-Policy: queue +X-Callback-URL: {{callback_url}} +Content-Type: application/json +file,small.json; + +HTTP 202 +[Asserts] +header "X-Job-Id" exists +jsonpath "$.status" == "queued" +jsonpath "$.poll_url" matches "/jobs/.+/status" + +POST {{service_url}}/v1/queue/enqueue +X-Tenant-Id: hurl-queue +X-RateLimit-Policy: queue +Content-Type: application/json +file,small.json; + +HTTP 400 diff --git a/binary/ai-gateway-service/tests/hurl/ratelimit.hurl b/binary/ai-gateway-service/tests/hurl/ratelimit.hurl new file mode 100644 index 00000000..2a2771ca --- /dev/null +++ b/binary/ai-gateway-service/tests/hurl/ratelimit.hurl @@ -0,0 +1,15 @@ +# TC-RL-02 / TC-HDR-03:限流 check 与缺 tenant +POST {{service_url}}/v1/ratelimit/check +X-Tenant-Id: hurl-tenant-a +X-RateLimit-Policy: abandon +X-Original-Path: /v1/chat + +HTTP 200 +[Asserts] +jsonpath "$.allowed" == true +jsonpath "$.retry_after_ms" == 0 + +POST {{service_url}}/v1/ratelimit/check +X-RateLimit-Policy: abandon + +HTTP 400 diff --git a/binary/ai-gateway-service/tests/hurl/wait.hurl b/binary/ai-gateway-service/tests/hurl/wait.hurl new file mode 100644 index 00000000..6118c287 --- /dev/null +++ b/binary/ai-gateway-service/tests/hurl/wait.hurl @@ -0,0 +1,15 @@ +# TC-W-02:wait 成功 +POST {{service_url}}/v1/queue/enqueue-and-wait +X-Tenant-Id: hurl-wait +X-RateLimit-Policy: wait +X-Request-Timeout: 10 +X-Original-Method: POST +X-Original-Path: /v1/chat +Content-Type: application/json +file,small.json; + +HTTP 200 +[Asserts] +header "X-Job-Id" exists +header "X-Queue-Wait-Ms" exists +jsonpath "$.upstream" == true diff --git a/binary/ai-gateway-service/tests/integration/admin_tenant_limit.rs b/binary/ai-gateway-service/tests/integration/admin_tenant_limit.rs new file mode 100644 index 00000000..ddca8941 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/admin_tenant_limit.rs @@ -0,0 +1,64 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::TestHarness; + +/// TC-RL-05 / TC-RL-06:Admin 租户规则写入并生效。 +#[tokio::test] +async fn tc_rl_05_admin_tenant_rate_limit() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_rps: Some(100), + rate_limit_burst: Some(100), + ..Default::default() + }) + .await; + + let rule = serde_json::json!({ + "tenant": "admin-tenant", + "rps": 1, + "burst": 1, + "cost": 1 + }); + let put = h + .client + .put(format!("{}/v1/admin/tenant-rate-limits", h.base_url)) + .json(&rule) + .send() + .await + .expect("put rule"); + assert_eq!(put.status(), 200); + + let first = h.check_rate_limit("admin-tenant", "abandon").await.json::().await.unwrap(); + assert_eq!(first["allowed"], true); + let second = h.check_rate_limit("admin-tenant", "abandon").await.json::().await.unwrap(); + assert_eq!(second["allowed"], false); +} + +/// TC-RL-06:model 维度规则更具体时生效。 +#[tokio::test] +async fn tc_rl_06_model_specific_rule() { + let h = TestHarness::start().await; + let rule = serde_json::json!({ + "tenant": "model-tenant", + "model": "gpt-4", + "rps": 1, + "burst": 1 + }); + h.client + .put(format!("{}/v1/admin/tenant-rate-limits", h.base_url)) + .json(&rule) + .send() + .await + .expect("put"); + + let resp = h + .client + .post(format!("{}/v1/ratelimit/check", h.base_url)) + .header("x-tenant-id", "model-tenant") + .header("x-model", "gpt-4") + .header("x-ratelimit-policy", "abandon") + .send() + .await + .expect("check"); + let first = resp.json::().await.unwrap(); + assert_eq!(first["allowed"], true); +} diff --git a/binary/ai-gateway-service/tests/integration/body_store.rs b/binary/ai-gateway-service/tests/integration/body_store.rs new file mode 100644 index 00000000..9419753e --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/body_store.rs @@ -0,0 +1,39 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::{small_body, TestHarness}; + +/// TC-BODY-01:小 body inline 入队。 +#[tokio::test] +async fn tc_body_01_inline_enqueue() { + let h = TestHarness::start().await; + let resp = h.enqueue("inline-t", small_body(), axum::http::HeaderMap::new()).await; + assert_eq!(resp.status(), 202); +} + +/// TC-BODY-03:无 S3 时大 body 413。 +#[tokio::test] +async fn tc_body_03_large_body_without_s3_rejected() { + let h = TestHarness::start_config(HarnessConfig { + inline_threshold: Some(1024), + clear_object_store: true, + ..Default::default() + }) + .await; + let large = vec![0u8; 2048]; + let resp = h.enqueue("large-t", large, axum::http::HeaderMap::new()).await; + assert_eq!(resp.status(), 413); +} + +/// TC-HDR-03:缺 tenant。 +#[tokio::test] +async fn tc_hdr_03_missing_tenant() { + let h = TestHarness::start().await; + let resp = h + .client + .post(format!("{}/v1/ratelimit/check", h.base_url)) + .header("x-ratelimit-policy", "abandon") + .send() + .await + .expect("check"); + assert_eq!(resp.status(), 400); +} diff --git a/binary/ai-gateway-service/tests/integration/common.rs b/binary/ai-gateway-service/tests/integration/common.rs new file mode 100644 index 00000000..47781794 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/common.rs @@ -0,0 +1,9 @@ +pub use ai_gateway_service::TestHarness; + +pub fn small_body() -> Vec { + br#"{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}"#.to_vec() +} + +pub async fn parse_rate_limit(resp: reqwest::Response) -> serde_json::Value { + resp.json().await.expect("rate limit json") +} diff --git a/binary/ai-gateway-service/tests/integration/enqueue_queue.rs b/binary/ai-gateway-service/tests/integration/enqueue_queue.rs new file mode 100644 index 00000000..b8941e98 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/enqueue_queue.rs @@ -0,0 +1,108 @@ +use axum::http::HeaderMap; + +use ai_gateway_service::HarnessConfig; + +use super::common::{small_body, TestHarness}; + +/// TC-HDR-04:queue 缺 callback。 +#[tokio::test] +async fn tc_hdr_04_queue_missing_callback() { + let h = TestHarness::start().await; + let resp = h + .client + .post(format!("{}/v1/queue/enqueue", h.base_url)) + .header("x-tenant-id", "t1") + .header("x-ratelimit-policy", "queue") + .body(small_body()) + .send() + .await + .expect("enqueue"); + assert_eq!(resp.status(), 400); +} + +/// TC-HDR-05:非 HTTPS 回调(生产配置)。 +#[tokio::test] +async fn tc_hdr_05_https_callback_required() { + let h = TestHarness::start_config(HarnessConfig { + require_https_callback: Some(true), + ..Default::default() + }) + .await; + let resp = h + .client + .post(format!("{}/v1/queue/enqueue", h.base_url)) + .header("x-tenant-id", "t1") + .header("x-ratelimit-policy", "queue") + .header("x-callback-url", "http://insecure.example/cb") + .body(small_body()) + .send() + .await + .expect("enqueue"); + assert_eq!(resp.status(), 400); +} + +/// TC-Q-02 / TC-Q-03:入队 202 + ULID job_id + poll_url。 +#[tokio::test] +async fn tc_q_02_enqueue_returns_202_with_job_id() { + let h = TestHarness::start().await; + + let resp = h.enqueue("queue-t", small_body(), HeaderMap::new()).await; + assert_eq!(resp.status(), 202); + let job_id = resp.headers().get("x-job-id").unwrap().to_str().unwrap().to_string(); + assert_eq!(job_id.len(), 26); + let json: serde_json::Value = resp.json().await.expect("json"); + assert_eq!(json["status"], "queued"); + assert!(json["poll_url"].as_str().unwrap().contains(&job_id)); +} + +/// TC-Q-04 / TC-Q-05:Worker 回调四字段 JSON。 +#[tokio::test] +async fn tc_q_04_callback_payload_shape() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_burst: Some(10), + rate_limit_rps: Some(100), + ..Default::default() + }) + .await; + + let resp = h.enqueue("cb-t", small_body(), HeaderMap::new()).await; + assert_eq!(resp.status(), 202); + let job_id = resp.headers().get("x-job-id").unwrap().to_str().unwrap().to_string(); + + for _ in 0..40 { + if !h.callback_records().is_empty() { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + let records = h.callback_records(); + assert!(!records.is_empty(), "expected callback"); + let rec = records.iter().find(|r| r.job_id == job_id).expect("job callback"); + assert!(rec.body.get("job_id").is_some()); + assert!(rec.body.get("status").is_some()); + assert!(rec.body.get("result").is_some()); + assert!(rec.body.get("completed_at").is_some()); + assert!(rec.body.get("http_status").is_none()); +} + +/// TC-Q-07:dev 允许 http 回调。 +#[tokio::test] +async fn tc_q_07_http_callback_when_disabled_check() { + let h = TestHarness::start_config(HarnessConfig { + require_https_callback: Some(false), + ..Default::default() + }) + .await; + let resp = h + .client + .post(format!("{}/v1/queue/enqueue", h.base_url)) + .header("x-tenant-id", "t-http") + .header("x-ratelimit-policy", "queue") + .header("x-callback-url", "http://127.0.0.1:9/cb") + .body(small_body()) + .send() + .await + .expect("enqueue"); + assert_eq!(resp.status(), 202); +} diff --git a/binary/ai-gateway-service/tests/integration/enqueue_wait.rs b/binary/ai-gateway-service/tests/integration/enqueue_wait.rs new file mode 100644 index 00000000..717bed23 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/enqueue_wait.rs @@ -0,0 +1,54 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::{small_body, TestHarness}; + +/// TC-W-02:wait 成功返回上游 body 与等待头。 +#[tokio::test] +async fn tc_w_02_wait_success_headers() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_burst: Some(10), + wait_timeout_secs: Some(10), + ..Default::default() + }) + .await; + + let resp = h.enqueue_and_wait("wait-ok", small_body(), None).await; + assert_eq!(resp.status(), 200); + assert!(resp.headers().contains_key("x-job-id")); + assert!(resp.headers().contains_key("x-queue-wait-ms")); + let body: serde_json::Value = resp.json().await.expect("json"); + assert_eq!(body["upstream"], true); +} + +/// TC-W-04:wait 超时 504(短 timeout)。 +#[tokio::test] +async fn tc_w_04_wait_timeout_504() { + let h = TestHarness::start_config(HarnessConfig { + wait_timeout_secs: Some(1), + ..Default::default() + }) + .await; + + let resp = h.enqueue_and_wait("wait-to", small_body(), Some(1)).await; + assert!(resp.status() == 200 || resp.status() == 504); + if resp.status() == 504 { + let json: serde_json::Value = resp.json().await.expect("json"); + assert_eq!(json["error"], "timeout"); + assert!(json.get("job_id").is_some()); + assert!(json.get("waited_ms").is_some()); + } +} + +/// TC-W-05:完成后 poll 返回 LLM 原始响应。 +#[tokio::test] +async fn tc_w_05_poll_returns_upstream_body() { + let h = TestHarness::start().await; + let resp = h.enqueue_and_wait("poll-t", small_body(), Some(10)).await; + assert_eq!(resp.status(), 200); + let job_id = resp.headers().get("x-job-id").unwrap().to_str().unwrap().to_string(); + + let poll = h.get_job(&job_id).await; + assert_eq!(poll.status(), 200); + let body: serde_json::Value = poll.json().await.expect("poll json"); + assert_eq!(body["upstream"], true); +} diff --git a/binary/ai-gateway-service/tests/integration/metrics.rs b/binary/ai-gateway-service/tests/integration/metrics.rs new file mode 100644 index 00000000..f30ea520 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/metrics.rs @@ -0,0 +1,35 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::TestHarness; + +/// TC-MET-01:metrics 含 queue_depth / pel_size。 +#[tokio::test] +async fn tc_met_01_metrics_endpoint() { + let h = TestHarness::start().await; + let body = h.metrics().await; + assert!(body.contains("queue_depth")); + assert!(body.contains("pel_size")); + assert!(body.contains("enqueue_total")); +} + +/// TC-MET-02:rate_limited 带标签(触发后)。 +#[tokio::test] +async fn tc_met_02_labeled_rate_limited() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_burst: Some(1), + ..Default::default() + }) + .await; + h.exhaust_tenant("met-t", "wait", 2).await; + let body = h.metrics().await; + assert!(body.contains("rate_limited_total{policy=\"wait\",tenant=\"met-t\"}")); +} + +/// TC-DEP-02 smoke:healthz。 +#[tokio::test] +async fn tc_dep_02_healthz() { + let h = TestHarness::start().await; + let resp = h.client.get(format!("{}/healthz", h.base_url)).send().await.expect("healthz"); + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "ok"); +} diff --git a/binary/ai-gateway-service/tests/integration/mod.rs b/binary/ai-gateway-service/tests/integration/mod.rs new file mode 100644 index 00000000..ebcf7229 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/mod.rs @@ -0,0 +1,9 @@ +mod common; + +mod ratelimit; +mod enqueue_queue; +mod enqueue_wait; +mod body_store; +mod worker_reliability; +mod admin_tenant_limit; +mod metrics; diff --git a/binary/ai-gateway-service/tests/integration/ratelimit.rs b/binary/ai-gateway-service/tests/integration/ratelimit.rs new file mode 100644 index 00000000..5a8ca47f --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/ratelimit.rs @@ -0,0 +1,70 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::{parse_rate_limit, TestHarness}; + +/// TC-RL-01 / TC-RL-02:租户隔离与配额内 allowed。 +#[tokio::test] +async fn tc_rl_01_tenant_isolation_and_allowed() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_rps: Some(1), + rate_limit_burst: Some(1), + ..Default::default() + }) + .await; + + let a1 = parse_rate_limit(h.check_rate_limit("tenant-a", "abandon").await).await; + assert_eq!(a1["allowed"], true); + + let a2 = parse_rate_limit(h.check_rate_limit("tenant-a", "abandon").await).await; + assert_eq!(a2["allowed"], false); + + let b1 = parse_rate_limit(h.check_rate_limit("tenant-b", "abandon").await).await; + assert_eq!(b1["allowed"], true); +} + +/// TC-RL-03:超额时 metrics 计数。 +#[tokio::test] +async fn tc_rl_03_rate_limited_metrics() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_burst: Some(1), + rate_limit_rps: Some(1), + ..Default::default() + }) + .await; + + h.exhaust_tenant("metrics-tenant", "queue", 2).await; + let body = h.metrics().await; + assert!(body.contains("rate_limited_total")); + assert!(body.contains("policy=\"queue\"")); + assert!(body.contains("tenant=\"metrics-tenant\"")); +} + +/// TC-RL-04:burst 超发后第三次拒绝。 +#[tokio::test] +async fn tc_rl_04_burst_then_deny() { + let h = TestHarness::start_config(HarnessConfig { + rate_limit_burst: Some(2), + rate_limit_rps: Some(100), + ..Default::default() + }) + .await; + + for _ in 0..2 { + let v = parse_rate_limit(h.check_rate_limit("burst-t", "abandon").await).await; + assert_eq!(v["allowed"], true); + } + let v = parse_rate_limit(h.check_rate_limit("burst-t", "abandon").await).await; + assert_eq!(v["allowed"], false); + assert!(v["retry_after_ms"].as_i64().unwrap_or(0) > 0); +} + +/// TC-RL-07:Redis 限流 key 仅 tenant 维度。 +#[tokio::test] +async fn tc_rl_07_tenant_only_redis_keys() { + let h = TestHarness::start().await; + let _ = h.check_rate_limit("key-tenant", "abandon").await; + let keys = h.ratelimit_keys_for_tenant("key-tenant").await; + assert!(keys.iter().any(|k| k.ends_with(":tokens"))); + assert!(keys.iter().any(|k| k.ends_with(":ts"))); + assert!(!keys.iter().any(|k| k.contains("model") || k.contains("path"))); +} diff --git a/binary/ai-gateway-service/tests/integration/worker_reliability.rs b/binary/ai-gateway-service/tests/integration/worker_reliability.rs new file mode 100644 index 00000000..35bcc0e4 --- /dev/null +++ b/binary/ai-gateway-service/tests/integration/worker_reliability.rs @@ -0,0 +1,42 @@ +use ai_gateway_service::HarnessConfig; + +use super::common::{small_body, TestHarness}; + +/// TC-WK-01:多条 job 均可被 worker 完成(smoke)。 +#[tokio::test] +async fn tc_wk_01_multiple_jobs_complete() { + let h = TestHarness::start().await; + for i in 0..3 { + let resp = h.enqueue(&format!("wk-{i}"), small_body(), axum::http::HeaderMap::new()).await; + assert_eq!(resp.status(), 202); + } + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let metrics = h.metrics().await; + assert!(metrics.contains("worker_completed_total")); +} + +/// TC-WK-03:回调失败进入 retry(不可达 URL)。 +#[tokio::test] +async fn tc_wk_03_callback_failure_retry_stream() { + let h = TestHarness::start_config(HarnessConfig { + require_https_callback: Some(false), + ..Default::default() + }) + .await; + + let resp = h + .client + .post(format!("{}/v1/queue/enqueue", h.base_url)) + .header("x-tenant-id", "retry-t") + .header("x-ratelimit-policy", "queue") + .header("x-callback-url", "http://127.0.0.1:1/unreachable") + .body(small_body()) + .send() + .await + .expect("enqueue"); + assert_eq!(resp.status(), 202); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + let depth = h.callback_retry_depth().await; + assert!(depth >= 0); +} diff --git a/binary/spacegate/Cargo.toml b/binary/spacegate/Cargo.toml index 30c6f799..911eadf6 100644 --- a/binary/spacegate/Cargo.toml +++ b/binary/spacegate/Cargo.toml @@ -28,6 +28,7 @@ axum = ["spacegate-shell/ext-axum"] static-openssl = ["openssl/vendored"] dylib = ["spacegate-shell/plugin-dylib"] plugin-all = ["spacegate-shell/plugin-all"] +wasm = ["spacegate-shell/plugin-wasm"] [dependencies] # envy = { } clap = { version = "4.5", features = ["derive", "env"] } @@ -35,8 +36,6 @@ serde = { workspace = true, features = ["derive"] } spacegate-shell = { workspace = true } openssl = { version = "0.10" } # tardis = { workspace = true, features = ["console-subscriber"] } -# tardis = { workspace = true } -tracing-subscriber = { workspace = true, features = ["env-filter"] } tokio = { version = "1", features = ["full"] } [dev-dependencies] diff --git a/binary/spacegate/src/main.rs b/binary/spacegate/src/main.rs index 12d177b4..32be3427 100644 --- a/binary/spacegate/src/main.rs +++ b/binary/spacegate/src/main.rs @@ -2,8 +2,6 @@ use clap::Parser; use spacegate_shell::BoxError; mod args; fn main() -> Result<(), BoxError> { - // TODO: more subscriber required - tracing_subscriber::fmt().with_env_filter(tracing_subscriber::EnvFilter::from_default_env()).init(); let args = args::Args::parse(); #[allow(unused_variables)] if let Some(plugins) = args.plugins { diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml index 2bb92beb..11e08e90 100644 --- a/crates/config/Cargo.toml +++ b/crates/config/Cargo.toml @@ -47,6 +47,7 @@ tokio-rustls.workspace = true ipnet = { workspace = true, features = ["serde"] } bytes = { workspace = true } +base64 = { workspace = true } kube = { workspace = true, optional = true } k8s-openapi = { workspace = true, optional = true } diff --git a/crates/config/src/service.rs b/crates/config/src/service.rs index eae9eb5f..d942fc69 100644 --- a/crates/config/src/service.rs +++ b/crates/config/src/service.rs @@ -134,6 +134,7 @@ pub trait Retrieve: Sync + Send { gateways, plugins: PluginInstanceMap::from_config_vec(plugins), api_port: None, + observability: Default::default(), }) } } diff --git a/crates/config/src/service/fs/create.rs b/crates/config/src/service/fs/create.rs index 8c15b95c..38be51a3 100644 --- a/crates/config/src/service/fs/create.rs +++ b/crates/config/src/service/fs/create.rs @@ -15,14 +15,15 @@ where F: ConfigFormat + Send + Sync, { async fn create_plugin(&self, id: &spacegate_model::PluginInstanceId, value: serde_json::Value) -> Result<(), BoxError> { - self.modify_cached(|config| { - if config.plugins.get(id).is_some() { - return Err("plugin existed".into()); - } - config.plugins.insert(id.clone(), value); - Ok(()) - }) - .await + let path = self.plugin_path(id); + if path.exists() { + return Err("plugin existed".into()); + } + // 仅写入新插件文件,避免 rewrite 整个 /etc/spacegate + tokio::fs::create_dir_all(self.plugin_dir()).await?; + let b_spec = self.format.ser(&value)?; + tokio::fs::write(&path, &b_spec).await?; + Ok(()) } async fn create_config_item(&self, gateway_name: &str, item: ConfigItem) -> Result<(), BoxError> { self.modify_cached(|config| { diff --git a/crates/config/src/service/fs/mod.rs b/crates/config/src/service/fs/mod.rs index e839e811..6e7a362b 100644 --- a/crates/config/src/service/fs/mod.rs +++ b/crates/config/src/service/fs/mod.rs @@ -99,8 +99,17 @@ where pub async fn save_config(&self, config: Config) -> Result<(), BoxError> { // save config - let Config { plugins, gateways, api_port } = config; - let main_config_to_save: Config = Config { api_port, ..Default::default() }; + let Config { + plugins, + gateways, + api_port, + observability, + } = config; + let main_config_to_save: Config = Config { + api_port, + observability, + ..Default::default() + }; let b_main_config = self.format.ser(&main_config_to_save)?; tokio::fs::write(self.entrance_config_path(), &b_main_config).await?; if !plugins.is_empty() { diff --git a/crates/config/src/service/fs/model.rs b/crates/config/src/service/fs/model.rs index ddcffc7c..3267c332 100644 --- a/crates/config/src/service/fs/model.rs +++ b/crates/config/src/service/fs/model.rs @@ -5,7 +5,7 @@ use std::{ use serde::{Deserialize, Serialize}; use serde_json::Value; -use spacegate_model::{constants::DEFAULT_API_PORT, ConfigItem, PluginInstanceId, PluginInstanceMap, PluginInstanceName, SgGateway, SgHttpRoute}; +use spacegate_model::{constants::DEFAULT_API_PORT, ConfigItem, ObservabilityConfig, PluginInstanceId, PluginInstanceMap, PluginInstanceName, SgGateway, SgHttpRoute}; #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] @@ -53,6 +53,7 @@ pub struct MainFileConfig

{ pub gateways: Vec>, pub plugins: PluginConfigs, pub api_port: u16, + pub observability: ObservabilityConfig, } impl

Default for MainFileConfig

{ @@ -61,6 +62,7 @@ impl

Default for MainFileConfig

{ gateways: Default::default(), plugins: Default::default(), api_port: DEFAULT_API_PORT, + observability: Default::default(), } } } @@ -140,6 +142,7 @@ impl MainFileConfig { gateways, plugins: self.plugins, api_port: self.api_port, + observability: self.observability, } } } @@ -191,6 +194,7 @@ impl MainFileConfig { gateways, plugins, api_port: Some(self.api_port), + observability: self.observability, } } } @@ -228,6 +232,7 @@ impl From for MainFileConfig { gateways, plugins, api_port: value.api_port.unwrap_or(DEFAULT_API_PORT), + observability: value.observability, } } } diff --git a/crates/config/src/service/fs/update.rs b/crates/config/src/service/fs/update.rs index d04bf2c2..e7d34a80 100644 --- a/crates/config/src/service/fs/update.rs +++ b/crates/config/src/service/fs/update.rs @@ -12,15 +12,12 @@ where F: ConfigFormat + Send + Sync, { async fn update_plugin(&self, id: &spacegate_model::PluginInstanceId, value: serde_json::Value) -> Result<(), BoxError> { - self.modify_cached(|config| { - if let Some(prev_spec) = config.plugins.get_mut(id) { - *prev_spec = value; - Ok(()) - } else { - Err("plugin not exists".into()) - } - }) - .await + // 仅更新单个插件 JSON,避免 modify_cached 清空整棵配置树(Docker 共享挂载会 EBUSY/EROFS) + tokio::fs::create_dir_all(self.plugin_dir()).await?; + let path = self.plugin_path(id); + let b_spec = self.format.ser(&value)?; + tokio::fs::write(&path, &b_spec).await?; + Ok(()) } async fn update_config_item_gateway(&self, gateway_name: &str, gateway: SgGateway) -> Result<(), BoxError> { self.modify_cached(|config| { diff --git a/crates/config/src/service/k8s/convert.rs b/crates/config/src/service/k8s/convert.rs index 875d6b79..a88ddda6 100644 --- a/crates/config/src/service/k8s/convert.rs +++ b/crates/config/src/service/k8s/convert.rs @@ -2,6 +2,7 @@ use spacegate_model::ext::k8s::crd::sg_filter::K8sSgFilterSpecTargetRef; pub mod filter_k8s_conv; pub mod gateway_k8s_conv; +pub mod higress_wasm_plugin_conv; pub mod route_k8s_conv; pub(crate) trait ToTarget { diff --git a/crates/config/src/service/k8s/convert/gateway_k8s_conv.rs b/crates/config/src/service/k8s/convert/gateway_k8s_conv.rs index f85188db..32019bd0 100644 --- a/crates/config/src/service/k8s/convert/gateway_k8s_conv.rs +++ b/crates/config/src/service/k8s/convert/gateway_k8s_conv.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeMap, hash::Hasher}; use k8s_gateway_api::{Gateway, GatewaySpec, GatewayTlsConfig, Listener, SecretObjectReference}; use k8s_openapi::{api::core::v1::Secret, ByteString}; use kube::{api::ObjectMeta, ResourceExt}; -use spacegate_model::{ext::k8s::helper_struct::SgTargetKind, PluginInstanceId}; +use spacegate_model::{ext::k8s::helper_struct::SgTargetKind, ObservabilityConfig, PluginInstanceId}; use crate::{constants, ext::k8s::crd::sg_filter::K8sSgFilterSpecTargetRef, service::k8s::K8s, SgGateway, SgParameters}; @@ -81,7 +81,7 @@ impl SgGatewayConv for SgGateway { } } -pub(crate) trait SgParametersConv { +pub trait SgParametersConv { fn from_kube_gateway(gateway: &Gateway) -> Self; fn into_kube_gateway(self) -> BTreeMap; } @@ -107,18 +107,71 @@ impl SgParametersConv for SgParameters { if let Some(enable_x_request_id) = self.enable_x_request_id { ann.insert(crate::constants::GATEWAY_ANNOTATION_ENABLE_X_REQUEST_ID.to_string(), enable_x_request_id.to_string()); } + if self.observability.enabled { + ann.insert(crate::constants::GATEWAY_ANNOTATION_OTEL_ENABLED.to_string(), self.observability.enabled.to_string()); + ann.insert(crate::constants::GATEWAY_ANNOTATION_OTEL_SERVICE_NAME.to_string(), self.observability.service_name); + ann.insert(crate::constants::GATEWAY_ANNOTATION_OTEL_ENDPOINT.to_string(), self.observability.otlp_endpoint); + ann.insert(crate::constants::GATEWAY_ANNOTATION_OTEL_PROTOCOL.to_string(), self.observability.protocol.to_string()); + ann.insert( + crate::constants::GATEWAY_ANNOTATION_OTEL_TRACES_ENABLED.to_string(), + self.observability.traces.enabled.to_string(), + ); + ann.insert( + crate::constants::GATEWAY_ANNOTATION_OTEL_TRACES_SAMPLE_RATIO.to_string(), + self.observability.traces.sample_ratio.to_string(), + ); + ann.insert( + crate::constants::GATEWAY_ANNOTATION_OTEL_METRICS_ENABLED.to_string(), + self.observability.metrics.enabled.to_string(), + ); + ann.insert( + crate::constants::GATEWAY_ANNOTATION_OTEL_METRICS_EXPORT_INTERVAL_MS.to_string(), + self.observability.metrics.export_interval_ms.to_string(), + ); + ann.insert( + crate::constants::GATEWAY_ANNOTATION_OTEL_LOGS_ENABLED.to_string(), + self.observability.logs.enabled.to_string(), + ); + ann.insert(crate::constants::GATEWAY_ANNOTATION_OTEL_LOGS_LEVEL.to_string(), self.observability.logs.level); + } ann } fn from_kube_gateway(gateway: &Gateway) -> Self { let gateway_annotations = gateway.metadata.annotations.clone(); if let Some(gateway_annotations) = gateway_annotations { + let mut observability = ObservabilityConfig { + enabled: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_ENABLED).and_then(|v| v.parse::().ok()).unwrap_or_default(), + service_name: gateway_annotations + .get(crate::constants::GATEWAY_ANNOTATION_OTEL_SERVICE_NAME) + .cloned() + .unwrap_or_else(|| ObservabilityConfig::default().service_name), + otlp_endpoint: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_ENDPOINT).cloned().unwrap_or_else(|| ObservabilityConfig::default().otlp_endpoint), + protocol: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_PROTOCOL).and_then(|v| v.parse().ok()).unwrap_or_default(), + ..Default::default() + }; + observability.traces.enabled = + gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_TRACES_ENABLED).and_then(|v| v.parse::().ok()).unwrap_or_default(); + observability.traces.sample_ratio = gateway_annotations + .get(crate::constants::GATEWAY_ANNOTATION_OTEL_TRACES_SAMPLE_RATIO) + .and_then(|v| v.parse::().ok()) + .unwrap_or_else(|| ObservabilityConfig::default().traces.sample_ratio); + observability.metrics.enabled = + gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_METRICS_ENABLED).and_then(|v| v.parse::().ok()).unwrap_or_default(); + observability.metrics.export_interval_ms = gateway_annotations + .get(crate::constants::GATEWAY_ANNOTATION_OTEL_METRICS_EXPORT_INTERVAL_MS) + .and_then(|v| v.parse::().ok()) + .unwrap_or_else(|| ObservabilityConfig::default().metrics.export_interval_ms); + observability.logs.enabled = gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_LOGS_ENABLED).and_then(|v| v.parse::().ok()).unwrap_or_default(); + observability.logs.level = + gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_OTEL_LOGS_LEVEL).cloned().unwrap_or_else(|| ObservabilityConfig::default().logs.level); SgParameters { redis_url: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_REDIS_URL).map(|v| v.to_string()), log_level: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_LOG_LEVEL).map(|v| v.to_string()), lang: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_LANGUAGE).map(|v| v.to_string()), ignore_tls_verification: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_IGNORE_TLS_VERIFICATION).and_then(|v| v.parse::().ok()), enable_x_request_id: gateway_annotations.get(crate::constants::GATEWAY_ANNOTATION_ENABLE_X_REQUEST_ID).and_then(|v| v.parse::().ok()), + observability, } } else { SgParameters { @@ -127,6 +180,7 @@ impl SgParametersConv for SgParameters { lang: None, ignore_tls_verification: None, enable_x_request_id: None, + observability: Default::default(), } } } diff --git a/crates/config/src/service/k8s/convert/higress_wasm_plugin_conv.rs b/crates/config/src/service/k8s/convert/higress_wasm_plugin_conv.rs new file mode 100644 index 00000000..06f32552 --- /dev/null +++ b/crates/config/src/service/k8s/convert/higress_wasm_plugin_conv.rs @@ -0,0 +1,407 @@ +use kube::ResourceExt; +use serde_json::{json, Map, Value}; +use spacegate_model::{ + ext::k8s::crd::wasm_plugin::{HigressWasmPluginMatchRule, WasmPlugin}, + BackendHost, PluginConfig, PluginInstanceId, PluginInstanceName, SgBackendRef, +}; + +const WASM_CODE: &str = "wasm"; + +pub(crate) trait HigressWasmPluginConv { + fn to_spacegate_plugin_id(&self) -> PluginInstanceId; + fn to_spacegate_rule_plugin_id(&self, rule_index: usize) -> PluginInstanceId; + fn to_spacegate_plugin_configs(&self) -> Vec; + fn to_spacegate_plugin_configs_with_oci_auth(&self, oci_auth: Option) -> Vec; + fn to_spacegate_plugin_config_by_id(&self, id: &PluginInstanceId) -> Option; + fn to_spacegate_plugin_config_by_id_with_oci_auth(&self, id: &PluginInstanceId, oci_auth: Option) -> Option; + fn gateway_plugin_id(&self) -> Option; + fn route_plugin_ids(&self, route_name: &str, hostnames: Option<&[String]>) -> Vec; + fn backend_plugin_ids

(&self, backend: &SgBackendRef

) -> Vec; + fn priority(&self) -> i32; + fn phase_rank(&self) -> i32; + fn validate_for_spacegate(&self) -> Result<(), String>; + fn digest(&self) -> Option; + fn oci_registry(&self) -> Option; +} + +impl HigressWasmPluginConv for WasmPlugin { + fn to_spacegate_plugin_id(&self) -> PluginInstanceId { + PluginInstanceId { + code: WASM_CODE.into(), + name: PluginInstanceName::named(format!("higress-{}", self.name_any())), + } + } + + fn to_spacegate_rule_plugin_id(&self, rule_index: usize) -> PluginInstanceId { + PluginInstanceId { + code: WASM_CODE.into(), + name: PluginInstanceName::named(format!("higress-{}-rule-{rule_index}", self.name_any())), + } + } + + fn to_spacegate_plugin_configs(&self) -> Vec { + self.to_spacegate_plugin_configs_with_oci_auth(None) + } + + fn to_spacegate_plugin_configs_with_oci_auth(&self, oci_auth: Option) -> Vec { + let mut configs = Vec::new(); + if !self.spec.default_config_disable { + configs.push(build_plugin_config( + self, + self.to_spacegate_plugin_id(), + self.spec.default_config.clone(), + "default", + oci_auth.clone(), + )); + } + configs.extend(self.spec.match_rules.iter().enumerate().filter(|(_, rule)| !rule.config_disable).map(|(idx, rule)| { + build_plugin_config( + self, + self.to_spacegate_rule_plugin_id(idx), + build_higress_rule_config(rule), + &format!("rule-{idx}"), + oci_auth.clone(), + ) + })); + configs + } + + fn to_spacegate_plugin_config_by_id(&self, id: &PluginInstanceId) -> Option { + self.to_spacegate_plugin_config_by_id_with_oci_auth(id, None) + } + + fn to_spacegate_plugin_config_by_id_with_oci_auth(&self, id: &PluginInstanceId, oci_auth: Option) -> Option { + self.to_spacegate_plugin_configs_with_oci_auth(oci_auth).into_iter().find(|cfg| &cfg.id == id) + } + + fn gateway_plugin_id(&self) -> Option { + (!self.spec.default_config_disable).then(|| self.to_spacegate_plugin_id()) + } + + fn route_plugin_ids(&self, route_name: &str, hostnames: Option<&[String]>) -> Vec { + self.spec + .match_rules + .iter() + .enumerate() + .filter(|(_, rule)| !rule.config_disable && rule_matches_route(rule, route_name, hostnames)) + .map(|(idx, _)| self.to_spacegate_rule_plugin_id(idx)) + .collect() + } + + fn backend_plugin_ids

(&self, backend: &SgBackendRef

) -> Vec { + self.spec + .match_rules + .iter() + .enumerate() + .filter(|(_, rule)| !rule.config_disable && rule_matches_backend(rule, backend)) + .map(|(idx, _)| self.to_spacegate_rule_plugin_id(idx)) + .collect() + } + + fn priority(&self) -> i32 { + self.spec.priority.unwrap_or(0) + } + + fn phase_rank(&self) -> i32 { + phase_rank(self.spec.phase.as_deref()) + } + + fn validate_for_spacegate(&self) -> Result<(), String> { + let url = self.spec.url.trim(); + if url.is_empty() { + return Err("spec.url is empty".to_string()); + } + if is_oci_url(url) && parse_oci_registry(url).is_none() { + return Err("spec.url must include OCI registry and repository".to_string()); + } + Ok(()) + } + + fn digest(&self) -> Option { + self.spec.sha256.clone() + } + + fn oci_registry(&self) -> Option { + parse_oci_registry(&self.spec.url) + } +} + +fn build_plugin_config(plugin: &WasmPlugin, id: PluginInstanceId, plugin_config: Value, instance_suffix: &str, oci_auth: Option) -> PluginConfig { + let namespace = plugin.namespace().unwrap_or_else(|| "default".to_string()); + let resource_version = plugin.resource_version().unwrap_or_else(|| "unknown".to_string()); + let plugin_name = plugin.spec.plugin_name.clone().unwrap_or_else(|| plugin.name_any()); + let image_pull_always = plugin.spec.image_pull_policy.as_deref().map(|v| v.eq_ignore_ascii_case("always")).unwrap_or(false); + + let mut spec = json!({ + "url": plugin.spec.url, + "plugin_config": plugin_config, + "plugin_name": plugin_name, + "plugin_root_id": format!("higress-{}-root-{instance_suffix}", plugin.name_any()), + "plugin_vm_id": format!("higress-{}-{}-{instance_suffix}", namespace, plugin.name_any()), + "module_cache_key": format!("higress-wasmplugin:{namespace}:{}:{resource_version}:{instance_suffix}", plugin.name_any()), + "use_cache": !image_pull_always, + }); + + if let Some(sha256) = plugin.spec.sha256.as_deref().filter(|v| !v.trim().is_empty()) { + spec["sha256"] = Value::String(sha256.to_string()); + } + if let Some(fail_strategy) = plugin.spec.fail_strategy.as_deref().and_then(normalize_fail_strategy) { + spec["fail_strategy"] = Value::String(fail_strategy.to_string()); + } + if let Some(oci_auth) = oci_auth { + spec["oci_auth"] = oci_auth; + } + + PluginConfig { id, spec } +} + +pub(crate) fn sort_higress_wasm_plugins(plugins: &mut [WasmPlugin]) { + plugins.sort_by(|a, b| a.phase_rank().cmp(&b.phase_rank()).then_with(|| b.priority().cmp(&a.priority())).then_with(|| a.name_any().cmp(&b.name_any()))); +} + +fn normalize_fail_strategy(value: &str) -> Option<&'static str> { + match value.trim().to_ascii_lowercase().replace('-', "_").as_str() { + "fail_open" | "failopen" => Some("fail_open"), + "fail_close" | "failclose" => Some("fail_close"), + _ => None, + } +} + +fn build_higress_rule_config(rule: &HigressWasmPluginMatchRule) -> Value { + let mut config = value_to_object(rule.config.clone()); + if !rule.ingress.is_empty() { + config.insert("_match_route_".to_string(), strings_value(rule.ingress.clone())); + } + if !rule.domain.is_empty() { + config.insert("_match_domain_".to_string(), strings_value(rule.domain.clone())); + } + if !rule.service.is_empty() { + config.insert("_match_service_".to_string(), strings_value(rule.service.clone())); + } + if rule.config_disable { + config.insert("_config_disable_".to_string(), Value::Bool(true)); + } + Value::Object(config) +} + +fn value_to_object(value: Value) -> Map { + match value { + Value::Object(map) => map, + Value::Null => Map::new(), + other => { + let mut map = Map::new(); + map.insert("_config_".to_string(), other); + map + } + } +} + +fn strings_value(values: Vec) -> Value { + Value::Array(values.into_iter().map(Value::String).collect()) +} + +fn phase_rank(phase: Option<&str>) -> i32 { + match phase.unwrap_or_default().trim().to_ascii_uppercase().as_str() { + "AUTHN" => 10, + "AUTHZ" => 20, + "STATS" => 90, + _ => 50, + } +} + +fn rule_matches_route(rule: &HigressWasmPluginMatchRule, route_name: &str, hostnames: Option<&[String]>) -> bool { + let route_match = !rule.ingress.is_empty() && rule.ingress.iter().any(|name| name.eq_ignore_ascii_case(route_name)); + let domain_match = + !rule.domain.is_empty() && hostnames.map(|hostnames| hostnames.iter().any(|hostname| rule.domain.iter().any(|domain| domain_matches(domain, hostname)))).unwrap_or(false); + let rule_has_no_explicit_target = rule.ingress.is_empty() && rule.domain.is_empty() && rule.service.is_empty(); + route_match || domain_match || rule_has_no_explicit_target +} + +fn rule_matches_backend

(rule: &HigressWasmPluginMatchRule, backend: &SgBackendRef

) -> bool { + !rule.service.is_empty() && rule.service.iter().any(|service| backend_matches_service(backend, service)) +} + +fn domain_matches(pattern: &str, hostname: &str) -> bool { + let pattern = pattern.trim().trim_end_matches('.'); + let hostname = hostname.trim().trim_end_matches('.'); + if pattern == "*" { + return true; + } + if let Some(suffix) = pattern.strip_prefix("*.") { + return hostname.eq_ignore_ascii_case(suffix) || hostname.to_ascii_lowercase().ends_with(&format!(".{}", suffix.to_ascii_lowercase())); + } + pattern.eq_ignore_ascii_case(hostname) +} + +fn backend_matches_service

(backend: &SgBackendRef

, service: &str) -> bool { + let service = service.trim(); + match &backend.host { + BackendHost::K8sService(data) => { + data.name.eq_ignore_ascii_case(service) || data.namespace.as_ref().map(|ns| format!("{}.{}", data.name, ns).eq_ignore_ascii_case(service)).unwrap_or(false) + } + _ => backend.get_host().eq_ignore_ascii_case(service), + } +} + +fn is_oci_url(url: &str) -> bool { + let lower = url.to_ascii_lowercase(); + lower.starts_with("oci://") || lower.starts_with("docker://") || lower.starts_with("image://") || lower.starts_with("oci+http://") +} + +fn parse_oci_registry(url: &str) -> Option { + let trim = url.trim(); + let rest = trim.strip_prefix("oci://").or_else(|| trim.strip_prefix("docker://")).or_else(|| trim.strip_prefix("image://")).or_else(|| trim.strip_prefix("oci+http://"))?; + let (registry, repository) = rest.split_once('/')?; + (!registry.trim().is_empty() && !repository.trim().is_empty()).then(|| registry.to_string()) +} + +#[cfg(test)] +mod tests { + use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; + use serde_json::json; + use spacegate_model::{ + ext::k8s::crd::wasm_plugin::{HigressWasmPluginSpec, WasmPlugin}, + BackendHost, K8sServiceData, SgBackendRef, + }; + + use super::*; + + #[test] + fn converts_higress_wasmplugin_to_spacegate_wasm_plugin_config() { + let plugin = WasmPlugin { + metadata: ObjectMeta { + name: Some("authn".to_string()), + namespace: Some("gw".to_string()), + resource_version: Some("42".to_string()), + ..Default::default() + }, + spec: HigressWasmPluginSpec { + url: "https://example.com/authn.wasm".to_string(), + plugin_name: Some("authn-plugin".to_string()), + sha256: Some("sha256:abc".to_string()), + phase: Some("AUTHN".to_string()), + priority: Some(100), + image_pull_policy: Some("IfNotPresent".to_string()), + image_pull_secret: None, + default_config_disable: false, + default_config: json!({"issuer": "spacegate"}), + match_rules: vec![HigressWasmPluginMatchRule { + ingress: vec!["api-route".to_string()], + domain: vec!["api.example.com".to_string()], + service: vec![], + config_disable: false, + config: json!({"issuer": "route"}), + }], + fail_strategy: Some("FAIL_CLOSE".to_string()), + }, + status: None, + }; + + let cfg = plugin.to_spacegate_plugin_configs().into_iter().next().expect("default config"); + assert_eq!(cfg.id.to_string(), "wasm-n-higress-authn"); + assert_eq!(cfg.spec["url"], "https://example.com/authn.wasm"); + assert_eq!(cfg.spec["sha256"], "sha256:abc"); + assert_eq!(cfg.spec["fail_strategy"], "fail_close"); + assert_eq!(cfg.spec["module_cache_key"], "higress-wasmplugin:gw:authn:42:default"); + assert_eq!(cfg.spec["plugin_config"]["issuer"], "spacegate"); + + let rule_cfg = plugin.to_spacegate_plugin_config_by_id(&plugin.to_spacegate_rule_plugin_id(0)).expect("rule config"); + assert_eq!(rule_cfg.id.to_string(), "wasm-n-higress-authn-rule-0"); + assert_eq!(rule_cfg.spec["plugin_config"]["issuer"], "route"); + assert_eq!(rule_cfg.spec["plugin_config"]["_match_route_"][0], "api-route"); + assert_eq!(rule_cfg.spec["plugin_config"]["_match_domain_"][0], "api.example.com"); + } + + #[test] + fn rule_ids_match_routes_domains_and_services() { + let plugin = WasmPlugin { + metadata: ObjectMeta { + name: Some("ratelimit".to_string()), + namespace: Some("gw".to_string()), + resource_version: Some("7".to_string()), + ..Default::default() + }, + spec: HigressWasmPluginSpec { + url: "file:///tmp/ratelimit.wasm".to_string(), + plugin_name: None, + sha256: None, + phase: Some("AUTHZ".to_string()), + priority: Some(10), + image_pull_policy: None, + image_pull_secret: None, + default_config_disable: true, + default_config: serde_json::Value::Null, + match_rules: vec![ + HigressWasmPluginMatchRule { + ingress: vec!["api-route".to_string()], + domain: vec![], + service: vec![], + config_disable: false, + config: json!({"limit": 10}), + }, + HigressWasmPluginMatchRule { + ingress: vec![], + domain: vec!["*.example.com".to_string()], + service: vec![], + config_disable: false, + config: json!({"limit": 20}), + }, + HigressWasmPluginMatchRule { + ingress: vec![], + domain: vec![], + service: vec!["backend.default".to_string()], + config_disable: false, + config: json!({"limit": 30}), + }, + ], + fail_strategy: None, + }, + status: None, + }; + + let route_ids = plugin.route_plugin_ids("api-route", Some(&["shop.example.com".to_string()])); + assert_eq!( + route_ids.iter().map(ToString::to_string).collect::>(), + vec!["wasm-n-higress-ratelimit-rule-0", "wasm-n-higress-ratelimit-rule-1"] + ); + + let backend = SgBackendRef:: { + host: BackendHost::K8sService(K8sServiceData { + name: "backend".to_string(), + namespace: Some("default".to_string()), + }), + ..Default::default() + }; + let backend_ids = plugin.backend_plugin_ids(&backend); + assert_eq!(backend_ids.iter().map(ToString::to_string).collect::>(), vec!["wasm-n-higress-ratelimit-rule-2"]); + } + + #[test] + fn validates_oci_urls_for_status() { + let plugin = WasmPlugin { + metadata: ObjectMeta { + name: Some("oci-plugin".to_string()), + namespace: Some("gw".to_string()), + resource_version: Some("1".to_string()), + ..Default::default() + }, + spec: HigressWasmPluginSpec { + url: "oci://registry.example.com/plugin:v1".to_string(), + plugin_name: None, + sha256: None, + phase: None, + priority: None, + image_pull_policy: Some("Always".to_string()), + image_pull_secret: Some("pull-secret".to_string()), + default_config_disable: false, + default_config: serde_json::Value::Null, + match_rules: vec![], + fail_strategy: None, + }, + status: None, + }; + + plugin.validate_for_spacegate().expect("OCI should be accepted"); + assert_eq!(plugin.oci_registry().as_deref(), Some("registry.example.com")); + } +} diff --git a/crates/config/src/service/k8s/listen.rs b/crates/config/src/service/k8s/listen.rs index 9cb21d72..569e9ba5 100644 --- a/crates/config/src/service/k8s/listen.rs +++ b/crates/config/src/service/k8s/listen.rs @@ -6,8 +6,9 @@ use std::{ use futures_util::{pin_mut, TryStreamExt}; use k8s_gateway_api::{Gateway, HttpRoute}; +use k8s_openapi::api::core::v1::Secret; use kube::{ - api::ObjectMeta, + api::{ObjectMeta, PostParams}, runtime::{watcher, WatchStreamExt}, Api, Resource, ResourceExt, }; @@ -16,12 +17,19 @@ use spacegate_model::{ ext::k8s::crd::{ http_spaceroute::HttpSpaceroute, sg_filter::{K8sSgFilterSpecTargetRef, SgFilter}, + wasm_plugin::{HigressWasmPluginStatus, WasmPlugin}, }, BoxResult, Config, PluginInstanceId, }; use tracing::debug; -use crate::service::{k8s::convert::filter_k8s_conv::PluginIdConv, ConfigEventType, ConfigType, CreateListener, Listen, ListenEvent, Retrieve as _}; +use crate::service::{ + k8s::{ + convert::{filter_k8s_conv::PluginIdConv, higress_wasm_plugin_conv::HigressWasmPluginConv as _}, + retrieve::oci_auth_from_secret, + }, + ConfigEventType, ConfigType, CreateListener, Listen, ListenEvent, Retrieve as _, +}; use super::K8s; @@ -31,6 +39,43 @@ pub struct K8sListener { impl K8sListener {} impl K8s { + async fn reconcile_wasm_plugin_status(api: &Api, secret_api: &Api, plugin: &WasmPlugin) { + let (phase, message) = match Self::validate_wasm_plugin(api, secret_api, plugin).await { + Ok(()) => ("Accepted".to_string(), "WasmPlugin accepted by Spacegate".to_string()), + Err(e) => ("Unsupported".to_string(), e), + }; + let status = HigressWasmPluginStatus { + observed_generation: plugin.meta().generation, + phase: Some(phase), + digest: plugin.digest(), + message: Some(message), + }; + let mut update = plugin.clone(); + update.status = Some(status); + if let Err(e) = api.replace_status(&plugin.name_any(), &PostParams::default(), serde_json::to_vec(&update).unwrap_or_default()).await { + tracing::warn!(name = %plugin.name_any(), error = %e, "failed to update WasmPlugin status"); + } + } + + async fn validate_wasm_plugin(_api: &Api, secret_api: &Api, plugin: &WasmPlugin) -> Result<(), String> { + plugin.validate_for_spacegate()?; + let Some(registry) = plugin.oci_registry() else { + return Ok(()); + }; + let Some(secret_name) = plugin.spec.image_pull_secret.as_deref().map(str::trim).filter(|v| !v.is_empty()) else { + return Ok(()); + }; + let secret = secret_api + .get_opt(secret_name) + .await + .map_err(|e| format!("read imagePullSecret {secret_name}: {e}"))? + .ok_or_else(|| format!("imagePullSecret {secret_name} not found"))?; + if oci_auth_from_secret(&secret, ®istry).is_none() { + return Err(format!("imagePullSecret {secret_name} does not contain credentials for {registry}")); + } + Ok(()) + } + async fn process_http_spaceroute_event( move_evt_tx: &tokio::sync::mpsc::UnboundedSender<(ConfigType, ConfigEventType)>, move_http_route_names: &[String], @@ -111,6 +156,8 @@ impl CreateListener for K8s { let http_route_api: Api = self.get_namespace_api(); let http_spaceroute_api: Api = self.get_namespace_api(); let sg_filter_api: Api = self.get_namespace_api(); + let wasm_plugin_api: Api = self.get_namespace_api(); + let secret_api: Api = self.get_namespace_api(); let move_gateway_names = config.gateways.clone().into_values().map(|item| item.gateway.name).collect::>(); let move_evt_tx = evt_tx.clone(); @@ -326,6 +373,56 @@ impl CreateListener for K8s { } }); + let move_evt_tx = evt_tx.clone(); + let wasm_plugin_status_api = wasm_plugin_api.clone(); + let wasm_plugin_secret_api = secret_api.clone(); + // watch Higress-compatible WasmPlugin. A WasmPlugin can add/remove gateway-level + // plugins, so the simplest correct reconciliation is a global reload. + tokio::task::spawn(async move { + let mut uid_version_map = HashMap::new(); + let ew = watcher::watcher(wasm_plugin_api, watcher::Config::default()); + pin_mut!(ew); + while let Some(event) = ew.try_next().await.unwrap_or_default() { + match event { + watcher::Event::Applied(plugin) => { + Self::reconcile_wasm_plugin_status(&wasm_plugin_status_api, &wasm_plugin_secret_api, &plugin).await; + if uid_version_map.get(&plugin.uid()) == Some(plugin.meta()) { + continue; + } + uid_version_map.insert(plugin.uid(), plugin.meta().clone()); + move_evt_tx + .send(( + ConfigType::Plugin { + id: plugin.to_spacegate_plugin_id(), + }, + ConfigEventType::Update, + )) + .expect("send event error"); + move_evt_tx.send((ConfigType::Global, ConfigEventType::Update)).expect("send event error"); + } + watcher::Event::Deleted(plugin) => { + uid_version_map.remove(&plugin.uid()); + move_evt_tx + .send(( + ConfigType::Plugin { + id: plugin.to_spacegate_plugin_id(), + }, + ConfigEventType::Delete, + )) + .expect("send event error"); + move_evt_tx.send((ConfigType::Global, ConfigEventType::Update)).expect("send event error"); + } + watcher::Event::Restarted(plugins) => { + for plugin in &plugins { + Self::reconcile_wasm_plugin_status(&wasm_plugin_status_api, &wasm_plugin_secret_api, plugin).await; + } + uid_version_map = plugins.into_iter().map(|plugin| (plugin.uid(), plugin.meta().clone())).collect(); + move_evt_tx.send((ConfigType::Global, ConfigEventType::Update)).expect("send event error"); + } + } + } + }); + let listener = K8sListener { rx: evt_rx }; Ok((config, listener)) diff --git a/crates/config/src/service/k8s/retrieve.rs b/crates/config/src/service/k8s/retrieve.rs index e8e75848..9ce0b939 100644 --- a/crates/config/src/service/k8s/retrieve.rs +++ b/crates/config/src/service/k8s/retrieve.rs @@ -1,14 +1,17 @@ +use base64::{engine::general_purpose, Engine as _}; use futures_util::future::join_all; use gateway::{SgListener, SgParameters, SgProtocolConfig, SgTlsConfig}; use http_route::SgHttpRouteRule; use k8s_gateway_api::{Gateway, HttpRoute, Listener}; use k8s_openapi::api::core::v1::Secret; use kube::{api::ListParams, Api, ResourceExt}; +use serde_json::{json, Value}; use spacegate_model::{ ext::k8s::{ crd::{ http_spaceroute::HttpSpaceroute, sg_filter::{K8sSgFilterSpecTargetRef, SgFilter}, + wasm_plugin::WasmPlugin, }, helper_struct::SgTargetKind, }, @@ -23,7 +26,12 @@ use crate::{ }; use super::{ - convert::{filter_k8s_conv::PluginConfigConv, gateway_k8s_conv::SgParametersConv as _, route_k8s_conv::SgHttpRouteRuleConv as _}, + convert::{ + filter_k8s_conv::PluginConfigConv, + gateway_k8s_conv::SgParametersConv as _, + higress_wasm_plugin_conv::{sort_higress_wasm_plugins, HigressWasmPluginConv as _}, + route_k8s_conv::SgHttpRouteRuleConv as _, + }, K8s, }; @@ -137,14 +145,42 @@ impl Retrieve for K8s { async fn retrieve_all_plugins(&self) -> Result, BoxError> { let filter_api: Api = self.get_namespace_api(); + let wasm_plugin_api: Api = self.get_namespace_api(); - let result = filter_api.list(&ListParams::default()).await?.into_iter().filter_map(PluginConfig::from_first_filter_obj).collect(); + let mut result = filter_api.list(&ListParams::default()).await?.into_iter().filter_map(PluginConfig::from_first_filter_obj).collect::>(); + let mut wasm_plugins = wasm_plugin_api.list(&ListParams::default()).await?.items; + sort_higress_wasm_plugins(&mut wasm_plugins); + for plugin in wasm_plugins { + let oci_auth = self.resolve_higress_wasm_oci_auth(&plugin).await?; + if oci_auth.is_some() { + result.extend(plugin.to_spacegate_plugin_configs_with_oci_auth(oci_auth)); + } else { + result.extend(plugin.to_spacegate_plugin_configs()); + } + } Ok(result) } async fn retrieve_plugin(&self, id: &spacegate_model::PluginInstanceId) -> Result, BoxError> { let filter_api: Api = self.get_namespace_api(); + let wasm_plugin_api: Api = self.get_namespace_api(); + if id.code == "wasm" { + if let spacegate_model::PluginInstanceName::Named { name } = &id.name { + if let Some(wasm_name) = name.strip_prefix("higress-") { + let wasm_name = wasm_name.rsplit_once("-rule-").map(|(base, _)| base).unwrap_or(wasm_name); + if let Some(plugin) = wasm_plugin_api.get_opt(wasm_name).await? { + let oci_auth = self.resolve_higress_wasm_oci_auth(&plugin).await?; + return Ok(if oci_auth.is_some() { + plugin.to_spacegate_plugin_config_by_id_with_oci_auth(id, oci_auth) + } else { + plugin.to_spacegate_plugin_config_by_id(id) + }); + } + return Ok(None); + } + } + } match &id.name { spacegate_model::PluginInstanceName::Anon { uid: _ } => Ok(None), spacegate_model::PluginInstanceName::Named { name } => { @@ -174,13 +210,14 @@ impl K8s { } async fn kube_gateway_2_sg_gateway(&self, gateway_obj: Gateway) -> BoxResult { let gateway_name = gateway_obj.name_any(); - let plugins = self + let mut plugins = self .retrieve_config_item_filters(K8sSgFilterSpecTargetRef { kind: SgTargetKind::Gateway.into(), name: gateway_name.clone(), namespace: gateway_obj.namespace(), }) .await?; + plugins.extend(self.retrieve_higress_gateway_plugins(gateway_obj.namespace()).await?); let result = SgGateway { name: gateway_name, parameters: SgParameters::from_kube_gateway(&gateway_obj), @@ -192,6 +229,7 @@ impl K8s { async fn kube_httpspaceroute_2_sg_route(&self, httpspace_route: HttpSpaceroute) -> BoxResult { let route_name = httpspace_route.name_any(); + let namespace = httpspace_route.namespace(); let kind = if let Some(kind) = httpspace_route.annotations().get(constants::RAW_HTTP_ROUTE_KIND) { kind.clone() } else { @@ -205,7 +243,7 @@ impl K8s { namespace: httpspace_route.namespace(), }) .await?; - Ok(SgHttpRoute { + let mut route = SgHttpRoute { hostnames: httpspace_route.spec.hostnames.clone(), plugins, rules: httpspace_route @@ -216,7 +254,9 @@ impl K8s { .unwrap_or_default(), priority, route_name, - }) + }; + self.apply_higress_wasm_route_plugins(&mut route, namespace).await?; + Ok(route) } async fn kube_httproute_2_sg_route(&self, http_route: HttpRoute) -> BoxResult { @@ -243,6 +283,7 @@ impl K8s { }) .flat_map(|filter_obj| PluginConfig::from_first_filter_obj(filter_obj).map(|f| f.into())) .collect(); + let plugin_ids = plugin_ids; if !plugin_ids.is_empty() { let mut filter_vec = String::new(); @@ -324,4 +365,107 @@ impl K8s { .into_iter() .collect() } + + async fn retrieve_higress_gateway_plugins(&self, namespace: Option) -> BoxResult> { + let namespace = namespace.unwrap_or_else(|| self.namespace.to_string()); + let wasm_plugin_api: Api = self.get_specify_namespace_api(&namespace); + let mut wasm_plugins = wasm_plugin_api.list(&ListParams::default()).await?.items; + sort_higress_wasm_plugins(&mut wasm_plugins); + Ok(wasm_plugins.into_iter().filter_map(|p| p.gateway_plugin_id()).collect()) + } + + async fn apply_higress_wasm_route_plugins(&self, route: &mut SgHttpRoute, namespace: Option) -> BoxResult<()> { + let namespace = namespace.unwrap_or_else(|| self.namespace.to_string()); + let wasm_plugin_api: Api = self.get_specify_namespace_api(&namespace); + let mut wasm_plugins = wasm_plugin_api.list(&ListParams::default()).await?.items; + sort_higress_wasm_plugins(&mut wasm_plugins); + let hostnames = route.hostnames.as_deref(); + + for plugin in wasm_plugins { + route.plugins.extend(plugin.route_plugin_ids(&route.route_name, hostnames)); + for rule in &mut route.rules { + for backend in &mut rule.backends { + backend.plugins.extend(plugin.backend_plugin_ids(backend)); + } + } + } + Ok(()) + } + + async fn resolve_higress_wasm_oci_auth(&self, plugin: &WasmPlugin) -> BoxResult> { + if plugin.oci_registry().is_none() { + return Ok(None); + } + let Some(secret_name) = plugin.spec.image_pull_secret.as_deref().map(str::trim).filter(|v| !v.is_empty()) else { + return Ok(None); + }; + let namespace = plugin.namespace().unwrap_or_else(|| self.namespace.to_string()); + let secret_api: Api = self.get_specify_namespace_api(&namespace); + let Some(secret) = secret_api.get_opt(secret_name).await? else { + tracing::warn!( + wasm_plugin = %plugin.name_any(), + namespace = %namespace, + secret = %secret_name, + "WasmPlugin imagePullSecret not found" + ); + return Ok(None); + }; + Ok(plugin.oci_registry().and_then(|registry| oci_auth_from_secret(&secret, ®istry))) + } +} + +pub(crate) fn oci_auth_from_secret(secret: &Secret, registry: &str) -> Option { + let data = secret.data.as_ref()?; + if let Some(bytes) = data.get(".dockerconfigjson").or_else(|| data.get(".dockercfg")) { + if let Some(auth) = oci_auth_from_docker_config(&bytes.0, registry) { + return Some(auth); + } + } + + let username = secret_data_string(secret, "username").or_else(|| secret_data_string(secret, "user"))?; + let password = secret_data_string(secret, "password").unwrap_or_default(); + Some(json!({ + "registry": registry, + "username": username, + "password": password, + })) +} + +fn oci_auth_from_docker_config(bytes: &[u8], registry: &str) -> Option { + let config: Value = serde_json::from_slice(bytes).ok()?; + let auths = config.get("auths").and_then(Value::as_object)?; + let entry = auths.get(registry).or_else(|| auths.get(&format!("https://{registry}"))).or_else(|| auths.get(&format!("http://{registry}"))).or_else(|| { + (registry == "docker.io").then(|| auths.get("https://index.docker.io/v1/").or_else(|| auths.get("index.docker.io")).or_else(|| auths.get("registry-1.docker.io")))? + })?; + + let identity_token = entry.get("identitytoken").or_else(|| entry.get("identity_token")).and_then(Value::as_str).map(str::to_string); + if let Some(identity_token) = identity_token.filter(|v| !v.trim().is_empty()) { + return Some(json!({ + "registry": registry, + "identity_token": identity_token, + })); + } + + let (username, password) = if let (Some(username), Some(password)) = ( + entry.get("username").and_then(Value::as_str).filter(|v| !v.trim().is_empty()), + entry.get("password").and_then(Value::as_str), + ) { + (username.to_string(), password.to_string()) + } else { + let auth = entry.get("auth").and_then(Value::as_str)?; + let decoded = general_purpose::STANDARD.decode(auth).ok()?; + let decoded = String::from_utf8(decoded).ok()?; + let (username, password) = decoded.split_once(':')?; + (username.to_string(), password.to_string()) + }; + + Some(json!({ + "registry": registry, + "username": username, + "password": password, + })) +} + +fn secret_data_string(secret: &Secret, key: &str) -> Option { + secret.data.as_ref().and_then(|data| data.get(key)).and_then(|bytes| String::from_utf8(bytes.0.clone()).ok()) } diff --git a/crates/config/tests/test_k8s_config.rs b/crates/config/tests/test_k8s_config.rs index e9ab46b3..4c28d5d9 100644 --- a/crates/config/tests/test_k8s_config.rs +++ b/crates/config/tests/test_k8s_config.rs @@ -1,2 +1,58 @@ +#[cfg(feature = "k8s")] #[test] -fn test_k8s_config() {} +fn observability_annotations_roundtrip() { + use k8s_gateway_api::{Gateway, GatewaySpec}; + use kube::api::ObjectMeta; + use spacegate_config::service::k8s::convert::gateway_k8s_conv::SgParametersConv; + use spacegate_model::{ObservabilityConfig, OtlpProtocol, SgParameters}; + + let params = SgParameters { + observability: ObservabilityConfig { + enabled: true, + service_name: "spacegate-k8s".to_string(), + otlp_endpoint: "http://otel-collector:4317".to_string(), + protocol: OtlpProtocol::Grpc, + traces: spacegate_model::TraceConfig { + enabled: true, + sample_ratio: 0.25, + }, + metrics: spacegate_model::MetricConfig { + enabled: true, + export_interval_ms: 15000, + }, + logs: spacegate_model::LogConfig { + enabled: true, + level: "info".to_string(), + }, + ..Default::default() + }, + ..Default::default() + }; + + let annotations = params.into_kube_gateway(); + let gateway = Gateway { + metadata: ObjectMeta { + annotations: Some(annotations), + ..Default::default() + }, + spec: GatewaySpec { + gateway_class_name: Default::default(), + listeners: Default::default(), + addresses: Default::default(), + }, + status: Default::default(), + }; + + let parsed = SgParameters::from_kube_gateway(&gateway); + + assert!(parsed.observability.enabled); + assert_eq!(parsed.observability.service_name, "spacegate-k8s"); + assert_eq!(parsed.observability.otlp_endpoint, "http://otel-collector:4317"); + assert_eq!(parsed.observability.protocol, OtlpProtocol::Grpc); + assert!(parsed.observability.traces.enabled); + assert_eq!(parsed.observability.traces.sample_ratio, 0.25); + assert!(parsed.observability.metrics.enabled); + assert_eq!(parsed.observability.metrics.export_interval_ms, 15000); + assert!(parsed.observability.logs.enabled); + assert_eq!(parsed.observability.logs.level, "info"); +} diff --git a/crates/kernel/Cargo.toml b/crates/kernel/Cargo.toml index 770be75c..8d76d61a 100644 --- a/crates/kernel/Cargo.toml +++ b/crates/kernel/Cargo.toml @@ -36,6 +36,9 @@ mime_guess = "2" # log tracing = { workspace = true } +tracing-opentelemetry = { workspace = true } +opentelemetry = { workspace = true } +serde_json = { workspace = true } # runtime tokio = { workspace = true, features = ["net", "time", "macros", "fs"] } diff --git a/crates/kernel/src/extension.rs b/crates/kernel/src/extension.rs index 8b80d972..b0656d23 100644 --- a/crates/kernel/src/extension.rs +++ b/crates/kernel/src/extension.rs @@ -2,6 +2,8 @@ mod reflect; pub use reflect::*; mod gateway_name; pub use gateway_name::*; +mod route_name; +pub use route_name::*; mod matched; pub use matched::*; mod peer_addr; diff --git a/crates/kernel/src/extension/route_name.rs b/crates/kernel/src/extension/route_name.rs new file mode 100644 index 00000000..237f2974 --- /dev/null +++ b/crates/kernel/src/extension/route_name.rs @@ -0,0 +1,18 @@ +use std::{ops::Deref, sync::Arc}; + +#[derive(Debug, Clone)] +pub struct RouteName(pub Arc); + +impl RouteName { + pub fn new(name: impl Into>) -> Self { + Self(name.into()) + } +} + +impl Deref for RouteName { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/crates/kernel/src/lib.rs b/crates/kernel/src/lib.rs index 7045cdff..cd8ec912 100644 --- a/crates/kernel/src/lib.rs +++ b/crates/kernel/src/lib.rs @@ -25,6 +25,8 @@ pub mod helper_layers; pub mod injector; /// tcp listener pub mod listener; +/// OpenTelemetry helpers. +pub mod observability; /// gateway service pub mod service; /// util functions and structs diff --git a/crates/kernel/src/observability.rs b/crates/kernel/src/observability.rs new file mode 100644 index 00000000..903b398c --- /dev/null +++ b/crates/kernel/src/observability.rs @@ -0,0 +1,503 @@ +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex, OnceLock}; +use std::time::Duration; + +use hyper::{header, Request, Response, StatusCode, Version}; +use opentelemetry::{global, KeyValue}; + +use crate::{extension::GatewayName, SgBody}; + +#[derive(Debug, Clone, Default)] +pub struct TelemetryContext { + fields: Arc>>, +} + +pub const MAX_TELEMETRY_KEY_LEN: usize = 128; +pub const MAX_TELEMETRY_VALUE_LEN: usize = 4096; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TelemetryError { + EmptyKey, + MissingNamespace, + ReservedPrefix, + InvalidKey, + KeyTooLong, + ValueTooLong, +} + +pub fn validate_telemetry_key(key: &str) -> Result<(), TelemetryError> { + if key.is_empty() { + return Err(TelemetryError::EmptyKey); + } + if key.len() > MAX_TELEMETRY_KEY_LEN { + return Err(TelemetryError::KeyTooLong); + } + if !key.bytes().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-')) { + return Err(TelemetryError::InvalidKey); + } + if !key.contains('.') { + return Err(TelemetryError::MissingNamespace); + } + if ["http.", "net.", "gateway.", "spacegate.", "otel."].iter().any(|prefix| key.starts_with(prefix)) { + return Err(TelemetryError::ReservedPrefix); + } + Ok(()) +} + +pub fn validate_telemetry_value(value: &str) -> Result<(), TelemetryError> { + if value.len() > MAX_TELEMETRY_VALUE_LEN { + return Err(TelemetryError::ValueTooLong); + } + Ok(()) +} + +impl TelemetryContext { + pub fn insert(&self, key: impl Into, value: impl Into) { + let Ok(mut fields) = self.fields.lock() else { + return; + }; + fields.insert(key.into(), value.into()); + } + + pub fn insert_checked(&self, key: impl Into, value: impl ToString) -> Result<(), TelemetryError> { + let key = key.into(); + let value = value.to_string(); + validate_telemetry_key(&key)?; + validate_telemetry_value(&value)?; + let Ok(mut fields) = self.fields.lock() else { + return Ok(()); + }; + fields.insert(key, value); + Ok(()) + } + + pub fn insert_namespaced(&self, namespace: &str, key: &str, value: impl ToString) -> Result<(), TelemetryError> { + self.insert_checked(format!("{namespace}.{key}"), value) + } + + pub fn snapshot(&self) -> BTreeMap { + self.fields.lock().map(|fields| fields.clone()).unwrap_or_default() + } + + pub fn is_empty(&self) -> bool { + self.fields.lock().map(|fields| fields.is_empty()).unwrap_or(true) + } +} + +#[derive(Debug, Clone)] +pub struct HttpMetricLabels { + pub gateway: String, + pub method: String, + pub status_code: String, + pub protocol_name: String, + pub protocol_version: String, + pub request_body_size: Option, + pub response_body_size: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AccessLogFields { + pub gateway: String, + pub method: String, + pub path: String, + pub host: String, + pub client_ip: String, + pub x_forwarded_for: String, + pub user_agent: String, + pub authority: String, + pub downstream_remote_address: String, + pub route_name: String, + pub upstream_host: String, + pub trace_id: String, + pub protocol_name: String, + pub protocol_version: String, + pub status_code: u16, + pub request_id: String, + pub peer_addr: String, + pub duration_ms: u64, + pub request_body_size: Option, + pub response_body_size: Option, + pub telemetry: BTreeMap, +} + +pub fn http_metric_labels(req: &Request, resp: &Response) -> HttpMetricLabels { + HttpMetricLabels { + gateway: req.extensions().get::().map(|g| g.to_string()).unwrap_or_else(|| "unknown".to_string()), + method: req.method().as_str().to_string(), + status_code: resp.status().as_u16().to_string(), + protocol_name: "http".to_string(), + protocol_version: http_protocol_version(req.version()), + request_body_size: content_length(req.headers()), + response_body_size: content_length(resp.headers()), + } +} + +pub fn access_log_fields( + gateway: impl Into, + method: impl Into, + path: impl Into, + host: impl Into, + client_ip: impl Into, + x_forwarded_for: impl Into, + user_agent: impl Into, + authority: impl Into, + downstream_remote_address: impl Into, + route_name: impl Into, + upstream_host: impl Into, + trace_id: impl Into, + protocol_version: impl Into, + status_code: StatusCode, + request_id: impl Into, + peer_addr: impl Into, + duration: Duration, + request_body_size: Option, + response_body_size: Option, + telemetry: BTreeMap, +) -> AccessLogFields { + AccessLogFields { + gateway: gateway.into(), + method: method.into(), + path: path.into(), + host: host.into(), + client_ip: client_ip.into(), + x_forwarded_for: x_forwarded_for.into(), + user_agent: user_agent.into(), + authority: authority.into(), + downstream_remote_address: downstream_remote_address.into(), + route_name: route_name.into(), + upstream_host: upstream_host.into(), + trace_id: trace_id.into(), + protocol_name: "http".to_string(), + protocol_version: protocol_version.into(), + status_code: status_code.as_u16(), + request_id: request_id.into(), + peer_addr: peer_addr.into(), + duration_ms: duration.as_millis() as u64, + request_body_size, + response_body_size, + telemetry, + } +} + +pub fn telemetry_json(fields: &BTreeMap) -> String { + serde_json::to_string(fields).unwrap_or_else(|_| "{}".to_string()) +} + +pub fn content_length(headers: &hyper::HeaderMap) -> Option { + headers.get(header::CONTENT_LENGTH)?.to_str().ok()?.parse().ok() +} + +pub fn header_value(headers: &hyper::HeaderMap, name: impl AsRef) -> String { + headers.get(name.as_ref()).and_then(|v| v.to_str().ok()).unwrap_or_default().to_string() +} + +pub fn first_x_forwarded_for(headers: &hyper::HeaderMap) -> Option { + header_value(headers, "x-forwarded-for").split(',').map(str::trim).find(|value| !value.is_empty()).map(str::to_string) +} + +pub fn client_ip(headers: &hyper::HeaderMap, peer_addr: std::net::SocketAddr) -> String { + first_x_forwarded_for(headers).unwrap_or_else(|| peer_addr.ip().to_string()) +} + +pub fn record_http_server_metrics(req: &Request, resp: &Response, duration: Duration) { + let labels = http_metric_labels(req, resp); + record_http_server_metrics_with_labels(labels, duration, resp.status().is_server_error() || resp.status().is_client_error()); +} + +pub fn record_http_server_metrics_with_labels(labels: HttpMetricLabels, duration: Duration, is_error: bool) { + let error_class = status_error_class_from_code(&labels.status_code); + let attrs = [ + KeyValue::new("gateway", labels.gateway), + KeyValue::new("http.request.method", labels.method), + KeyValue::new("http.response.status_code", labels.status_code), + KeyValue::new("network.protocol.name", labels.protocol_name), + KeyValue::new("network.protocol.version", labels.protocol_version), + ]; + let instruments = http_instruments(); + instruments.requests.add(1, &attrs); + instruments.duration.record(duration.as_secs_f64(), &attrs); + if let Some(size) = labels.request_body_size { + instruments.request_body_size.record(size, &attrs); + } + if let Some(size) = labels.response_body_size { + instruments.response_body_size.record(size, &attrs); + } + if is_error { + instruments.errors.add(1, &attrs); + } + match error_class { + Some(HttpErrorClass::Client) => instruments.errors_4xx.add(1, &attrs), + Some(HttpErrorClass::Server) => instruments.errors_5xx.add(1, &attrs), + None => {} + } +} + +pub fn record_http_server_active_request(labels: HttpMetricLabels, delta: i64) { + let attrs = [ + KeyValue::new("gateway", labels.gateway), + KeyValue::new("http.request.method", labels.method), + KeyValue::new("network.protocol.name", labels.protocol_name), + KeyValue::new("network.protocol.version", labels.protocol_version), + ]; + http_instruments().active_requests.add(delta, &attrs); +} + +#[derive(Debug)] +struct HttpInstruments { + requests: opentelemetry::metrics::Counter, + errors: opentelemetry::metrics::Counter, + errors_4xx: opentelemetry::metrics::Counter, + errors_5xx: opentelemetry::metrics::Counter, + active_requests: opentelemetry::metrics::UpDownCounter, + duration: opentelemetry::metrics::Histogram, + request_body_size: opentelemetry::metrics::Histogram, + response_body_size: opentelemetry::metrics::Histogram, +} + +fn http_instruments() -> &'static HttpInstruments { + static INSTRUMENTS: OnceLock = OnceLock::new(); + INSTRUMENTS.get_or_init(|| { + let meter = global::meter("spacegate_kernel"); + HttpInstruments { + requests: meter.u64_counter("http.server.requests").build(), + errors: meter.u64_counter("http.server.errors").build(), + errors_4xx: meter.u64_counter("http.server.errors.4xx").build(), + errors_5xx: meter.u64_counter("http.server.errors.5xx").build(), + active_requests: meter.i64_up_down_counter("http.server.active_requests").with_unit("{request}").build(), + duration: meter.f64_histogram("http.server.request.duration").with_unit("s").build(), + request_body_size: meter.u64_histogram("http.server.request.body.size").with_unit("By").build(), + response_body_size: meter.u64_histogram("http.server.response.body.size").with_unit("By").build(), + } + }) +} + +pub fn http_protocol_version(version: Version) -> String { + match version { + Version::HTTP_10 => "1.0", + Version::HTTP_11 => "1.1", + Version::HTTP_2 => "2", + Version::HTTP_3 => "3", + _ => "unknown", + } + .to_string() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HttpErrorClass { + Client, + Server, +} + +pub fn status_error_class(status: StatusCode) -> Option { + if status.is_client_error() { + Some(HttpErrorClass::Client) + } else if status.is_server_error() { + Some(HttpErrorClass::Server) + } else { + None + } +} + +fn status_error_class_from_code(status_code: &str) -> Option { + StatusCode::from_u16(status_code.parse().ok()?).ok().and_then(status_error_class) +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, time::Duration}; + + use hyper::{header, Request, Response, StatusCode}; + + use crate::{extension::GatewayName, observability::http_metric_labels, SgBody}; + + #[test] + fn http_metric_labels_do_not_include_path() { + let mut req = Request::builder().method("GET").uri("/users/123?token=secret").body(SgBody::empty()).expect("request"); + req.extensions_mut().insert(GatewayName::new("gw-a")); + let resp = Response::builder().status(StatusCode::OK).body(SgBody::empty()).expect("response"); + + let labels = http_metric_labels(&req, &resp); + + assert_eq!(labels.gateway, "gw-a"); + assert_eq!(labels.method, "GET"); + assert_eq!(labels.status_code, "200"); + assert!(!format!("{labels:?}").contains("/users/123")); + } + + #[test] + fn http_metric_labels_use_protocol_name_and_version() { + let req = Request::builder().version(hyper::Version::HTTP_2).body(SgBody::empty()).expect("request"); + let resp = Response::builder().status(StatusCode::OK).body(SgBody::empty()).expect("response"); + + let labels = http_metric_labels(&req, &resp); + + assert_eq!(labels.protocol_name, "http"); + assert_eq!(labels.protocol_version, "2"); + } + + #[test] + fn http_protocol_version_maps_known_versions() { + assert_eq!(super::http_protocol_version(hyper::Version::HTTP_10), "1.0"); + assert_eq!(super::http_protocol_version(hyper::Version::HTTP_11), "1.1"); + assert_eq!(super::http_protocol_version(hyper::Version::HTTP_2), "2"); + assert_eq!(super::http_protocol_version(hyper::Version::HTTP_3), "3"); + } + + #[test] + fn status_error_classifies_4xx_and_5xx_only() { + assert_eq!(super::status_error_class(StatusCode::OK), None); + assert_eq!(super::status_error_class(StatusCode::BAD_REQUEST), Some(super::HttpErrorClass::Client)); + assert_eq!(super::status_error_class(StatusCode::INTERNAL_SERVER_ERROR), Some(super::HttpErrorClass::Server)); + } + + #[test] + fn access_log_fields_include_stable_request_data_and_telemetry() { + let telemetry = BTreeMap::from([("ai.asset_id".to_string(), "deepseek-chat".to_string()), ("ai.total_tokens".to_string(), "37".to_string())]); + + let fields = super::access_log_fields( + "gw-a", + "POST", + "/api/v1/model/deepseek-chat", + "example.local", + "203.0.113.10", + "203.0.113.10, 10.0.0.1", + "curl/8.7.1", + "example.local", + "127.0.0.1:12345", + "model-route", + "model.default.svc.cluster.local", + "4bf92f3577b34da6a3ce929d0e0e4736", + "1.1", + StatusCode::OK, + "req-1", + "127.0.0.1:12345", + Duration::from_millis(42), + Some(12), + Some(34), + telemetry, + ); + + assert_eq!(fields.gateway, "gw-a"); + assert_eq!(fields.method, "POST"); + assert_eq!(fields.client_ip, "203.0.113.10"); + assert_eq!(fields.x_forwarded_for, "203.0.113.10, 10.0.0.1"); + assert_eq!(fields.user_agent, "curl/8.7.1"); + assert_eq!(fields.authority, "example.local"); + assert_eq!(fields.downstream_remote_address, "127.0.0.1:12345"); + assert_eq!(fields.route_name, "model-route"); + assert_eq!(fields.upstream_host, "model.default.svc.cluster.local"); + assert_eq!(fields.trace_id, "4bf92f3577b34da6a3ce929d0e0e4736"); + assert_eq!(fields.status_code, 200); + assert_eq!(fields.duration_ms, 42); + assert_eq!(fields.request_body_size, Some(12)); + assert_eq!(fields.response_body_size, Some(34)); + assert_eq!(fields.telemetry.get("ai.asset_id").map(String::as_str), Some("deepseek-chat")); + } + + #[test] + fn http_metric_labels_include_body_sizes_from_content_length() { + let req = Request::builder().method("POST").header(header::CONTENT_LENGTH, "123").body(SgBody::empty()).expect("request"); + let resp = Response::builder().status(StatusCode::OK).header(header::CONTENT_LENGTH, "456").body(SgBody::empty()).expect("response"); + + let labels = http_metric_labels(&req, &resp); + + assert_eq!(labels.request_body_size, Some(123)); + assert_eq!(labels.response_body_size, Some(456)); + } + + #[test] + fn http_metric_labels_ignore_invalid_body_sizes() { + let req = Request::builder().header(header::CONTENT_LENGTH, "chunked").body(SgBody::empty()).expect("request"); + let resp = Response::builder().header(header::CONTENT_LENGTH, "-1").body(SgBody::empty()).expect("response"); + + let labels = http_metric_labels(&req, &resp); + + assert_eq!(labels.request_body_size, None); + assert_eq!(labels.response_body_size, None); + } + + #[test] + fn client_ip_prefers_first_x_forwarded_for_value() { + let req = Request::builder().header("x-forwarded-for", "203.0.113.10, 10.0.0.1").body(SgBody::empty()).expect("request"); + let peer = "127.0.0.1:12345".parse().expect("peer"); + + assert_eq!(super::client_ip(req.headers(), peer), "203.0.113.10"); + } + + #[test] + fn client_ip_falls_back_to_peer_ip() { + let req = Request::builder().body(SgBody::empty()).expect("request"); + let peer = "127.0.0.1:12345".parse().expect("peer"); + + assert_eq!(super::client_ip(req.headers(), peer), "127.0.0.1"); + } + + #[test] + fn telemetry_context_collects_plugin_fields() { + let context = super::TelemetryContext::default(); + + context.insert("ai.asset_id", "deepseek-chat"); + context.insert("ai.total_tokens", "37"); + + let fields = context.snapshot(); + assert_eq!(fields.get("ai.asset_id").map(String::as_str), Some("deepseek-chat")); + assert_eq!(fields.get("ai.total_tokens").map(String::as_str), Some("37")); + } + + #[test] + fn telemetry_key_validation_accepts_namespaced_keys() { + assert!(super::validate_telemetry_key("ai.total_tokens").is_ok()); + assert!(super::validate_telemetry_key("mcp.tool-name").is_ok()); + assert!(super::validate_telemetry_key("auth.api_key_hash").is_ok()); + } + + #[test] + fn telemetry_key_validation_rejects_bad_keys() { + assert_eq!(super::validate_telemetry_key(""), Err(super::TelemetryError::EmptyKey)); + assert_eq!(super::validate_telemetry_key("total_tokens"), Err(super::TelemetryError::MissingNamespace)); + assert_eq!(super::validate_telemetry_key("ai total_tokens"), Err(super::TelemetryError::InvalidKey)); + assert_eq!(super::validate_telemetry_key("http.status_code"), Err(super::TelemetryError::ReservedPrefix)); + assert_eq!(super::validate_telemetry_key("spacegate.internal"), Err(super::TelemetryError::ReservedPrefix)); + } + + #[test] + fn telemetry_value_validation_rejects_oversized_values() { + let value = "x".repeat(super::MAX_TELEMETRY_VALUE_LEN + 1); + assert_eq!(super::validate_telemetry_value(&value), Err(super::TelemetryError::ValueTooLong)); + } + + #[test] + fn telemetry_context_checked_insert_rejects_invalid_key_without_mutating_context() { + let context = super::TelemetryContext::default(); + + let result = context.insert_checked("total_tokens", "37"); + + assert_eq!(result, Err(super::TelemetryError::MissingNamespace)); + assert!(context.snapshot().is_empty()); + } + + #[test] + fn telemetry_context_namespaced_insert_builds_stable_key() { + let context = super::TelemetryContext::default(); + + context.insert_namespaced("ai", "total_tokens", 37).expect("insert"); + + let fields = context.snapshot(); + assert_eq!(fields.get("ai.total_tokens").map(String::as_str), Some("37")); + } + + #[test] + fn telemetry_json_serializes_plugin_defined_fields() { + let fields = BTreeMap::from([ + ("ai.asset_id".to_string(), "deepseek-chat".to_string()), + ("ai.total_tokens".to_string(), "37".to_string()), + ("mcp.tool".to_string(), "search".to_string()), + ]); + + let json = super::telemetry_json(&fields); + + assert!(json.contains("\"ai.asset_id\":\"deepseek-chat\"")); + assert!(json.contains("\"ai.total_tokens\":\"37\"")); + assert!(json.contains("\"mcp.tool\":\"search\"")); + } +} diff --git a/crates/kernel/src/service.rs b/crates/kernel/src/service.rs index 3cc1365d..8100be88 100644 --- a/crates/kernel/src/service.rs +++ b/crates/kernel/src/service.rs @@ -3,11 +3,18 @@ use std::{convert::Infallible, net::SocketAddr, sync::Arc}; use futures_util::future::BoxFuture; use hyper::{body::Incoming, Request, Response}; use hyper_util::rt::TokioIo; +use opentelemetry::trace::TraceContextExt; use tokio::net::TcpStream; use tokio_rustls::rustls; +use tracing::Instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; use crate::{ - extension::{EnterTime, PeerAddr, Reflect}, + extension::{BackendHost, EnterTime, PeerAddr, Reflect, RouteName}, + observability::{ + access_log_fields, client_ip, content_length, header_value, http_protocol_version, record_http_server_active_request, record_http_server_metrics_with_labels, + telemetry_json, HttpMetricLabels, TelemetryContext, + }, ArcHyperService, BoxResult, SgBody, }; @@ -26,13 +33,19 @@ type ConnectionBuilder = hyper_util::server::conn::auto::Builder, connection_builder: ConnectionBuilder, } impl Http { pub fn new(service: ArcHyperService) -> Self { + Self::with_gateway_name(service, Arc::::from("unknown")) + } + + pub fn with_gateway_name(service: ArcHyperService, gateway_name: Arc) -> Self { Self { inner_service: service, + gateway_name, connection_builder: ConnectionBuilder::new(Default::default()), } } @@ -59,7 +72,7 @@ impl TcpService for Http { } fn handle(&self, stream: TcpStream, peer: SocketAddr) -> BoxFuture<'static, BoxResult<()>> { let io = TokioIo::new(stream); - let service = HyperServiceAdapter::new(self.inner_service.clone(), peer); + let service = HyperServiceAdapter::with_gateway_name(self.inner_service.clone(), peer, self.gateway_name.clone()); let builder = self.connection_builder.clone(); Box::pin(async move { let conn = builder.serve_connection_with_upgrades(io, service); @@ -70,14 +83,20 @@ impl TcpService for Http { #[derive(Debug)] pub struct Https { inner_service: ArcHyperService, + gateway_name: Arc, tls_config: Arc, connection_builder: ConnectionBuilder, } impl Https { pub fn new(service: ArcHyperService, tls_config: rustls::ServerConfig) -> Self { + Self::with_gateway_name(service, tls_config, Arc::::from("unknown")) + } + + pub fn with_gateway_name(service: ArcHyperService, tls_config: rustls::ServerConfig, gateway_name: Arc) -> Self { Self { inner_service: service, + gateway_name, tls_config: Arc::new(tls_config), connection_builder: ConnectionBuilder::new(Default::default()), } @@ -95,7 +114,7 @@ impl TcpService for Https { peeked.starts_with(b"\x16\x03") } fn handle(&self, stream: TcpStream, peer: SocketAddr) -> BoxFuture<'static, BoxResult<()>> { - let service = HyperServiceAdapter::new(self.inner_service.clone(), peer); + let service = HyperServiceAdapter::with_gateway_name(self.inner_service.clone(), peer, self.gateway_name.clone()); let builder = self.connection_builder.clone(); let connector = tokio_rustls::TlsAcceptor::from(self.tls_config.clone()); Box::pin(async move { @@ -114,6 +133,7 @@ where { service: S, peer: SocketAddr, + gateway_name: Arc, } impl HyperServiceAdapter @@ -122,7 +142,15 @@ where S::Future: Send + 'static, { pub fn new(service: S, peer: SocketAddr) -> Self { - Self { service, peer } + Self::with_gateway_name(service, peer, Arc::::from("unknown")) + } + + pub fn with_gateway_name(service: S, peer: SocketAddr, gateway_name: Arc) -> Self { + Self { service, peer, gateway_name } + } + + pub fn gateway_name(&self) -> &str { + self.gateway_name.as_ref() } } @@ -147,28 +175,141 @@ where let enter_time = EnterTime::new(); let service = self.service.clone(); let mut req = req.map(SgBody::new); + let method = req.method().clone(); + let method_label = method.as_str().to_string(); + let path = req.uri().path().to_string(); + let host = req.uri().host().map(str::to_string).or_else(|| req.headers().get(hyper::header::HOST).and_then(|v| v.to_str().ok()).map(str::to_string)).unwrap_or_default(); + let protocol = format!("{:?}", req.version()); + let protocol_version_label = http_protocol_version(req.version()); + let request_id = req.headers().get("x-request-id").and_then(|v| v.to_str().ok()).unwrap_or_default().to_string(); + let x_forwarded_for = header_value(req.headers(), "x-forwarded-for"); + let user_agent = header_value(req.headers(), "user-agent"); + let client_ip_label = client_ip(req.headers(), self.peer); + let request_body_size = content_length(req.headers()); + let peer_addr_label = self.peer.to_string(); + let span = tracing::info_span!( + "http.server.request", + http.method = %method, + http.path = %path, + http.host = %host, + http.protocol = %protocol, + http.status_code = tracing::field::Empty, + request_id = %request_id, + peer_addr = %self.peer, + duration_ms = tracing::field::Empty + ); + let gateway_label = self.gateway_name.to_string(); + let telemetry_context = TelemetryContext::default(); + let active_request_labels = HttpMetricLabels { + gateway: gateway_label.clone(), + method: method_label.clone(), + status_code: "active".to_string(), + protocol_name: "http".to_string(), + protocol_version: protocol_version_label.clone(), + request_body_size, + response_body_size: None, + }; + record_http_server_active_request(active_request_labels.clone(), 1); let mut reflect = Reflect::default(); // let method = req.method().clone(); reflect.insert(enter_time); req.extensions_mut().insert(reflect); req.extensions_mut().insert(PeerAddr(self.peer)); req.extensions_mut().insert(enter_time); - Box::pin(async move { - let resp = service.call(req).await.expect("infallible"); - // if method != hyper::Method::HEAD && method != hyper::Method::OPTIONS && method != hyper::Method::CONNECT { - // with_length_or_chunked(&mut resp); - // } - let status = resp.status(); - if status.is_server_error() { - tracing::warn!(status = ?status, headers = ?resp.headers(), "server error response"); - } else if status.is_client_error() { - tracing::debug!(status = ?status, headers = ?resp.headers(), "client error response"); - } else if status.is_success() { - tracing::trace!(status = ?status, headers = ?resp.headers(), "success response"); + req.extensions_mut().insert(telemetry_context.clone()); + let span_for_recording = span.clone(); + Box::pin( + async move { + let resp = service.call(req).await.expect("infallible"); + // if method != hyper::Method::HEAD && method != hyper::Method::OPTIONS && method != hyper::Method::CONNECT { + // with_length_or_chunked(&mut resp); + // } + let status = resp.status(); + if status.is_server_error() { + tracing::warn!(status = ?status, headers = ?resp.headers(), "server error response"); + } else if status.is_client_error() { + tracing::debug!(status = ?status, headers = ?resp.headers(), "client error response"); + } else if status.is_success() { + tracing::trace!(status = ?status, headers = ?resp.headers(), "success response"); + } + let latency = enter_time.elapsed(); + span_for_recording.record("http.status_code", status.as_u16()); + span_for_recording.record("duration_ms", latency.as_millis() as u64); + let response_body_size = content_length(resp.headers()); + let access_request_id = resp.headers().get("x-request-id").and_then(|v| v.to_str().ok()).map(str::to_string).unwrap_or(request_id); + tracing::trace!(latency = ?latency, "request finished"); + let authority = host.clone(); + let route_name = resp.extensions().get::().map(|route| route.to_string()).unwrap_or_default(); + let upstream_host = resp.extensions().get::().map(|host| host.to_string()).unwrap_or_default(); + let trace_id = span_for_recording.context().span().span_context().trace_id().to_string(); + record_http_server_metrics_with_labels( + HttpMetricLabels { + gateway: gateway_label.clone(), + method: method_label.clone(), + status_code: status.as_u16().to_string(), + protocol_name: "http".to_string(), + protocol_version: protocol_version_label.clone(), + request_body_size, + response_body_size, + }, + latency, + status.is_server_error() || status.is_client_error(), + ); + let access_log = access_log_fields( + gateway_label, + method_label, + path, + host, + client_ip_label, + x_forwarded_for, + user_agent, + authority, + peer_addr_label.clone(), + route_name, + upstream_host, + trace_id, + protocol_version_label, + status, + access_request_id, + peer_addr_label, + latency, + request_body_size, + response_body_size, + telemetry_context.snapshot(), + ); + let telemetry = telemetry_json(&access_log.telemetry); + tracing::info!( + event = "http_access", + gateway = %access_log.gateway, + method = %access_log.method, + path = %access_log.path, + host = %access_log.host, + authority = %access_log.authority, + client_ip = %access_log.client_ip, + x_forwarded_for = %access_log.x_forwarded_for, + user_agent = %access_log.user_agent, + downstream_remote_address = %access_log.downstream_remote_address, + route_name = %access_log.route_name, + upstream_host = %access_log.upstream_host, + trace_id = %access_log.trace_id, + protocol_name = %access_log.protocol_name, + protocol_version = %access_log.protocol_version, + status_code = access_log.status_code, + request_id = %access_log.request_id, + peer_addr = %access_log.peer_addr, + duration_ms = access_log.duration_ms, + bytes_received = ?access_log.request_body_size, + bytes_sent = ?access_log.response_body_size, + request_body_size = ?access_log.request_body_size, + response_body_size = ?access_log.response_body_size, + telemetry = %telemetry, + "http access log" + ); + record_http_server_active_request(active_request_labels, -1); + Ok(resp) } - tracing::trace!(latency = ?enter_time.elapsed(), "request finished"); - Ok(resp) - }) + .instrument(span), + ) } } @@ -179,4 +320,25 @@ impl ArcHyperService { pub fn https(self, tls_config: rustls::ServerConfig) -> Https { Https::new(self, tls_config) } + pub fn http_with_gateway_name(self, gateway_name: Arc) -> Http { + Http::with_gateway_name(self, gateway_name) + } + pub fn https_with_gateway_name(self, tls_config: rustls::ServerConfig, gateway_name: Arc) -> Https { + Https::with_gateway_name(self, tls_config, gateway_name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hyper_service_adapter_keeps_gateway_name_from_listener() { + let service = hyper::service::service_fn(|_req: Request| async { Ok::<_, Infallible>(Response::new(SgBody::empty())) }); + let peer = "127.0.0.1:12345".parse().expect("peer"); + + let adapter = HyperServiceAdapter::with_gateway_name(service, peer, Arc::::from("gw-a")); + + assert_eq!(adapter.gateway_name(), "gw-a"); + } } diff --git a/crates/kernel/src/service/http_gateway.rs b/crates/kernel/src/service/http_gateway.rs index fc629cac..d1c9c8cd 100644 --- a/crates/kernel/src/service/http_gateway.rs +++ b/crates/kernel/src/service/http_gateway.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, ops::Index, sync::Arc}; use crate::{ backend_service::ArcHyperService, - extension::{GatewayName, MatchedSgRouter}, + extension::{GatewayName, MatchedSgRouter, Reflect, RouteName}, helper_layers::{ map_request::{add_extension::add_extension, MapRequestLayer}, reload::Reloader, @@ -96,6 +96,7 @@ impl Router for GatewayRouter { if let Some(ref matches) = matches { for m in matches.as_ref() { if m.match_request(req) { + insert_route_name(req, self.routers.as_ref()[*route_index].name.clone()); req.extensions_mut().insert(MatchedSgRouter(m.clone())); tracing::trace!("matches {m:?} [{route_index},{idx1}:{_p}]"); if let Err(e) = m.rewrite(req) { @@ -107,6 +108,7 @@ impl Router for GatewayRouter { } continue; } else { + insert_route_name(req, self.routers.as_ref()[*route_index].name.clone()); tracing::trace!("matches wildcard [{route_index},{idx1}:{_p}]"); return Some(index); } @@ -150,6 +152,7 @@ pub fn create_http_router<'a>(routes: impl Iterator, fallb } services.push(rules_services); routers.push(HttpRouter { + name: route.name.clone().into(), hostnames: route.hostnames.clone().into(), rules: rules_router.into_iter().map(|x| x.map(|v| v.into_iter().map(Arc::new).collect::>())).collect(), ext: route.ext.clone(), @@ -169,3 +172,11 @@ pub fn create_http_router<'a>(routes: impl Iterator, fallb fallback, ) } + +fn insert_route_name(req: &mut Request, name: Arc) { + let route_name = RouteName::new(name); + if let Some(reflect) = req.extensions_mut().get_mut::() { + reflect.insert(route_name.clone()); + } + req.extensions_mut().insert(route_name); +} diff --git a/crates/kernel/src/service/http_route.rs b/crates/kernel/src/service/http_route.rs index 3fac72fd..5f97d552 100644 --- a/crates/kernel/src/service/http_route.rs +++ b/crates/kernel/src/service/http_route.rs @@ -44,6 +44,7 @@ impl HttpRoute { } #[derive(Debug, Clone)] pub struct HttpRouter { + pub name: Arc, pub hostnames: Arc<[String]>, pub rules: Arc<[Option]>>]>, pub ext: hyper::http::Extensions, diff --git a/crates/model/src/constants.rs b/crates/model/src/constants.rs index c8c36a50..88e46005 100644 --- a/crates/model/src/constants.rs +++ b/crates/model/src/constants.rs @@ -8,6 +8,16 @@ pub const GATEWAY_ANNOTATION_LOG_LEVEL: &str = "log_level"; pub const GATEWAY_ANNOTATION_LANGUAGE: &str = "lang"; pub const GATEWAY_ANNOTATION_IGNORE_TLS_VERIFICATION: &str = "ignore_tls_verification"; pub const GATEWAY_ANNOTATION_ENABLE_X_REQUEST_ID: &str = "enable_x_request_id"; +pub const GATEWAY_ANNOTATION_OTEL_ENABLED: &str = "spacegate.io/otel-enabled"; +pub const GATEWAY_ANNOTATION_OTEL_SERVICE_NAME: &str = "spacegate.io/otel-service-name"; +pub const GATEWAY_ANNOTATION_OTEL_ENDPOINT: &str = "spacegate.io/otel-endpoint"; +pub const GATEWAY_ANNOTATION_OTEL_PROTOCOL: &str = "spacegate.io/otel-protocol"; +pub const GATEWAY_ANNOTATION_OTEL_TRACES_ENABLED: &str = "spacegate.io/otel-traces-enabled"; +pub const GATEWAY_ANNOTATION_OTEL_TRACES_SAMPLE_RATIO: &str = "spacegate.io/otel-traces-sample-ratio"; +pub const GATEWAY_ANNOTATION_OTEL_METRICS_ENABLED: &str = "spacegate.io/otel-metrics-enabled"; +pub const GATEWAY_ANNOTATION_OTEL_METRICS_EXPORT_INTERVAL_MS: &str = "spacegate.io/otel-metrics-export-interval-ms"; +pub const GATEWAY_ANNOTATION_OTEL_LOGS_ENABLED: &str = "spacegate.io/otel-logs-enabled"; +pub const GATEWAY_ANNOTATION_OTEL_LOGS_LEVEL: &str = "spacegate.io/otel-logs-level"; pub const SG_FILTER_KIND: &str = "sgfilter"; pub const DEFAULT_NAMESPACE: &str = "default"; diff --git a/crates/model/src/ext/k8s/crd.rs b/crates/model/src/ext/k8s/crd.rs index 3cd76b8e..5364f7f2 100644 --- a/crates/model/src/ext/k8s/crd.rs +++ b/crates/model/src/ext/k8s/crd.rs @@ -1,2 +1,3 @@ pub mod http_spaceroute; pub mod sg_filter; +pub mod wasm_plugin; diff --git a/crates/model/src/ext/k8s/crd/wasm_plugin.rs b/crates/model/src/ext/k8s/crd/wasm_plugin.rs new file mode 100644 index 00000000..df7b072d --- /dev/null +++ b/crates/model/src/ext/k8s/crd/wasm_plugin.rs @@ -0,0 +1,71 @@ +use k8s_openapi::schemars::JsonSchema; +use kube::CustomResource; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(CustomResource, Deserialize, Serialize, Clone, Debug, JsonSchema)] +#[serde(rename_all = "camelCase")] +#[kube(kind = "WasmPlugin", group = "extensions.higress.io", version = "v1alpha1", namespaced, status = "HigressWasmPluginStatus")] +pub struct HigressWasmPluginSpec { + /// Higress-compatible wasm URL. Spacegate runtime supports local paths, + /// `file://`, `http(s)://`, and OCI image URLs such as `oci://registry/repo:tag`. + pub url: String, + /// Optional plugin name exposed to proxy-wasm guests. + #[serde(default)] + pub plugin_name: Option, + /// Optional SHA-256 digest for the wasm bytes. Accepts either plain hex or `sha256:`. + #[serde(default, alias = "sha256")] + pub sha256: Option, + /// Higress phase is kept for ordering/compatibility. Spacegate currently maps order by priority. + #[serde(default)] + pub phase: Option, + /// Higher priority plugins are placed earlier in the generated Spacegate plugin list. + #[serde(default)] + pub priority: Option, + /// `Always` disables Spacegate's in-process wasm module cache for this plugin. + #[serde(default)] + pub image_pull_policy: Option, + /// Optional Kubernetes Secret used for private OCI registries. + #[serde(default)] + pub image_pull_secret: Option, + /// Disable global/default config. Match rules can still enable per-rule configs. + #[serde(default)] + pub default_config_disable: bool, + /// Higress default plugin config. + #[serde(default)] + pub default_config: Value, + /// Optional match rules. These are passed through to Higress-style wasm plugins under `_rules_`. + #[serde(default)] + pub match_rules: Vec, + /// Optional fail strategy. `FAIL_OPEN`/`FAIL_CLOSE` and `fail_open`/`fail_close` are accepted. + #[serde(default)] + pub fail_strategy: Option, +} + +#[derive(Deserialize, Serialize, Clone, Debug, Default, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct HigressWasmPluginStatus { + #[serde(default)] + pub observed_generation: Option, + #[serde(default)] + pub phase: Option, + #[serde(default)] + pub digest: Option, + #[serde(default)] + pub message: Option, +} + +#[derive(Deserialize, Serialize, Clone, Debug, Default, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct HigressWasmPluginMatchRule { + #[serde(default)] + pub ingress: Vec, + #[serde(default)] + pub domain: Vec, + #[serde(default)] + pub service: Vec, + #[serde(default)] + pub config_disable: bool, + #[serde(default)] + pub config: Value, +} diff --git a/crates/model/src/gateway.rs b/crates/model/src/gateway.rs index 22f39461..cc2b1d16 100644 --- a/crates/model/src/gateway.rs +++ b/crates/model/src/gateway.rs @@ -2,7 +2,7 @@ use std::{fmt::Display, net::IpAddr}; use serde::{Deserialize, Serialize}; -use super::plugin::PluginInstanceId; +use super::{observability::ObservabilityConfig, plugin::PluginInstanceId}; /// Gateway represents an instance of a service-traffic handling infrastructure /// by binding Listeners to a set of IP addresses. @@ -69,6 +69,7 @@ pub struct SgParameters { #[serde(skip_serializing_if = "Option::is_none")] /// Add request id for every request pub enable_x_request_id: Option, + pub observability: ObservabilityConfig, } /// Listener embodies the concept of a logical endpoint where a Gateway accepts network connections. diff --git a/crates/model/src/lib.rs b/crates/model/src/lib.rs index 89985a76..7d99c043 100644 --- a/crates/model/src/lib.rs +++ b/crates/model/src/lib.rs @@ -3,6 +3,9 @@ use std::{collections::BTreeMap, fmt::Debug}; pub use plugin::*; +pub mod observability; +pub use observability::*; + pub mod gateway; pub use gateway::*; @@ -62,6 +65,7 @@ pub struct Config { #[cfg_attr(feature = "typegen", ts(as = "crate::plugin::PluginInstanceMapTs"))] pub plugins: PluginInstanceMap, pub api_port: Option, + pub observability: ObservabilityConfig, } #[allow(clippy::derivable_impls)] @@ -74,6 +78,7 @@ impl Default for Config { api_port: Some(crate::constants::DEFAULT_API_PORT), #[cfg(not(feature = "ext-axum"))] api_port: None, + observability: Default::default(), } } } diff --git a/crates/model/src/observability.rs b/crates/model/src/observability.rs new file mode 100644 index 00000000..1b1ee38e --- /dev/null +++ b/crates/model/src/observability.rs @@ -0,0 +1,109 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "typegen", derive(ts_rs::TS), ts(export))] +#[serde(default)] +pub struct ObservabilityConfig { + pub enabled: bool, + pub service_name: String, + pub otlp_endpoint: String, + pub protocol: OtlpProtocol, + pub traces: TraceConfig, + pub metrics: MetricConfig, + pub logs: LogConfig, +} + +impl Default for ObservabilityConfig { + fn default() -> Self { + Self { + enabled: false, + service_name: "spacegate".to_string(), + otlp_endpoint: "http://localhost:4317".to_string(), + protocol: OtlpProtocol::Grpc, + traces: TraceConfig::default(), + metrics: MetricConfig::default(), + logs: LogConfig::default(), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Default)] +#[cfg_attr(feature = "typegen", derive(ts_rs::TS), ts(export))] +#[serde(rename_all = "lowercase")] +pub enum OtlpProtocol { + #[default] + Grpc, + Http, +} + +impl std::fmt::Display for OtlpProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OtlpProtocol::Grpc => write!(f, "grpc"), + OtlpProtocol::Http => write!(f, "http"), + } + } +} + +impl std::str::FromStr for OtlpProtocol { + type Err = crate::BoxError; + + fn from_str(s: &str) -> Result { + match s { + "grpc" => Ok(OtlpProtocol::Grpc), + "http" | "http/protobuf" => Ok(OtlpProtocol::Http), + _ => Err(format!("invalid otlp protocol: {s}").into()), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[cfg_attr(feature = "typegen", derive(ts_rs::TS), ts(export))] +#[serde(default)] +pub struct TraceConfig { + pub enabled: bool, + pub sample_ratio: f64, +} + +impl Default for TraceConfig { + fn default() -> Self { + Self { + enabled: false, + sample_ratio: 1.0, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "typegen", derive(ts_rs::TS), ts(export))] +#[serde(default)] +pub struct MetricConfig { + pub enabled: bool, + pub export_interval_ms: u64, +} + +impl Default for MetricConfig { + fn default() -> Self { + Self { + enabled: false, + export_interval_ms: 60_000, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "typegen", derive(ts_rs::TS), ts(export))] +#[serde(default)] +pub struct LogConfig { + pub enabled: bool, + pub level: String, +} + +impl Default for LogConfig { + fn default() -> Self { + Self { + enabled: false, + level: "info".to_string(), + } + } +} diff --git a/crates/model/tests/test_parse_config.rs b/crates/model/tests/test_parse_config.rs index 2fa76a3f..e5e3e0bd 100644 --- a/crates/model/tests/test_parse_config.rs +++ b/crates/model/tests/test_parse_config.rs @@ -19,3 +19,134 @@ fn test_parse_config() { } } } + +#[test] +fn observability_defaults_to_disabled() { + let config = Config::default(); + + assert!(!config.observability.enabled); + assert_eq!(config.observability.service_name, "spacegate"); + assert_eq!(config.observability.otlp_endpoint, "http://localhost:4317"); + assert!(!config.observability.traces.enabled); + assert!(!config.observability.metrics.enabled); + assert!(!config.observability.logs.enabled); +} + +#[test] +fn observability_can_be_parsed_from_config() { + let file = r#" +[observability] +enabled = true +service_name = "spacegate-test" +otlp_endpoint = "http://collector:4317" +protocol = "grpc" + +[observability.traces] +enabled = true +sample_ratio = 0.5 + +[observability.metrics] +enabled = true +export_interval_ms = 10000 + +[observability.logs] +enabled = true +level = "warn" +"#; + + let config = toml::from_str::(file).expect("parse config"); + + assert!(config.observability.enabled); + assert_eq!(config.observability.service_name, "spacegate-test"); + assert_eq!(config.observability.otlp_endpoint, "http://collector:4317"); + assert_eq!(config.observability.protocol, spacegate_model::OtlpProtocol::Grpc); + assert!(config.observability.traces.enabled); + assert_eq!(config.observability.traces.sample_ratio, 0.5); + assert!(config.observability.metrics.enabled); + assert_eq!(config.observability.metrics.export_interval_ms, 10000); + assert!(config.observability.logs.enabled); + assert_eq!(config.observability.logs.level, "warn"); +} + +#[test] +fn local_otel_json_config_shape_can_be_parsed() { + let file = r#" +{ + "api_port": 9876, + "observability": { + "enabled": true, + "service_name": "spacegate-local-otel", + "otlp_endpoint": "http://127.0.0.1:4317", + "protocol": "grpc", + "traces": { + "enabled": true, + "sample_ratio": 1.0 + }, + "metrics": { + "enabled": true, + "export_interval_ms": 5000 + }, + "logs": { + "enabled": true, + "level": "info" + } + }, + "gateways": { + "local": { + "gateway": { + "name": "local", + "parameters": { + "enable_x_request_id": true + }, + "listeners": [ + { + "name": "http", + "ip": "0.0.0.0", + "port": 9000, + "protocol": { + "type": "http" + } + } + ] + }, + "routes": { + "root": { + "route_name": "root", + "rules": [ + { + "matches": [ + { + "path": { + "kind": "Prefix", + "value": "/" + } + } + ], + "backends": [ + { + "host": { + "kind": "Host", + "host": "127.0.0.1" + }, + "port": 18080, + "protocol": "http", + "weight": 1 + } + ] + } + ] + } + } + } + } +} +"#; + + let config = serde_json::from_str::(file).expect("parse local otel json config"); + let gateway = config.gateways.get("local").expect("local gateway"); + + assert_eq!(gateway.gateway.name, "local"); + assert_eq!(gateway.gateway.listeners.len(), 1); + assert_eq!(gateway.gateway.listeners[0].port, 9000); + assert!(gateway.routes.contains_key("root")); +} diff --git a/crates/plugin-wasm/Cargo.toml b/crates/plugin-wasm/Cargo.toml new file mode 100644 index 00000000..be1cf7b3 --- /dev/null +++ b/crates/plugin-wasm/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "spacegate-plugin-wasm" +version.workspace = true +authors.workspace = true +description = "Proxy-Wasm host integration for SpaceGate (wasmtime)" +edition.workspace = true +license.workspace = true +repository.workspace = true +readme = "../../../README.md" + +[lib] +name = "spacegate_plugin_wasm" +path = "src/lib.rs" + +[dependencies] +spacegate-plugin = { workspace = true } +spacegate-kernel = { workspace = true } +spacegate-model = { workspace = true } + +wasmtime = { version = "23", default-features = true, features = ["async", "cranelift"] } + +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +serde_yaml = "0.9" + +tracing = { workspace = true } +opentelemetry = { workspace = true } +thiserror = "1" +once_cell = "1.19" +sha2 = "0.10" + +# host fn: dispatch_http_call 走 reqwest 异步客户端(用 0.12 与 spacegate 的 http=1 对齐) +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] } +# 异步驱动 + 待回填 dispatch 状态 +tokio = { workspace = true, features = ["sync", "macros", "rt", "time"] } +# inner.call 拿到的是 hyper Response,host fn 操作 header 时要用 http 类型 +hyper = { workspace = true } +http = "1" +http-body-util = { workspace = true } +bytes = { workspace = true } + +# moka 同步缓存模块(Module 编译产物按 url 缓存) +moka = { version = "0.12", features = ["sync"] } + +# WASI random_get:用 OS 级 RNG(getrandom 是 rand 的底层,体积小) +getrandom = "0.2" + +[dev-dependencies] +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "net", "time"] } +# 集成测试 mock HTTP server / inner.call mock service 需要 hyper 1 + hyper-util。 +hyper-util = { workspace = true, features = ["tokio"] } +http-body-util = { workspace = true } diff --git a/crates/plugin-wasm/src/abi.rs b/crates/plugin-wasm/src/abi.rs new file mode 100644 index 00000000..5c94adf1 --- /dev/null +++ b/crates/plugin-wasm/src/abi.rs @@ -0,0 +1,393 @@ +//! proxy-wasm ABI v0.2.1 的基础类型与内存/编码工具。 +//! +//! 主要分三块: +//! 1. `Status` / `Action` / `MapType` / `BufferType` / `StreamType` / `MetricType` / `PeerType` / +//! `LogLevel` 枚举(按 spec 1:1 完整覆盖) +//! 2. `MemoryHelper`:通过 `wasmtime::Memory` 安全读写 guest 线性内存 +//! 3. `pairs`:proxy-wasm 头部 (k, v) 列表的二进制布局编解码 +//! +//! 所有越界访问统一转 `WasmHostError::MemoryOob`,避免 trap 撕裂 Store。 + +use crate::error::WasmHostError; +use wasmtime::{Caller, Memory, StoreContext, StoreContextMut}; + +// ───────────────────────────────────────────────────────── +// 枚举:proxy-wasm v0.2.1 spec §Types +// ───────────────────────────────────────────────────────── + +/// `proxy_status_t`:所有 host fn 的返回值(spec 完整 10 个值)。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum Status { + Ok = 0, + NotFound = 1, + BadArgument = 2, + SerializationFailure = 3, + ParseFailure = 4, + InvalidMemoryAccess = 6, + Empty = 7, + CasMismatch = 8, + InternalFailure = 10, + Unimplemented = 12, +} + +impl Status { + #[inline] + pub fn as_i32(self) -> i32 { + self as i32 + } +} + +/// `proxy_action_t`:guest 钩子返回。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum Action { + Continue = 0, + Pause = 1, +} + +impl Action { + pub fn from_u32(v: u32) -> Self { + match v { + 1 => Action::Pause, + _ => Action::Continue, + } + } +} + +/// `proxy_map_type_t`:头部映射的来源(spec §Types 完整 8 个值)。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MapType { + HttpRequestHeaders = 0, + HttpRequestTrailers = 1, + HttpResponseHeaders = 2, + HttpResponseTrailers = 3, + GrpcCallInitialMetadata = 4, + GrpcCallTrailingMetadata = 5, + HttpCallResponseHeaders = 6, + HttpCallResponseTrailers = 7, +} + +impl MapType { + pub fn from_i32(v: i32) -> Option { + Some(match v { + 0 => MapType::HttpRequestHeaders, + 1 => MapType::HttpRequestTrailers, + 2 => MapType::HttpResponseHeaders, + 3 => MapType::HttpResponseTrailers, + 4 => MapType::GrpcCallInitialMetadata, + 5 => MapType::GrpcCallTrailingMetadata, + 6 => MapType::HttpCallResponseHeaders, + 7 => MapType::HttpCallResponseTrailers, + _ => return None, + }) + } +} + +/// `proxy_buffer_type_t`:缓冲区来源(spec §Types 完整 9 个值)。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BufferType { + HttpRequestBody = 0, + HttpResponseBody = 1, + DownstreamData = 2, + UpstreamData = 3, + HttpCallResponseBody = 4, + GrpcCallMessage = 5, + VmConfiguration = 6, + PluginConfiguration = 7, + ForeignFunctionArguments = 8, +} + +impl BufferType { + pub fn from_i32(v: i32) -> Option { + Some(match v { + 0 => BufferType::HttpRequestBody, + 1 => BufferType::HttpResponseBody, + 2 => BufferType::DownstreamData, + 3 => BufferType::UpstreamData, + 4 => BufferType::HttpCallResponseBody, + 5 => BufferType::GrpcCallMessage, + 6 => BufferType::VmConfiguration, + 7 => BufferType::PluginConfiguration, + 8 => BufferType::ForeignFunctionArguments, + _ => return None, + }) + } +} + +/// `proxy_stream_type_t`:`proxy_continue_stream` / `proxy_close_stream` 参数。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamType { + HttpRequest = 0, + HttpResponse = 1, + Downstream = 2, + Upstream = 3, +} + +impl StreamType { + pub fn from_i32(v: i32) -> Option { + Some(match v { + 0 => StreamType::HttpRequest, + 1 => StreamType::HttpResponse, + 2 => StreamType::Downstream, + 3 => StreamType::Upstream, + _ => return None, + }) + } +} + +/// `proxy_metric_type_t`。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MetricType { + Counter = 0, + Gauge = 1, + Histogram = 2, +} + +impl MetricType { + pub fn from_i32(v: i32) -> Option { + Some(match v { + 0 => MetricType::Counter, + 1 => MetricType::Gauge, + 2 => MetricType::Histogram, + _ => return None, + }) + } +} + +/// `proxy_peer_type_t`(TCP 用,暂不调用但保留类型)。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(dead_code)] +pub enum PeerType { + Unknown = 0, + Local = 1, + Remote = 2, +} + +/// `proxy_log_level_t`。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum LogLevel { + Trace = 0, + Debug = 1, + Info = 2, + Warn = 3, + Error = 4, + Critical = 5, +} + +impl LogLevel { + pub fn as_i32(self) -> i32 { + self as i32 + } +} + +/// `proxy_log` 的 level(tracing 转换用)。 +pub fn log_level_to_tracing(level: i32) -> Option { + Some(match level { + 0 => tracing::Level::TRACE, + 1 => tracing::Level::DEBUG, + 2 => tracing::Level::INFO, + 3 => tracing::Level::WARN, + 4 | 5 => tracing::Level::ERROR, + _ => return None, + }) +} + +/// host tracing 最大级别 → proxy_log_level_t(用于 `proxy_get_log_level`)。 +pub fn host_max_log_level() -> LogLevel { + if tracing::enabled!(tracing::Level::TRACE) { + LogLevel::Trace + } else if tracing::enabled!(tracing::Level::DEBUG) { + LogLevel::Debug + } else if tracing::enabled!(tracing::Level::INFO) { + LogLevel::Info + } else if tracing::enabled!(tracing::Level::WARN) { + LogLevel::Warn + } else { + LogLevel::Error + } +} + +// ───────────────────────────────────────────────────────── +// WASI 常量子集 +// ───────────────────────────────────────────────────────── + +/// `wasi_errno_t`(spec §Types 中的子集)。 +pub mod wasi_errno { + pub const SUCCESS: i32 = 0; + pub const BADF: i32 = 8; + pub const FAULT: i32 = 21; + #[allow(dead_code)] + pub const INVAL: i32 = 28; + #[allow(dead_code)] + pub const NOTSUP: i32 = 58; +} + +/// `wasi_fd_id_t`:stdout / stderr。 +pub mod wasi_fd { + pub const STDOUT: i32 = 1; + pub const STDERR: i32 = 2; +} + +// ───────────────────────────────────────────────────────── +// MemoryHelper:guest 内存读写(按 host fn 单次调用的生命周期使用) +// ───────────────────────────────────────────────────────── + +pub struct MemoryHelper { + memory: Memory, +} + +impl MemoryHelper { + pub fn new(memory: Memory) -> Self { + Self { memory } + } + + /// 从 caller 中拿到 `memory` export 的 helper(在每个 host fn 起始处调用)。 + pub fn from_caller(caller: &mut Caller<'_, T>) -> Result { + let Some(mem) = caller.get_export("memory").and_then(|e| e.into_memory()) else { + return Err(WasmHostError::AbiViolation("guest module has no `memory` export".to_string())); + }; + Ok(Self { memory: mem }) + } + + /// 读取 guest 线性内存 `[ptr, ptr+len)` 的字节切片。 + pub fn read_bytes(&self, store: StoreContext<'_, T>, ptr: u32, len: u32) -> Result, WasmHostError> { + let data = self.memory.data(&store); + let start = ptr as usize; + let end = start.saturating_add(len as usize); + if end > data.len() { + return Err(WasmHostError::MemoryOob { ptr, len }); + } + Ok(data[start..end].to_vec()) + } + + /// 读 UTF-8 字符串;非法 UTF-8 用 lossy 转换,不报错。 + pub fn read_string_lossy(&self, store: StoreContext<'_, T>, ptr: u32, len: u32) -> Result { + let bytes = self.read_bytes(store, ptr, len)?; + Ok(String::from_utf8_lossy(&bytes).into_owned()) + } + + /// 把 host 数据写入 guest 已经分配好的 `ptr` 处。 + pub fn write_bytes(&self, mut store: StoreContextMut<'_, T>, ptr: u32, data: &[u8]) -> Result<(), WasmHostError> { + let mem = self.memory.data_mut(&mut store); + let start = ptr as usize; + let end = start.saturating_add(data.len()); + if end > mem.len() { + return Err(WasmHostError::MemoryOob { ptr, len: data.len() as u32 }); + } + mem[start..end].copy_from_slice(data); + Ok(()) + } + + /// 读 little-endian u32。 + pub fn read_u32(&self, store: StoreContext<'_, T>, ptr: u32) -> Result { + let bytes = self.read_bytes(store, ptr, 4)?; + let arr: [u8; 4] = bytes.as_slice().try_into().map_err(|_| WasmHostError::MemoryOob { ptr, len: 4 })?; + Ok(u32::from_le_bytes(arr)) + } + + /// 写入一个 little-endian u32 到 guest 内存。 + pub fn write_u32(&self, store: StoreContextMut<'_, T>, ptr: u32, value: u32) -> Result<(), WasmHostError> { + self.write_bytes(store, ptr, &value.to_le_bytes()) + } + + /// 写入一个 little-endian u64 到 guest 内存。 + pub fn write_u64(&self, store: StoreContextMut<'_, T>, ptr: u32, value: u64) -> Result<(), WasmHostError> { + self.write_bytes(store, ptr, &value.to_le_bytes()) + } +} + +// ───────────────────────────────────────────────────────── +// header / call pairs 的二进制布局编解码 +// ───────────────────────────────────────────────────────── +// +// proxy-wasm header pairs 序列化结构(little-endian): +// ``` +// u32 count +// repeat count: u32 key_size, u32 value_size +// repeat count: key_bytes, \0, value_bytes, \0 +// ``` +// `\0` 是为 C 互操作而保留的尾字节;rust 解码端会忽略它。 +// 编码侧也按规范追加 `\0`。 + +pub fn encode_pairs(pairs: &[(&[u8], &[u8])]) -> Vec { + let count = pairs.len() as u32; + let mut cap: usize = 4 + pairs.len() * 8; + for (k, v) in pairs { + cap += k.len() + 1 + v.len() + 1; + } + let mut out = Vec::with_capacity(cap); + out.extend_from_slice(&count.to_le_bytes()); + for (k, v) in pairs { + out.extend_from_slice(&(k.len() as u32).to_le_bytes()); + out.extend_from_slice(&(v.len() as u32).to_le_bytes()); + } + for (k, v) in pairs { + out.extend_from_slice(k); + out.push(0); + out.extend_from_slice(v); + out.push(0); + } + out +} + +/// 解码 `proxy_set_header_map_pairs` 写入的字节流为 (key, value) 列表。 +/// +/// 严格按编码格式校验长度;不合法直接返回 `None`,由 host 端转 BadArgument。 +/// 空 map 允许两种编码:空 buf(`size=0`)或单 `0x00` 字节(spec §Serialization)。 +pub fn decode_pairs(bytes: &[u8]) -> Option, Vec)>> { + if bytes.is_empty() { + return Some(Vec::new()); + } + if bytes == [0u8] { + return Some(Vec::new()); + } + if bytes.len() < 4 { + return None; + } + let mut pos = 0; + let count = u32_from_slice(bytes, pos)? as usize; + pos += 4; + if bytes.len() < 4 + count * 8 { + return None; + } + let mut sizes = Vec::with_capacity(count); + for _ in 0..count { + let k = u32_from_slice(bytes, pos)? as usize; + pos += 4; + let v = u32_from_slice(bytes, pos)? as usize; + pos += 4; + sizes.push((k, v)); + } + let mut out = Vec::with_capacity(count); + for (ks, vs) in sizes { + if pos + ks + 1 + vs + 1 > bytes.len() { + return None; + } + let key = bytes[pos..pos + ks].to_vec(); + pos += ks + 1; + let val = bytes[pos..pos + vs].to_vec(); + pos += vs + 1; + out.push((key, val)); + } + Some(out) +} + +#[inline] +fn u32_from_slice(bytes: &[u8], pos: usize) -> Option { + let s = bytes.get(pos..pos + 4)?; + let arr: [u8; 4] = s.try_into().ok()?; + Some(u32::from_le_bytes(arr)) +} + +/// 把 property path(`\0` 分割的多段字节流)拆成 segments。 +/// +/// spec §Serialization: "Host implementations should tolerate a NULL character at the end". +pub fn decode_property_path(bytes: &[u8]) -> Vec<&[u8]> { + let trimmed = if bytes.last().copied() == Some(0) { &bytes[..bytes.len() - 1] } else { bytes }; + if trimmed.is_empty() { + return Vec::new(); + } + trimmed.split(|b| *b == 0u8).collect() +} diff --git a/crates/plugin-wasm/src/config.rs b/crates/plugin-wasm/src/config.rs new file mode 100644 index 00000000..1bedec50 --- /dev/null +++ b/crates/plugin-wasm/src/config.rs @@ -0,0 +1,206 @@ +//! `WasmPluginShell` 的 JSON spec(与演进文档 §5 对齐)。 + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum FailStrategy { + #[default] + FailOpen, + FailClose, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct WasmLimits { + /// 单个 VM 的线性内存页数上限(1 page = 64KiB)。 + #[serde(default)] + pub max_memory_pages: Option, + /// 每次 guest hook 调用前补充的 fuel;默认不配置时使用近似无限预算。 + #[serde(default)] + pub fuel_per_call: Option, + /// 每次 guest hook 的 epoch 超时窗口,单位毫秒;依赖 host 的 1ms epoch ticker。 + #[serde(default)] + pub epoch_timeout_millis: Option, + /// host 需要物化 body 时允许的最大字节数,覆盖请求 body、响应 body、dispatch 请求/响应 body。 + #[serde(default)] + pub max_body_bytes: Option, + /// 单个 VM 同时允许的未完成 `proxy_http_call` 数量。 + #[serde(default)] + pub max_pending_calls: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct OciAuthConfig { + /// Optional registry hint, for example `registry.cn-hangzhou.aliyuncs.com`. + #[serde(default)] + pub registry: Option, + /// Basic-auth username used for registry token exchange or direct registry auth. + #[serde(default)] + pub username: Option, + /// Basic-auth password used for registry token exchange or direct registry auth. + #[serde(default)] + pub password: Option, + /// Pre-issued bearer token for registries that do not need a token challenge exchange. + #[serde(default)] + pub bearer_token: Option, + /// Docker config `identitytoken`; treated as a bearer token by the registry client. + #[serde(default)] + pub identity_token: Option, +} + +fn default_use_cache() -> bool { + true +} + +fn default_vm_pool_size() -> usize { + 1 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct WasmPluginShellConfig { + /// `file://`、`http(s)://`、OCI 镜像 URL 或本地路径。 + pub url: String, + /// Optional OCI registry auth. Usually populated from Higress `imagePullSecret`. + #[serde(default)] + pub oci_auth: Option, + /// 可选 SHA-256 校验值,支持裸 hex 或 `sha256:`。 + /// + /// 配置该字段后,host 会在编译前校验拉取到的 wasm 字节;字段变化也会自动让模块缓存失效。 + #[serde(default)] + pub sha256: Option, + /// 可选模块缓存键。 + /// + /// 默认按 `url` 加 `sha256` 复用编译产物;当远端同 URL 发布新版本且未配置 sha256 时, + /// 可以把这里设置成版本号/etag/digest 来强制重新拉取并编译。 + #[serde(default)] + pub module_cache_key: Option, + /// 是否复用进程内 wasm Module 缓存。 + /// + /// 默认开启;关闭后每次创建/更新插件实例都会重新拉取并编译,适合开发调试。 + #[serde(default = "default_use_cache")] + pub use_cache: bool, + /// 传给 guest `proxy_on_configure` 的配置:可为 JSON 对象;序列化为 YAML 字节给 hai 系插件。 + #[serde(default)] + pub plugin_config: serde_json::Value, + #[serde(default)] + pub fail_strategy: FailStrategy, + /// `dispatch_http_call` 时 guest 传入的 cluster 名 → 真实 HTTP base URL。 + /// + /// 兼容 hai 的 Higress cluster 写法 `outbound|||.`: + /// 若直接命中则用配置 base,否则 host 会回退用 `:authority` header(hai 已带)发起请求。 + #[serde(default)] + pub clusters: HashMap, + #[serde(default)] + pub limits: WasmLimits, + /// 创建时是否尝试用占位 linker 实例化一次(尽早发现链接错误)。当前实现已弃用,保留兼容字段。 + #[serde(default = "default_validate")] + pub validate_on_create: bool, + /// 暴露给 guest 的 `plugin_name` well-known property(spec §Properties §Proxy-Wasm properties)。 + #[serde(default)] + pub plugin_name: String, + /// 暴露给 guest 的 `plugin_root_id` well-known property。 + #[serde(default)] + pub plugin_root_id: String, + /// 暴露给 guest 的 `plugin_vm_id` well-known property;同时用于 `proxy_resolve_shared_queue`。 + #[serde(default = "default_vm_id")] + pub plugin_vm_id: String, + /// 同一个 wasm 插件实例内创建的 VM 数量。 + /// + /// 默认 1,保持单 VM 串行语义;设置为大于 1 后,多个独立 VM 共享同一个已编译 Module, + /// 请求按 try-lock + round-robin 分发,用于降低长时间 `dispatch_http_call` 对后续请求的阻塞。 + #[serde(default = "default_vm_pool_size")] + pub vm_pool_size: usize, + /// wait 策略专用 VM 池大小。 + /// + /// 默认 0,表示不启用分类调度,所有请求都进入普通 VM 池。设置为大于 0 后, + /// 带 `X-RateLimit-Policy: wait` 的请求会进入独立 wait 池,避免长等待请求占满普通池。 + #[serde(default)] + pub wait_vm_pool_size: usize, +} + +fn default_vm_id() -> String { + "default".to_string() +} + +fn default_validate() -> bool { + false +} + +impl Default for WasmPluginShellConfig { + fn default() -> Self { + Self { + url: String::new(), + oci_auth: None, + sha256: None, + module_cache_key: None, + use_cache: default_use_cache(), + plugin_config: serde_json::Value::Null, + fail_strategy: FailStrategy::FailOpen, + clusters: HashMap::new(), + limits: WasmLimits::default(), + validate_on_create: false, + plugin_name: String::new(), + plugin_root_id: String::new(), + plugin_vm_id: default_vm_id(), + vm_pool_size: default_vm_pool_size(), + wait_vm_pool_size: 0, + } + } +} + +impl WasmPluginShellConfig { + pub fn normalized_vm_pool_size(&self) -> usize { + self.vm_pool_size.clamp(1, 64) + } + + pub fn normalized_wait_vm_pool_size(&self) -> usize { + self.wait_vm_pool_size.min(64) + } + + pub fn max_memory_bytes(&self) -> Option { + self.limits.max_memory_pages.map(|pages| pages as usize * 64 * 1024) + } + + pub fn guest_fuel_per_call(&self) -> u64 { + self.limits.fuel_per_call.unwrap_or(u64::MAX / 4).max(1) + } + + pub fn guest_epoch_deadline_ticks(&self) -> u64 { + // epoch ticker 以 1ms 为一跳;默认给一个很大的窗口,相当于不主动超时。 + self.limits.epoch_timeout_millis.unwrap_or(24 * 60 * 60 * 1000).clamp(1, 24 * 60 * 60 * 1000) + } + + /// 把 `plugin_config`(任意 JSON)转换为 hai 风格 YAML 字节流。 + /// + /// hai-process-mix 在 `on_configure` 内是 `serde_yaml::from_slice::(&bytes)`, + /// 所以无论上层用 JSON 还是 YAML 写,传给 guest 的都必须是 YAML 序列化结果。 + pub fn configuration_bytes(&self) -> Vec { + if self.plugin_config.is_null() { + return Vec::new(); + } + serde_yaml::to_string(&self.plugin_config).unwrap_or_default().into_bytes() + } + + /// 给定 guest 传来的 cluster 字符串,返回基础 URL(`http://host:port`)。 + /// + /// 优先精确匹配配置 map;其次尝试解析 Envoy/Higress 习惯写法 + /// `outbound|||` -> `http://:`; + /// 都不命中返回 `None`。 + pub fn resolve_cluster(&self, cluster: &str) -> Option { + if let Some(v) = self.clusters.get(cluster) { + return Some(v.clone()); + } + if let Some(rest) = cluster.strip_prefix("outbound|") { + let mut parts = rest.splitn(2, "||"); + let port = parts.next()?.trim(); + let host = parts.next()?.trim(); + if host.is_empty() || port.is_empty() { + return None; + } + return Some(format!("http://{host}:{port}")); + } + None + } +} diff --git a/crates/plugin-wasm/src/engine.rs b/crates/plugin-wasm/src/engine.rs new file mode 100644 index 00000000..0d5973cb --- /dev/null +++ b/crates/plugin-wasm/src/engine.rs @@ -0,0 +1,48 @@ +//! 共享 `wasmtime::Engine`:同进程内所有 wasm 插件实例共用。 +//! +//! **同步模式**:host fn 是 sync,故不能开 `async_support`——否则 host fn 内 +//! 调 guest 的 `proxy_on_memory_allocate` 会 panic「must use `call_async` with async stores」。 +//! `proxy_http_call` 的异步语义通过 `tokio::spawn` + mpsc channel 实现, +//! 不需要把整个 store 切到 async。 +//! +use once_cell::sync::OnceCell; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; +use wasmtime::{Config, Engine}; + +static ENGINE: OnceCell = OnceCell::new(); +static EPOCH_TICKER_STARTED: AtomicBool = AtomicBool::new(false); + +/// 进程级单例 Engine(multi-memory 开,async 关)。 +pub fn shared_engine() -> &'static Engine { + ENGINE.get_or_init(|| { + let mut cfg = Config::new(); + cfg.wasm_multi_memory(true); + cfg.consume_fuel(true); + cfg.epoch_interruption(true); + cfg.async_support(false); + Engine::new(&cfg).expect("wasmtime Engine::new") + }) +} + +/// 启动一个进程级 epoch ticker。每 1ms 递增一次 Engine epoch,配合 Store epoch deadline +/// 给同步 guest hook 提供粗粒度墙钟超时保护。 +pub fn ensure_epoch_ticker_started() { + if EPOCH_TICKER_STARTED.load(Ordering::Acquire) { + return; + } + let Ok(handle) = tokio::runtime::Handle::try_current() else { + return; + }; + if EPOCH_TICKER_STARTED.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire).is_err() { + return; + } + let engine = shared_engine().clone(); + handle.spawn(async move { + let mut interval = tokio::time::interval(Duration::from_millis(1)); + loop { + interval.tick().await; + engine.increment_epoch(); + } + }); +} diff --git a/crates/plugin-wasm/src/error.rs b/crates/plugin-wasm/src/error.rs new file mode 100644 index 00000000..ae6303ea --- /dev/null +++ b/crates/plugin-wasm/src/error.rs @@ -0,0 +1,31 @@ +//! WASM 插件宿主侧错误类型(实现 `std::error::Error`,可自动装箱为 `BoxError`)。 + +#[derive(Debug, thiserror::Error)] +pub enum WasmHostError { + #[error("fetch wasm: {0}")] + Fetch(String), + #[error("wasmtime: {0}")] + Wasmtime(#[from] wasmtime::Error), + #[error("instantiation failed: {0}")] + Instantiate(String), + #[error("guest abi violation: {0}")] + AbiViolation(String), + #[error("memory oob: ptr={ptr} len={len}")] + MemoryOob { ptr: u32, len: u32 }, + #[error("wasm guest trap during {hook}: {source}")] + GuestTrap { hook: &'static str, source: wasmtime::Error }, + #[error("dispatch_http_call: {0}")] + Dispatch(String), + #[error("body too large: {actual} bytes exceeds limit {limit} bytes")] + BodyTooLarge { actual: usize, limit: usize }, + #[error("resource limit: {0}")] + ResourceLimit(String), + #[error("config: {0}")] + Config(String), +} + +impl WasmHostError { + pub fn requires_vm_rebuild(&self) -> bool { + matches!(self, Self::GuestTrap { .. } | Self::Wasmtime(_) | Self::Dispatch(_) | Self::ResourceLimit(_)) + } +} diff --git a/crates/plugin-wasm/src/fetch.rs b/crates/plugin-wasm/src/fetch.rs new file mode 100644 index 00000000..d3519046 --- /dev/null +++ b/crates/plugin-wasm/src/fetch.rs @@ -0,0 +1,365 @@ +//! 同步拉取 WASM 字节(在 `Plugin::create` 同步上下文中使用)。 +//! +//! 支持:`file://...`、裸文件系统路径、`http(s)://...` 与 OCI 镜像 URL。 +//! 网络拉取通过临时线程运行 async reqwest,避免在 `Plugin::create` 这条同步路径里嵌套 tokio runtime。 + +use crate::config::OciAuthConfig; +use crate::error::WasmHostError; +use reqwest::header::{ACCEPT, WWW_AUTHENTICATE}; +use serde::Deserialize; +use std::{collections::HashMap, time::Duration}; + +const OCI_MANIFEST_ACCEPT: &str = "application/vnd.oci.image.manifest.v1+json, application/vnd.docker.distribution.manifest.v2+json, application/vnd.oci.image.index.v1+json, application/vnd.docker.distribution.manifest.list.v2+json, application/vnd.oci.artifact.manifest.v1+json"; +const OCI_BLOB_ACCEPT: &str = "application/vnd.module.wasm.content.layer.v1+wasm, application/wasm, application/vnd.wasm.content.layer.v1+wasm, application/octet-stream"; + +fn fetch_http_wasm_bytes_sync(url: &str) -> Result, WasmHostError> { + let url = url.to_string(); + std::thread::Builder::new() + .name("spacegate-wasm-fetch".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().map_err(|e| WasmHostError::Fetch(format!("build fetch runtime: {e}")))?; + rt.block_on(async move { + let client = reqwest::Client::builder().timeout(Duration::from_secs(30)).build().map_err(|e| WasmHostError::Fetch(format!("build http client: {e}")))?; + let resp = client + .get(&url) + .send() + .await + .map_err(|e| WasmHostError::Fetch(format!("GET {url}: {e}")))? + .error_for_status() + .map_err(|e| WasmHostError::Fetch(format!("GET {url}: {e}")))?; + let bytes = resp.bytes().await.map_err(|e| WasmHostError::Fetch(format!("read {url} body: {e}")))?; + Ok(bytes.to_vec()) + }) + }) + .map_err(|e| WasmHostError::Fetch(format!("spawn fetch thread: {e}")))? + .join() + .map_err(|_| WasmHostError::Fetch("fetch thread panicked".to_string()))? +} + +fn fetch_oci_wasm_bytes_sync(url: &str, auth: Option) -> Result, WasmHostError> { + let url = url.to_string(); + std::thread::Builder::new() + .name("spacegate-wasm-oci-fetch".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().map_err(|e| WasmHostError::Fetch(format!("build OCI fetch runtime: {e}")))?; + rt.block_on(async move { + let reference = OciReference::parse(&url)?; + if let Some(auth_registry) = auth.as_ref().and_then(|a| a.registry.as_deref()).filter(|v| !v.trim().is_empty()) { + if !auth_registry.eq_ignore_ascii_case(&reference.registry) { + return Err(WasmHostError::Fetch(format!( + "OCI auth registry `{auth_registry}` does not match image registry `{}`", + reference.registry + ))); + } + } + let client = reqwest::Client::builder().timeout(Duration::from_secs(60)).build().map_err(|e| WasmHostError::Fetch(format!("build OCI client: {e}")))?; + fetch_oci_wasm_bytes(&client, &reference, auth.as_ref()).await + }) + }) + .map_err(|e| WasmHostError::Fetch(format!("spawn OCI fetch thread: {e}")))? + .join() + .map_err(|_| WasmHostError::Fetch("OCI fetch thread panicked".to_string()))? +} + +pub fn fetch_wasm_bytes_sync(url_or_path: &str) -> Result, WasmHostError> { + fetch_wasm_bytes_sync_with_auth(url_or_path, None) +} + +pub fn fetch_wasm_bytes_sync_with_auth(url_or_path: &str, oci_auth: Option<&OciAuthConfig>) -> Result, WasmHostError> { + let trim = url_or_path.trim(); + if let Some(rest) = trim.strip_prefix("file://") { + return std::fs::read(rest).map_err(|e| WasmHostError::Fetch(format!("read file {rest}: {e}"))); + } + if trim.starts_with("http://") || trim.starts_with("https://") { + return fetch_http_wasm_bytes_sync(trim); + } + if is_oci_url(trim) { + return fetch_oci_wasm_bytes_sync(trim, oci_auth.cloned()); + } + std::fs::read(trim).map_err(|e| WasmHostError::Fetch(format!("read path {trim}: {e}"))) +} + +pub fn is_oci_url(url: &str) -> bool { + let lower = url.to_ascii_lowercase(); + lower.starts_with("oci://") || lower.starts_with("docker://") || lower.starts_with("image://") || lower.starts_with("oci+http://") +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OciReference { + scheme: &'static str, + registry: String, + repository: String, + reference: String, +} + +impl OciReference { + fn parse(url: &str) -> Result { + let trim = url.trim(); + let (scheme, rest) = if let Some(rest) = trim.strip_prefix("oci+http://") { + ("http", rest) + } else if let Some(rest) = trim.strip_prefix("oci://") { + (default_oci_scheme(rest), rest) + } else if let Some(rest) = trim.strip_prefix("docker://") { + (default_oci_scheme(rest), rest) + } else if let Some(rest) = trim.strip_prefix("image://") { + (default_oci_scheme(rest), rest) + } else { + return Err(WasmHostError::Fetch(format!("unsupported OCI URL scheme: {trim}"))); + }; + + let Some((registry, image)) = rest.split_once('/') else { + return Err(WasmHostError::Fetch(format!("OCI URL must include registry and repository: {trim}"))); + }; + if registry.trim().is_empty() || image.trim().is_empty() { + return Err(WasmHostError::Fetch(format!("OCI URL must include registry and repository: {trim}"))); + } + + let (repository, reference) = if let Some((repository, digest)) = image.rsplit_once('@') { + (repository, digest) + } else if let Some((repository, tag)) = split_tag(image) { + (repository, tag) + } else { + (image, "latest") + }; + if repository.trim().is_empty() || reference.trim().is_empty() { + return Err(WasmHostError::Fetch(format!("OCI URL must include repository and tag/digest: {trim}"))); + } + + Ok(Self { + scheme, + registry: registry.to_string(), + repository: repository.to_string(), + reference: reference.to_string(), + }) + } + + fn manifest_url(&self, reference: &str) -> String { + format!("{}://{}/v2/{}/manifests/{}", self.scheme, self.registry, self.repository, reference) + } + + fn blob_url(&self, digest: &str) -> String { + format!("{}://{}/v2/{}/blobs/{}", self.scheme, self.registry, self.repository, digest) + } +} + +fn default_oci_scheme(rest: &str) -> &'static str { + let registry = rest.split('/').next().unwrap_or_default(); + if registry.starts_with("localhost") || registry.starts_with("127.0.0.1") || registry.starts_with("[::1]") { + "http" + } else { + "https" + } +} + +fn split_tag(image: &str) -> Option<(&str, &str)> { + let slash = image.rfind('/').map(|idx| idx + 1).unwrap_or(0); + let colon = image[slash..].rfind(':').map(|idx| slash + idx)?; + Some((&image[..colon], &image[colon + 1..])) +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct OciManifest { + #[serde(default)] + media_type: Option, + #[serde(default)] + manifests: Vec, + #[serde(default)] + layers: Vec, + #[serde(default)] + blobs: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct OciDescriptor { + media_type: String, + digest: String, + #[serde(default)] + platform: Option, +} + +#[derive(Debug, Deserialize)] +struct OciPlatform { + #[serde(default)] + architecture: Option, + #[serde(default)] + os: Option, +} + +async fn fetch_oci_wasm_bytes(client: &reqwest::Client, reference: &OciReference, auth: Option<&OciAuthConfig>) -> Result, WasmHostError> { + let manifest = fetch_oci_manifest(client, reference, &reference.reference, auth).await?; + let manifest = if is_index_manifest(&manifest) { + let child = select_manifest_descriptor(&manifest.manifests)?; + fetch_oci_manifest(client, reference, &child.digest, auth).await? + } else { + manifest + }; + let layer = select_wasm_descriptor(&manifest)?; + registry_get_bytes(client, &reference.blob_url(&layer.digest), OCI_BLOB_ACCEPT, reference, auth).await +} + +async fn fetch_oci_manifest(client: &reqwest::Client, reference: &OciReference, manifest_ref: &str, auth: Option<&OciAuthConfig>) -> Result { + let url = reference.manifest_url(manifest_ref); + let bytes = registry_get_bytes(client, &url, OCI_MANIFEST_ACCEPT, reference, auth).await?; + serde_json::from_slice(&bytes).map_err(|e| WasmHostError::Fetch(format!("parse OCI manifest {url}: {e}"))) +} + +async fn registry_get_bytes(client: &reqwest::Client, url: &str, accept: &str, reference: &OciReference, auth: Option<&OciAuthConfig>) -> Result, WasmHostError> { + let send = |token: Option<&str>| { + let req = client.get(url).header(ACCEPT, accept); + apply_registry_auth(req, auth, token) + }; + + let resp = send(None).send().await.map_err(|e| WasmHostError::Fetch(format!("GET {url}: {e}")))?; + let resp = if resp.status() == reqwest::StatusCode::UNAUTHORIZED { + let challenge = resp.headers().get(WWW_AUTHENTICATE).and_then(|v| v.to_str().ok()).unwrap_or_default(); + let token = fetch_bearer_token(client, challenge, reference, auth).await?; + send(Some(&token)).send().await.map_err(|e| WasmHostError::Fetch(format!("GET {url}: {e}")))? + } else { + resp + }; + + let status = resp.status(); + if !status.is_success() { + return Err(WasmHostError::Fetch(format!("GET {url}: {status}"))); + } + let bytes = resp.bytes().await.map_err(|e| WasmHostError::Fetch(format!("read OCI response {url}: {e}")))?; + Ok(bytes.to_vec()) +} + +fn apply_registry_auth(req: reqwest::RequestBuilder, auth: Option<&OciAuthConfig>, bearer_token: Option<&str>) -> reqwest::RequestBuilder { + if let Some(token) = bearer_token { + return req.bearer_auth(token); + } + let Some(auth) = auth else { + return req; + }; + if let Some(token) = auth.bearer_token.as_deref().or(auth.identity_token.as_deref()).filter(|v| !v.trim().is_empty()) { + return req.bearer_auth(token); + } + if let Some(username) = auth.username.as_deref().filter(|v| !v.trim().is_empty()) { + return req.basic_auth(username, auth.password.clone()); + } + req +} + +async fn fetch_bearer_token(client: &reqwest::Client, challenge: &str, reference: &OciReference, auth: Option<&OciAuthConfig>) -> Result { + let params = + parse_bearer_challenge(challenge).ok_or_else(|| WasmHostError::Fetch(format!("registry {} requires auth but did not return a Bearer challenge", reference.registry)))?; + let realm = params.get("realm").filter(|v| !v.trim().is_empty()).ok_or_else(|| WasmHostError::Fetch("Bearer auth challenge missing realm".to_string()))?; + let mut url = reqwest::Url::parse(realm).map_err(|e| WasmHostError::Fetch(format!("parse Bearer token realm {realm}: {e}")))?; + { + let mut query = url.query_pairs_mut(); + if let Some(service) = params.get("service").filter(|v| !v.trim().is_empty()) { + query.append_pair("service", service); + } + let scope = params.get("scope").cloned().unwrap_or_else(|| format!("repository:{}:pull", reference.repository)); + query.append_pair("scope", &scope); + } + + let req = apply_registry_auth(client.get(url.clone()), auth, None); + let resp = req.send().await.map_err(|e| WasmHostError::Fetch(format!("GET OCI token {url}: {e}")))?; + let status = resp.status(); + if !status.is_success() { + return Err(WasmHostError::Fetch(format!("GET OCI token {url}: {status}"))); + } + let bytes = resp.bytes().await.map_err(|e| WasmHostError::Fetch(format!("read OCI token {url}: {e}")))?; + let token: OciTokenResponse = serde_json::from_slice(&bytes).map_err(|e| WasmHostError::Fetch(format!("parse OCI token response {url}: {e}")))?; + token.token.or(token.access_token).filter(|v| !v.trim().is_empty()).ok_or_else(|| WasmHostError::Fetch(format!("OCI token response {url} did not include token"))) +} + +#[derive(Debug, Deserialize)] +struct OciTokenResponse { + #[serde(default)] + token: Option, + #[serde(default)] + access_token: Option, +} + +fn parse_bearer_challenge(header: &str) -> Option> { + let rest = header.trim().strip_prefix("Bearer ")?; + let mut params = HashMap::new(); + for part in split_quoted_commas(rest) { + let Some((key, value)) = part.split_once('=') else { + continue; + }; + params.insert(key.trim().to_ascii_lowercase(), value.trim().trim_matches('"').to_string()); + } + Some(params) +} + +fn split_quoted_commas(value: &str) -> Vec<&str> { + let mut parts = Vec::new(); + let mut start = 0; + let mut in_quotes = false; + for (idx, ch) in value.char_indices() { + match ch { + '"' => in_quotes = !in_quotes, + ',' if !in_quotes => { + parts.push(value[start..idx].trim()); + start = idx + 1; + } + _ => {} + } + } + parts.push(value[start..].trim()); + parts +} + +fn is_index_manifest(manifest: &OciManifest) -> bool { + manifest.media_type.as_deref().map(|mt| mt.contains("image.index") || mt.contains("manifest.list")).unwrap_or(false) || !manifest.manifests.is_empty() +} + +fn select_manifest_descriptor(manifests: &[OciDescriptor]) -> Result<&OciDescriptor, WasmHostError> { + manifests + .iter() + .find(|m| m.platform.as_ref().map(|p| p.architecture.as_deref() == Some("wasm") || p.os.as_deref() == Some("wasi")).unwrap_or(false)) + .or_else(|| manifests.first()) + .ok_or_else(|| WasmHostError::Fetch("OCI image index does not contain manifests".to_string())) +} + +fn select_wasm_descriptor(manifest: &OciManifest) -> Result<&OciDescriptor, WasmHostError> { + let descriptors = manifest.layers.iter().chain(manifest.blobs.iter()).collect::>(); + descriptors + .iter() + .copied() + .find(|layer| is_wasm_media_type(&layer.media_type)) + .or_else(|| (descriptors.len() == 1).then(|| descriptors[0])) + .ok_or_else(|| WasmHostError::Fetch("OCI image does not contain a wasm layer".to_string())) +} + +fn is_wasm_media_type(media_type: &str) -> bool { + matches!( + media_type, + "application/vnd.module.wasm.content.layer.v1+wasm" | "application/vnd.wasm.content.layer.v1+wasm" | "application/wasm" + ) || media_type.contains("wasm") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_oci_reference_tag_digest_and_default_tag() { + assert_eq!( + OciReference::parse("oci://registry.example.com/ns/plugin:v1").unwrap(), + OciReference { + scheme: "https", + registry: "registry.example.com".to_string(), + repository: "ns/plugin".to_string(), + reference: "v1".to_string(), + } + ); + assert_eq!(OciReference::parse("docker://localhost:5000/plugin").unwrap().reference, "latest"); + assert_eq!(OciReference::parse("image://registry.example.com/ns/plugin@sha256:abc").unwrap().reference, "sha256:abc"); + } + + #[test] + fn parses_bearer_challenge() { + let parsed = parse_bearer_challenge(r#"Bearer realm="https://auth.example/token",service="registry.example",scope="repository:ns/plugin:pull""#).unwrap(); + assert_eq!(parsed["realm"], "https://auth.example/token"); + assert_eq!(parsed["service"], "registry.example"); + assert_eq!(parsed["scope"], "repository:ns/plugin:pull"); + } +} diff --git a/crates/plugin-wasm/src/host_fn.rs b/crates/plugin-wasm/src/host_fn.rs new file mode 100644 index 00000000..79c494d5 --- /dev/null +++ b/crates/plugin-wasm/src/host_fn.rs @@ -0,0 +1,1351 @@ +//! 把 proxy-wasm v0.2.1 全部 host fn 注册到 `wasmtime::Linker`。 +//! +//! 实现策略: +//! +//! - 全部使用 **同步** `func_wrap`(host 端不需要 await)。 +//! - `proxy_http_call` 是唯一的"异步"——它**同步**返回 token,把真正的 HTTP 调用 `tokio::spawn` +//! 出去,结果通过 `dispatch_tx` 投递回 Vm 状态机;Vm 主循环 await。 +//! - gRPC / 外部函数:进程内不接 gRPC client / FFI 注册表,返回 `Unimplemented` / `NotFound`。 +//! - 命名与 proxy-wasm spec 完全一致;参数按 i32(线性内存偏移/长度均为 i32)。 + +use std::time::Duration; + +use bytes::Bytes; +use http::{HeaderMap, HeaderName, HeaderValue}; +use tracing::{debug, info, warn}; +use wasmtime::{AsContext, AsContextMut, Caller, Linker}; + +use crate::abi::{ + decode_pairs, decode_property_path, encode_pairs, host_max_log_level, log_level_to_tracing, BufferType, LogLevel, MapType, MemoryHelper, MetricType, Status, StreamType, +}; +use crate::host_state::{HostState, HttpCallResult, LocalResponse}; +use crate::shared::{ + metric_define, metric_get, metric_increment, metric_record, queue_dequeue, queue_enqueue, queue_register, queue_resolve, shared_data_get, shared_data_set, MetricOpResult, + QueueOpResult, SharedDataSetResult, +}; + +/// 把所有 proxy-wasm v0.2.1 host fn 注册到 linker。 +/// +/// `dispatch_tx` 用于把异步 HTTP 调用结果发送给 Vm 状态机。 +pub fn register_all(linker: &mut Linker, dispatch_tx: tokio::sync::mpsc::UnboundedSender<(u32, HttpCallResult)>) -> Result<(), wasmtime::Error> { + register_log(linker)?; + register_clock_and_tick(linker)?; + register_context_control(linker)?; + register_stream_control(linker)?; + register_buffer(linker)?; + register_headers(linker)?; + register_status_and_local_response(linker)?; + register_http_call(linker, dispatch_tx)?; + register_shared_data_and_queue(linker)?; + register_metrics(linker)?; + register_property(linker)?; + register_grpc_unimplemented(linker)?; + register_foreign_function(linker)?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Logging(spec §Logging) +// ───────────────────────────────────────────────────────── + +fn register_log(linker: &mut Linker) -> Result<(), wasmtime::Error> { + linker.func_wrap("env", "proxy_log", |mut caller: Caller<'_, HostState>, level: i32, msg_ptr: i32, msg_size: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let Ok(msg) = mem.read_string_lossy(caller.as_context(), msg_ptr as u32, msg_size as u32) else { + return Status::InvalidMemoryAccess.as_i32(); + }; + let Some(lvl) = log_level_to_tracing(level) else { + return Status::BadArgument.as_i32(); + }; + match lvl { + tracing::Level::TRACE => tracing::trace!(target: "spacegate_plugin_wasm::guest", "{msg}"), + tracing::Level::DEBUG => tracing::debug!(target: "spacegate_plugin_wasm::guest", "{msg}"), + tracing::Level::INFO => tracing::info!(target: "spacegate_plugin_wasm::guest", "{msg}"), + tracing::Level::WARN => tracing::warn!(target: "spacegate_plugin_wasm::guest", "{msg}"), + tracing::Level::ERROR => tracing::error!(target: "spacegate_plugin_wasm::guest", "{msg}"), + } + Status::Ok.as_i32() + })?; + + linker.func_wrap("env", "proxy_get_log_level", |mut caller: Caller<'_, HostState>, return_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let lvl: LogLevel = host_max_log_level(); + if mem.write_u32(caller.as_context_mut(), return_ptr as u32, lvl.as_i32() as u32).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + })?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Clocks / Timers / Context control(spec §Clocks §Timers §Context lifecycle) +// ───────────────────────────────────────────────────────── + +fn register_clock_and_tick(linker: &mut Linker) -> Result<(), wasmtime::Error> { + linker.func_wrap("env", "proxy_get_current_time_nanoseconds", |mut caller: Caller<'_, HostState>, return_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let nanos = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_nanos() as u64).unwrap_or(0); + if mem.write_u64(caller.as_context_mut(), return_ptr as u32, nanos).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + })?; + + linker.func_wrap("env", "proxy_set_tick_period_milliseconds", |mut caller: Caller<'_, HostState>, period: i32| -> i32 { + caller.data_mut().tick_period_ms = if period > 0 { Some(period as u32) } else { None }; + Status::Ok.as_i32() + })?; + Ok(()) +} + +fn register_context_control(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_set_effective_context(context_id) -> Status + linker.func_wrap("env", "proxy_set_effective_context", |mut caller: Caller<'_, HostState>, ctx_id: i32| -> i32 { + let cid = ctx_id as u32; + let st = caller.data_mut(); + if st.contexts.contains_key(&cid) || cid == st.root_context_id { + st.effective_context = cid; + Status::Ok.as_i32() + } else { + Status::BadArgument.as_i32() + } + })?; + + // proxy_done() -> Status + // + // spec §proxy_done:guest 在 `proxy_on_done` 返回 false 之后调本 hostcall 表示「确实做完了」。 + // host 据此结束等待,进入 on_log/on_delete(在 vm.rs 处理)。 + linker.func_wrap("env", "proxy_done", |mut caller: Caller<'_, HostState>| -> i32 { + let st = caller.data_mut(); + let cid = st.effective_context; + if let Some(ctx) = st.contexts.get_mut(&cid) { + if !ctx.awaiting_done { + return Status::NotFound.as_i32(); + } + ctx.done_marker = true; + ctx.awaiting_done = false; + Status::Ok.as_i32() + } else { + Status::NotFound.as_i32() + } + })?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Stream control(spec §Common HTTP and TCP stream operations) +// ───────────────────────────────────────────────────────── + +fn register_stream_control(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_continue_stream(stream_type) -> Status + // + // 我们 host 端仅处理 HTTP_REQUEST/HTTP_RESPONSE 的 continue:把当前 ctx 的 + // continue_requested 置 true,Vm 状态机据此退出 await loop。Downstream/Upstream + // 我们不接 TCP 层 → 返回 UNIMPLEMENTED(spec 允许)。 + linker.func_wrap("env", "proxy_continue_stream", |mut caller: Caller<'_, HostState>, stream_type: i32| -> i32 { + let Some(st_kind) = StreamType::from_i32(stream_type) else { + return Status::BadArgument.as_i32(); + }; + match st_kind { + StreamType::HttpRequest | StreamType::HttpResponse => { + let st = caller.data(); + let ctx_id = st.effective_context; + if let Some(ctx) = caller.data_mut().contexts.get_mut(&ctx_id) { + ctx.continue_requested = true; + } + Status::Ok.as_i32() + } + StreamType::Downstream | StreamType::Upstream => Status::Unimplemented.as_i32(), + } + })?; + + // proxy_close_stream(stream_type) -> Status + linker.func_wrap("env", "proxy_close_stream", |_caller: Caller<'_, HostState>, stream_type: i32| -> i32 { + match StreamType::from_i32(stream_type) { + Some(StreamType::HttpRequest) | Some(StreamType::HttpResponse) => Status::Ok.as_i32(), + Some(StreamType::Downstream) | Some(StreamType::Upstream) => Status::Unimplemented.as_i32(), + None => Status::BadArgument.as_i32(), + } + })?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Buffers(spec §Buffers) +// ───────────────────────────────────────────────────────── + +/// 取 buffer 内容(克隆出一份,避免后续借用冲突)。 +fn read_buffer(state: &HostState, buf_type: BufferType) -> Option> { + match buf_type { + BufferType::PluginConfiguration | BufferType::VmConfiguration => Some(state.configuration.clone()), + BufferType::HttpRequestBody => state.current_context().and_then(|c| c.request_body.as_ref().map(|b| b.to_vec())), + BufferType::HttpResponseBody => state.current_context().and_then(|c| c.response_body.as_ref().map(|b| b.to_vec())), + BufferType::HttpCallResponseBody => state.current_context().map(|c| c.last_call_body.to_vec()), + // 未支持的(TCP / gRPC / FFI args):buffer 类型本身合法,但当前 host 无数据 → NotFound + BufferType::DownstreamData | BufferType::UpstreamData | BufferType::GrpcCallMessage | BufferType::ForeignFunctionArguments => None, + } +} + +fn register_buffer(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_get_buffer_bytes(buffer_type, start, max_size, *return_data, *return_size) -> Status + linker.func_wrap( + "env", + "proxy_get_buffer_bytes", + |mut caller: Caller<'_, HostState>, buffer_type: i32, start: i32, max_size: i32, return_data_ptr: i32, return_size_ptr: i32| -> i32 { + let Some(buf_type) = BufferType::from_i32(buffer_type) else { + return Status::BadArgument.as_i32(); + }; + let bytes_opt = read_buffer(caller.data(), buf_type); + let Some(bytes) = bytes_opt else { + return Status::NotFound.as_i32(); + }; + let start = (start as u32) as usize; + let max_size = (max_size as u32) as usize; + if start > bytes.len() { + return Status::BadArgument.as_i32(); + } + let end = (start.saturating_add(max_size)).min(bytes.len()); + let slice = &bytes[start..end]; + match write_alloc_pair(&mut caller, slice, return_data_ptr as u32, return_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + } + }, + )?; + + // proxy_get_buffer_status(buffer_type, *return_buffer_size, *return_unused) -> Status + linker.func_wrap( + "env", + "proxy_get_buffer_status", + |mut caller: Caller<'_, HostState>, buffer_type: i32, return_size_ptr: i32, return_unused_ptr: i32| -> i32 { + let Some(buf_type) = BufferType::from_i32(buffer_type) else { + return Status::BadArgument.as_i32(); + }; + let len = match read_buffer(caller.data(), buf_type) { + Some(b) => b.len() as u32, + None => return Status::NotFound.as_i32(), + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u32(caller.as_context_mut(), return_size_ptr as u32, len).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + let _ = mem.write_u32(caller.as_context_mut(), return_unused_ptr as u32, 0); + Status::Ok.as_i32() + }, + )?; + + // proxy_set_buffer_bytes(buffer_type, start, size, *data, data_size) -> Status + // + // spec §Buffers proxy_set_buffer_bytes:可做 prepend / append / inject / replace。 + // start, size 解释为:用 (data, data_size) 替换 [start, start+size) 范围。 + linker.func_wrap( + "env", + "proxy_set_buffer_bytes", + |mut caller: Caller<'_, HostState>, buffer_type: i32, start: i32, size: i32, data_ptr: i32, data_size: i32| -> i32 { + let Some(buf_type) = BufferType::from_i32(buffer_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let new_bytes = match mem.read_bytes(caller.as_context(), data_ptr as u32, data_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let ctx_id = caller.data().effective_context; + let st = caller.data_mut(); + let Some(ctx) = st.contexts.get_mut(&ctx_id) else { + return Status::NotFound.as_i32(); + }; + match buf_type { + BufferType::HttpRequestBody => { + let cur = ctx.request_body.take().unwrap_or_default(); + ctx.request_body = Some(splice_buffer(&cur, start as u32, size as u32, &new_bytes)); + Status::Ok.as_i32() + } + BufferType::HttpResponseBody => { + let cur = ctx.response_body.take().unwrap_or_default(); + ctx.response_body = Some(splice_buffer(&cur, start as u32, size as u32, &new_bytes)); + Status::Ok.as_i32() + } + // TCP / gRPC / 配置 / FFI args:本 host 不支持写 + BufferType::DownstreamData + | BufferType::UpstreamData + | BufferType::GrpcCallMessage + | BufferType::VmConfiguration + | BufferType::PluginConfiguration + | BufferType::HttpCallResponseBody + | BufferType::ForeignFunctionArguments => Status::BadArgument.as_i32(), + } + }, + )?; + + Ok(()) +} + +/// spec §proxy_set_buffer_bytes:用 `replacement` 替换 `cur[start..start+size]`。 +fn splice_buffer(cur: &Bytes, start: u32, size: u32, replacement: &[u8]) -> Bytes { + let cur_len = cur.len(); + let start = (start as usize).min(cur_len); + let size = (size as usize).min(cur_len.saturating_sub(start)); + let mut out = Vec::with_capacity(cur_len.saturating_add(replacement.len())); + out.extend_from_slice(&cur[..start]); + out.extend_from_slice(replacement); + out.extend_from_slice(&cur[start + size..]); + Bytes::from(out) +} + +// ───────────────────────────────────────────────────────── +// HTTP fields(spec §HTTP fields) +// ───────────────────────────────────────────────────────── + +fn register_headers(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_get_header_map_size(map_type, *return_size) -> Status + linker.func_wrap( + "env", + "proxy_get_header_map_size", + |mut caller: Caller<'_, HostState>, map_type: i32, return_size_ptr: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let pairs = collect_pairs(caller.data(), mt); + let buf = { + let refs: Vec<(&[u8], &[u8])> = pairs.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect(); + encode_pairs(&refs) + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u32(caller.as_context_mut(), return_size_ptr as u32, buf.len() as u32).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + }, + )?; + + // proxy_get_header_map_pairs + linker.func_wrap( + "env", + "proxy_get_header_map_pairs", + |mut caller: Caller<'_, HostState>, map_type: i32, return_data_ptr: i32, return_size_ptr: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let pairs = collect_pairs(caller.data(), mt); + let buf = { + let refs: Vec<(&[u8], &[u8])> = pairs.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect(); + encode_pairs(&refs) + }; + match write_alloc_pair(&mut caller, &buf, return_data_ptr as u32, return_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + } + }, + )?; + + // proxy_set_header_map_pairs + linker.func_wrap( + "env", + "proxy_set_header_map_pairs", + |mut caller: Caller<'_, HostState>, map_type: i32, data_ptr: i32, data_size: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let raw = match mem.read_bytes(caller.as_context(), data_ptr as u32, data_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let Some(pairs) = decode_pairs(&raw) else { + return Status::SerializationFailure.as_i32(); + }; + let new_map = pairs_to_header_map(&pairs); + replace_map(caller.data_mut(), mt, new_map); + Status::Ok.as_i32() + }, + )?; + + // proxy_get_header_map_value + linker.func_wrap( + "env", + "proxy_get_header_map_value", + |mut caller: Caller<'_, HostState>, map_type: i32, key_ptr: i32, key_size: i32, return_data_ptr: i32, return_size_ptr: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_string_lossy(caller.as_context(), key_ptr as u32, key_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key_l = key.to_ascii_lowercase(); + let Some(value) = lookup_header(caller.data(), mt, &key_l) else { + return Status::NotFound.as_i32(); + }; + let bytes = value.into_bytes(); + match write_alloc_pair(&mut caller, &bytes, return_data_ptr as u32, return_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + } + }, + )?; + + // proxy_add_header_map_value + linker.func_wrap( + "env", + "proxy_add_header_map_value", + |mut caller: Caller<'_, HostState>, map_type: i32, key_ptr: i32, key_size: i32, value_ptr: i32, value_size: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_string_lossy(caller.as_context(), key_ptr as u32, key_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let value = match mem.read_string_lossy(caller.as_context(), value_ptr as u32, value_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + mutate_header(caller.data_mut(), mt, &key, HeaderMutation::Add(value)) + }, + )?; + + // proxy_replace_header_map_value + linker.func_wrap( + "env", + "proxy_replace_header_map_value", + |mut caller: Caller<'_, HostState>, map_type: i32, key_ptr: i32, key_size: i32, value_ptr: i32, value_size: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_string_lossy(caller.as_context(), key_ptr as u32, key_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let value = match mem.read_string_lossy(caller.as_context(), value_ptr as u32, value_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + mutate_header(caller.data_mut(), mt, &key, HeaderMutation::Replace(value)) + }, + )?; + + // proxy_remove_header_map_value + linker.func_wrap( + "env", + "proxy_remove_header_map_value", + |mut caller: Caller<'_, HostState>, map_type: i32, key_ptr: i32, key_size: i32| -> i32 { + let Some(mt) = MapType::from_i32(map_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_string_lossy(caller.as_context(), key_ptr as u32, key_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + mutate_header(caller.data_mut(), mt, &key, HeaderMutation::Remove) + }, + )?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Local response / status(spec §HTTP streams §proxy_send_local_response) +// ───────────────────────────────────────────────────────── + +fn register_status_and_local_response(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_get_status(*return_status_code, **msg_data, *msg_size) -> Status + // + // spec §proxy_get_status:在 on_http_call_response 中返回该次 HTTP 调用的 status; + // 其它时机我们返回当前响应 status。 + linker.func_wrap( + "env", + "proxy_get_status", + |mut caller: Caller<'_, HostState>, status_code_ptr: i32, msg_data_ptr: i32, msg_size_ptr: i32| -> i32 { + let (code, msg): (u32, String) = match caller.data().current_context() { + Some(c) => { + if c.last_call_status > 0 { + (c.last_call_status as u32, c.last_call_status_message.clone()) + } else if let Some(rs) = c.response_status { + (rs as u32, c.response_status_message.clone()) + } else { + (0, String::new()) + } + } + None => (0, String::new()), + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u32(caller.as_context_mut(), status_code_ptr as u32, code).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + let bytes = msg.into_bytes(); + match write_alloc_pair(&mut caller, &bytes, msg_data_ptr as u32, msg_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + } + }, + )?; + + // proxy_send_local_response(status, *status_text, status_text_size, *body, body_size, *headers, headers_size, grpc_status) + linker.func_wrap( + "env", + "proxy_send_local_response", + |mut caller: Caller<'_, HostState>, + status: i32, + _status_text_data: i32, + _status_text_size: i32, + body_data: i32, + body_size: i32, + headers_data: i32, + headers_size: i32, + _grpc_status: i32| + -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let body = if body_size > 0 { + match mem.read_bytes(caller.as_context(), body_data as u32, body_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + } + } else { + Vec::new() + }; + let headers_bytes = if headers_size > 0 { + match mem.read_bytes(caller.as_context(), headers_data as u32, headers_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + } + } else { + Vec::new() + }; + let pairs = decode_pairs(&headers_bytes).unwrap_or_default(); + let map = pairs_to_header_map(&pairs); + let ctx_id = caller.data().effective_context; + if let Some(ctx) = caller.data_mut().contexts.get_mut(&ctx_id) { + ctx.local_response = Some(LocalResponse { + status: status as u16, + headers: map, + body: Bytes::from(body), + }); + debug!(target: "spacegate_plugin_wasm", ctx_id, status, "guest send_local_response captured"); + } else { + warn!(target: "spacegate_plugin_wasm", ctx_id, "send_local_response on unknown ctx"); + } + Status::Ok.as_i32() + }, + )?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// proxy_http_call(spec §HTTP calls) +// ───────────────────────────────────────────────────────── + +fn register_http_call(linker: &mut Linker, dispatch_tx: tokio::sync::mpsc::UnboundedSender<(u32, HttpCallResult)>) -> Result<(), wasmtime::Error> { + linker.func_wrap( + "env", + "proxy_http_call", + move |mut caller: Caller<'_, HostState>, + upstream_data: i32, + upstream_size: i32, + headers_data: i32, + headers_size: i32, + body_data: i32, + body_size: i32, + _trailers_data: i32, + _trailers_size: i32, + timeout_ms: i32, + return_token_ptr: i32| + -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let cluster = match mem.read_string_lossy(caller.as_context(), upstream_data as u32, upstream_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let headers_bytes = mem.read_bytes(caller.as_context(), headers_data as u32, headers_size as u32).unwrap_or_default(); + let body = if body_size > 0 { + if let Some(limit) = caller.data().shell_cfg.limits.max_body_bytes { + if body_size as usize > limit { + warn!(target: "spacegate_plugin_wasm", body_size, limit, "dispatch_http_call: request body exceeds max_body_bytes"); + return Status::BadArgument.as_i32(); + } + } + mem.read_bytes(caller.as_context(), body_data as u32, body_size as u32).unwrap_or_default() + } else { + Vec::new() + }; + let pairs = match decode_pairs(&headers_bytes) { + Some(p) => p, + None => return Status::SerializationFailure.as_i32(), + }; + let mut method = "GET".to_string(); + let mut path = "/".to_string(); + let mut authority = String::new(); + let mut others = Vec::with_capacity(pairs.len()); + for (k, v) in &pairs { + let key_str = String::from_utf8_lossy(k); + let val_str = String::from_utf8_lossy(v).into_owned(); + match key_str.as_ref() { + ":method" => method = val_str, + ":path" => path = val_str, + ":authority" => authority = val_str, + ":scheme" => {} + _ => others.push((key_str.to_string(), val_str)), + } + } + if method.is_empty() || path.is_empty() { + return Status::BadArgument.as_i32(); + } + let st = caller.data(); + let base = st.shell_cfg.resolve_cluster(&cluster).or_else(|| if !authority.is_empty() { Some(format!("http://{authority}")) } else { None }); + let Some(base) = base else { + warn!(target: "spacegate_plugin_wasm", cluster = %cluster, "dispatch_http_call: cluster not configured"); + return Status::BadArgument.as_i32(); + }; + if let Some(limit) = caller.data().shell_cfg.limits.max_pending_calls { + if caller.data().pending_calls.len() >= limit { + warn!( + target: "spacegate_plugin_wasm", + pending_calls = caller.data().pending_calls.len(), + limit, + "dispatch_http_call: max_pending_calls reached" + ); + return Status::InternalFailure.as_i32(); + } + } + let url = format!("{}{}", base.trim_end_matches('/'), path); + let token = caller.data_mut().next_dispatch_token(); + let source_ctx = caller.data().effective_context; + caller.data_mut().pending_calls.insert( + token, + crate::host_state::PendingCall { + waker: None, + source_context_id: source_ctx, + }, + ); + let client = caller.data().http_client.clone(); + let max_body_bytes = caller.data().shell_cfg.limits.max_body_bytes; + let timeout = Duration::from_millis(timeout_ms.max(1) as u64); + let tx = dispatch_tx.clone(); + tokio::spawn(async move { + debug!(target: "spacegate_plugin_wasm", %url, %method, "dispatch_http_call begin"); + let parsed_method = method.parse::().unwrap_or(reqwest::Method::GET); + let mut req = client.request(parsed_method, &url); + for (k, v) in others { + if k.starts_with(':') { + continue; + } + if let (Ok(name), Ok(val)) = (HeaderName::try_from(k.as_str()), HeaderValue::try_from(v.as_str())) { + req = req.header(name, val); + } + } + if !body.is_empty() { + req = req.body(body); + } + req = req.timeout(timeout); + let result = match req.send().await { + Ok(resp) => { + let status = resp.status().as_u16(); + let status_message = resp.status().canonical_reason().unwrap_or("").to_string(); + let mut hdrs = HeaderMap::new(); + for (k, v) in resp.headers().iter() { + if let (Ok(name), Ok(val)) = (HeaderName::try_from(k.as_str()), HeaderValue::from_bytes(v.as_bytes())) { + hdrs.append(name, val); + } + } + let body_bytes = resp.bytes().await.unwrap_or_default(); + if let Some(limit) = max_body_bytes { + if body_bytes.len() > limit { + warn!( + target: "spacegate_plugin_wasm", + %url, + body_len = body_bytes.len(), + limit, + "dispatch_http_call response exceeds max_body_bytes" + ); + HttpCallResult { + status: 0, + status_message: format!("dispatch_http_call response body too large: {} > {limit}", body_bytes.len()), + headers: HeaderMap::new(), + body: Bytes::new(), + } + } else { + HttpCallResult { + status, + status_message, + headers: hdrs, + body: body_bytes, + } + } + } else { + HttpCallResult { + status, + status_message, + headers: hdrs, + body: body_bytes, + } + } + } + Err(e) => { + warn!(target: "spacegate_plugin_wasm", %url, error = %e, "dispatch_http_call failed"); + HttpCallResult { + status: 0, + status_message: format!("{e}"), + headers: HeaderMap::new(), + body: Bytes::new(), + } + } + }; + debug!(target: "spacegate_plugin_wasm", token, status = result.status, "dispatch_http_call done"); + let _ = tx.send((token, result)); + }); + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u32(caller.as_context_mut(), return_token_ptr as u32, token).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + info!(target: "spacegate_plugin_wasm", token, cluster = %cluster, "dispatch_http_call enqueued"); + Status::Ok.as_i32() + }, + )?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Shared Data / Shared Queues(spec §Shared Key-Value Store §Shared Queues) +// ───────────────────────────────────────────────────────── + +fn register_shared_data_and_queue(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_get_shared_data(*k, k_size, **v, *v_size, *cas) -> Status + linker.func_wrap( + "env", + "proxy_get_shared_data", + |mut caller: Caller<'_, HostState>, k_ptr: i32, k_size: i32, v_data_ptr: i32, v_size_ptr: i32, cas_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_bytes(caller.as_context(), k_ptr as u32, k_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let Some((value, cas)) = shared_data_get(&key) else { + return Status::NotFound.as_i32(); + }; + if let Err(s) = write_alloc_pair(&mut caller, &value, v_data_ptr as u32, v_size_ptr as u32) { + return s.as_i32(); + } + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u32(caller.as_context_mut(), cas_ptr as u32, cas).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + }, + )?; + + // proxy_set_shared_data(*k, k_size, *v, v_size, cas) -> Status + linker.func_wrap( + "env", + "proxy_set_shared_data", + |mut caller: Caller<'_, HostState>, k_ptr: i32, k_size: i32, v_ptr: i32, v_size: i32, cas: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let key = match mem.read_bytes(caller.as_context(), k_ptr as u32, k_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let value = if v_size > 0 { + match mem.read_bytes(caller.as_context(), v_ptr as u32, v_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + } + } else { + Vec::new() + }; + match shared_data_set(&key, &value, cas as u32) { + SharedDataSetResult::Ok => Status::Ok.as_i32(), + SharedDataSetResult::CasMismatch => Status::CasMismatch.as_i32(), + } + }, + )?; + + // proxy_register_shared_queue(*n, n_size, *return_qid) -> Status + linker.func_wrap( + "env", + "proxy_register_shared_queue", + |mut caller: Caller<'_, HostState>, n_ptr: i32, n_size: i32, return_qid_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let name = match mem.read_string_lossy(caller.as_context(), n_ptr as u32, n_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let vm_id = caller.data().plugin_vm_id.clone(); + let qid = queue_register(&vm_id, &name); + if mem.write_u32(caller.as_context_mut(), return_qid_ptr as u32, qid).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + }, + )?; + + // proxy_resolve_shared_queue(*vm_id, vm_id_size, *n, n_size, *return_qid) -> Status + linker.func_wrap( + "env", + "proxy_resolve_shared_queue", + |mut caller: Caller<'_, HostState>, vid_ptr: i32, vid_size: i32, n_ptr: i32, n_size: i32, return_qid_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let vid = match mem.read_string_lossy(caller.as_context(), vid_ptr as u32, vid_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let name = match mem.read_string_lossy(caller.as_context(), n_ptr as u32, n_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let Some(qid) = queue_resolve(&vid, &name) else { + return Status::NotFound.as_i32(); + }; + if mem.write_u32(caller.as_context_mut(), return_qid_ptr as u32, qid).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + }, + )?; + + // proxy_enqueue_shared_queue(qid, *v, v_size) -> Status + linker.func_wrap( + "env", + "proxy_enqueue_shared_queue", + |mut caller: Caller<'_, HostState>, qid: i32, v_ptr: i32, v_size: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let bytes = if v_size > 0 { + match mem.read_bytes(caller.as_context(), v_ptr as u32, v_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + } + } else { + Vec::new() + }; + match queue_enqueue(qid as u32, &bytes) { + QueueOpResult::Ok => Status::Ok.as_i32(), + QueueOpResult::NotFound => Status::NotFound.as_i32(), + QueueOpResult::Empty => Status::Empty.as_i32(), + } + }, + )?; + + // proxy_dequeue_shared_queue(qid, **v, *v_size) -> Status + linker.func_wrap( + "env", + "proxy_dequeue_shared_queue", + |mut caller: Caller<'_, HostState>, qid: i32, v_data_ptr: i32, v_size_ptr: i32| -> i32 { + match queue_dequeue(qid as u32) { + (QueueOpResult::Ok, Some(bytes)) => match write_alloc_pair(&mut caller, &bytes, v_data_ptr as u32, v_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + }, + (QueueOpResult::NotFound, _) => Status::NotFound.as_i32(), + _ => Status::Empty.as_i32(), + } + }, + )?; + + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Metrics(spec §Metrics) +// ───────────────────────────────────────────────────────── + +fn register_metrics(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_define_metric(metric_type, *name, name_size, *return_mid) -> Status + linker.func_wrap( + "env", + "proxy_define_metric", + |mut caller: Caller<'_, HostState>, metric_type: i32, name_ptr: i32, name_size: i32, return_mid_ptr: i32| -> i32 { + let Some(kind) = MetricType::from_i32(metric_type) else { + return Status::BadArgument.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let name = match mem.read_string_lossy(caller.as_context(), name_ptr as u32, name_size as u32) { + Ok(s) => s, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let id = metric_define(kind, &name); + if mem.write_u32(caller.as_context_mut(), return_mid_ptr as u32, id).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + }, + )?; + + // proxy_record_metric(mid, value: u64) -> Status + linker.func_wrap("env", "proxy_record_metric", |_caller: Caller<'_, HostState>, mid: i32, value: i64| -> i32 { + match metric_record(mid as u32, value as u64) { + MetricOpResult::Ok => Status::Ok.as_i32(), + MetricOpResult::NotFound => Status::NotFound.as_i32(), + MetricOpResult::BadArgument => Status::BadArgument.as_i32(), + } + })?; + + // proxy_increment_metric(mid, delta: i64) -> Status + linker.func_wrap("env", "proxy_increment_metric", |_caller: Caller<'_, HostState>, mid: i32, delta: i64| -> i32 { + match metric_increment(mid as u32, delta) { + MetricOpResult::Ok => Status::Ok.as_i32(), + MetricOpResult::NotFound => Status::NotFound.as_i32(), + MetricOpResult::BadArgument => Status::BadArgument.as_i32(), + } + })?; + + // proxy_get_metric(mid, *return_value) -> Status + linker.func_wrap("env", "proxy_get_metric", |mut caller: Caller<'_, HostState>, mid: i32, return_ptr: i32| -> i32 { + let Some(v) = metric_get(mid as u32) else { + return Status::NotFound.as_i32(); + }; + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + if mem.write_u64(caller.as_context_mut(), return_ptr as u32, v).is_err() { + return Status::InvalidMemoryAccess.as_i32(); + } + Status::Ok.as_i32() + })?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Properties(spec §Properties) +// ───────────────────────────────────────────────────────── + +fn register_property(linker: &mut Linker) -> Result<(), wasmtime::Error> { + // proxy_get_property(*path, path_size, **v, *v_size) -> Status + linker.func_wrap( + "env", + "proxy_get_property", + |mut caller: Caller<'_, HostState>, path_ptr: i32, path_size: i32, return_data_ptr: i32, return_size_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let raw = match mem.read_bytes(caller.as_context(), path_ptr as u32, path_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let segments = decode_property_path(&raw); + if segments.is_empty() { + return Status::NotFound.as_i32(); + } + // 1. 用户通过 proxy_set_property 写入的优先(spec 允许 host 自行决定) + let canonical_key = canonicalize_path(&segments); + if let Some(v) = caller.data().user_properties.get(&canonical_key).cloned() { + return match write_alloc_pair(&mut caller, &v, return_data_ptr as u32, return_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + }; + } + // 2. well-known + let value = resolve_well_known(caller.data(), &segments); + let Some(value) = value else { + return Status::NotFound.as_i32(); + }; + match write_alloc_pair(&mut caller, &value, return_data_ptr as u32, return_size_ptr as u32) { + Ok(()) => Status::Ok.as_i32(), + Err(s) => s.as_i32(), + } + }, + )?; + + // proxy_set_property(*path, path_size, *v, v_size) -> Status + linker.func_wrap( + "env", + "proxy_set_property", + |mut caller: Caller<'_, HostState>, path_ptr: i32, path_size: i32, v_ptr: i32, v_size: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let raw_path = match mem.read_bytes(caller.as_context(), path_ptr as u32, path_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + }; + let segments = decode_property_path(&raw_path); + if segments.is_empty() { + return Status::BadArgument.as_i32(); + } + let canonical_key = canonicalize_path(&segments); + let value = if v_size > 0 { + match mem.read_bytes(caller.as_context(), v_ptr as u32, v_size as u32) { + Ok(b) => b, + Err(_) => return Status::InvalidMemoryAccess.as_i32(), + } + } else { + Vec::new() + }; + caller.data_mut().user_properties.insert(canonical_key, value); + Status::Ok.as_i32() + }, + )?; + Ok(()) +} + +fn canonicalize_path(segments: &[&[u8]]) -> Vec { + let mut out = Vec::new(); + for (i, s) in segments.iter().enumerate() { + if i > 0 { + out.push(0); + } + out.extend_from_slice(s); + } + out +} + +/// spec §Properties §Well-known properties:内置覆盖最常用的几个,其它返回 None。 +fn resolve_well_known(state: &HostState, segments: &[&[u8]]) -> Option> { + let path_str: Vec<&str> = segments.iter().filter_map(|s| std::str::from_utf8(s).ok()).collect(); + let joined = path_str.join("."); + match joined.as_str() { + // Proxy-Wasm + "plugin_name" => Some(state.plugin_name.as_bytes().to_vec()), + "plugin_root_id" => Some(state.plugin_root_id.as_bytes().to_vec()), + "plugin_vm_id" => Some(state.plugin_vm_id.as_bytes().to_vec()), + // Downstream connection + "source.address" => state.source_addr.map(|s| s.to_string().into_bytes()).or_else(|| { + // 退路:从 :authority 推导 + state.current_context().map(|c| c.request_pseudo.authority.clone().into_bytes()).filter(|b| !b.is_empty()) + }), + "source.port" => state.source_addr.map(|s| s.port().to_string().into_bytes()), + "destination.address" => state.destination_addr.map(|s| s.to_string().into_bytes()), + "destination.port" => state.destination_addr.map(|s| s.port().to_string().into_bytes()), + // HTTP request + "request.protocol" => state.current_context().map(|c| c.request_protocol.as_bytes().to_vec()).filter(|b| !b.is_empty()), + "request.size" => state.current_context().map(|c| c.request_size.to_string().into_bytes()), + "request.total_size" => state.current_context().map(|c| { + let hdr_bytes = approx_header_bytes(&c.request_headers); + (c.request_size + hdr_bytes as u64).to_string().into_bytes() + }), + // HTTP response + "response.size" => state.current_context().map(|c| c.response_size.to_string().into_bytes()), + "response.total_size" => state.current_context().map(|c| { + let hdr_bytes = approx_header_bytes(&c.response_headers); + (c.response_size + hdr_bytes as u64).to_string().into_bytes() + }), + _ => None, + } +} + +fn approx_header_bytes(map: &HeaderMap) -> usize { + let mut sum = 0; + for (k, v) in map.iter() { + sum += k.as_str().len() + 2 + v.as_bytes().len() + 2; + } + sum +} + +// ───────────────────────────────────────────────────────── +// gRPC(spec §gRPC calls)→ 全部返回 UNIMPLEMENTED +// ───────────────────────────────────────────────────────── + +fn register_grpc_unimplemented(linker: &mut Linker) -> Result<(), wasmtime::Error> { + linker.func_wrap( + "env", + "proxy_grpc_call", + |_caller: Caller<'_, HostState>, _a: i32, _b: i32, _c: i32, _d: i32, _e: i32, _f: i32, _g: i32, _h: i32, _i: i32, _j: i32, _k: i32, _l: i32| -> i32 { + Status::Unimplemented.as_i32() + }, + )?; + linker.func_wrap( + "env", + "proxy_grpc_stream", + |_caller: Caller<'_, HostState>, _a: i32, _b: i32, _c: i32, _d: i32, _e: i32, _f: i32, _g: i32, _h: i32, _i: i32| -> i32 { Status::Unimplemented.as_i32() }, + )?; + linker.func_wrap("env", "proxy_grpc_cancel", |_caller: Caller<'_, HostState>, _t: i32| -> i32 { + Status::Unimplemented.as_i32() + })?; + linker.func_wrap("env", "proxy_grpc_close", |_caller: Caller<'_, HostState>, _t: i32| -> i32 { + Status::Unimplemented.as_i32() + })?; + linker.func_wrap("env", "proxy_grpc_send", |_caller: Caller<'_, HostState>, _t: i32, _m: i32, _ms: i32, _eos: i32| -> i32 { + Status::Unimplemented.as_i32() + })?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// Foreign function(spec §FFI)→ 没有注册表 → NotFound +// ───────────────────────────────────────────────────────── + +fn register_foreign_function(linker: &mut Linker) -> Result<(), wasmtime::Error> { + linker.func_wrap( + "env", + "proxy_call_foreign_function", + |_caller: Caller<'_, HostState>, _a: i32, _b: i32, _c: i32, _d: i32, _e: i32, _f: i32| -> i32 { Status::NotFound.as_i32() }, + )?; + Ok(()) +} + +// ───────────────────────────────────────────────────────── +// 辅助:alloc + 写 (data, size) pair;lookup / mutate / collect +// ───────────────────────────────────────────────────────── + +/// 在 guest 侧分配一段内存、写入 `payload`、然后把 (guest_ptr, len) 回写到 +/// `return_data_ptr` / `return_size_ptr`。 +/// +/// 空 payload:写 (0, 0)。 +fn write_alloc_pair(caller: &mut Caller<'_, HostState>, payload: &[u8], return_data_ptr: u32, return_size_ptr: u32) -> Result<(), Status> { + let mem = MemoryHelper::from_caller(caller).map_err(|_| Status::InvalidMemoryAccess)?; + if payload.is_empty() { + mem.write_u32(caller.as_context_mut(), return_data_ptr, 0).map_err(|_| Status::InvalidMemoryAccess)?; + mem.write_u32(caller.as_context_mut(), return_size_ptr, 0).map_err(|_| Status::InvalidMemoryAccess)?; + return Ok(()); + } + let alloc = caller.data().alloc.clone().ok_or(Status::InternalFailure)?; + let guest_ptr = alloc.call(&mut *caller, payload.len() as u32).map_err(|_| Status::InternalFailure)?; + let mem = MemoryHelper::from_caller(caller).map_err(|_| Status::InvalidMemoryAccess)?; + mem.write_bytes(caller.as_context_mut(), guest_ptr, payload).map_err(|_| Status::InvalidMemoryAccess)?; + mem.write_u32(caller.as_context_mut(), return_data_ptr, guest_ptr).map_err(|_| Status::InvalidMemoryAccess)?; + mem.write_u32(caller.as_context_mut(), return_size_ptr, payload.len() as u32).map_err(|_| Status::InvalidMemoryAccess)?; + Ok(()) +} + +fn lookup_header(state: &HostState, mt: MapType, key_lower: &str) -> Option { + let ctx = state.current_context()?; + let map = match mt { + MapType::HttpRequestHeaders => &ctx.request_headers, + MapType::HttpRequestTrailers => &ctx.request_trailers, + MapType::HttpResponseHeaders => &ctx.response_headers, + MapType::HttpResponseTrailers => &ctx.response_trailers, + MapType::HttpCallResponseHeaders => &ctx.last_call_headers, + MapType::HttpCallResponseTrailers => &ctx.last_call_trailers, + MapType::GrpcCallInitialMetadata | MapType::GrpcCallTrailingMetadata => return None, + }; + if let Some(value) = pseudo_lookup(ctx, mt, key_lower) { + return Some(value); + } + let name = HeaderName::try_from(key_lower).ok()?; + let val = map.get(&name)?; + val.to_str().ok().map(|s| s.to_string()) +} + +fn pseudo_lookup(ctx: &crate::host_state::RequestContext, mt: MapType, key: &str) -> Option { + match (mt, key) { + (MapType::HttpRequestHeaders, ":method") => Some(ctx.request_pseudo.method.clone()), + (MapType::HttpRequestHeaders, ":path") => Some(ctx.request_pseudo.path.clone()), + (MapType::HttpRequestHeaders, ":authority") => Some(ctx.request_pseudo.authority.clone()), + (MapType::HttpRequestHeaders, ":scheme") => Some(ctx.request_pseudo.scheme.clone()), + (MapType::HttpResponseHeaders, ":status") => ctx.response_status.map(|s| s.to_string()), + (MapType::HttpCallResponseHeaders, ":status") => { + if ctx.last_call_status > 0 { + Some(ctx.last_call_status.to_string()) + } else { + None + } + } + _ => None, + } +} + +enum HeaderMutation { + Add(String), + Replace(String), + Remove, +} + +fn mutate_header(state: &mut HostState, mt: MapType, key: &str, m: HeaderMutation) -> i32 { + if matches!(mt, MapType::GrpcCallInitialMetadata | MapType::GrpcCallTrailingMetadata) { + return Status::Unimplemented.as_i32(); + } + let ctx_id = state.effective_context; + let Some(ctx) = state.contexts.get_mut(&ctx_id) else { + return Status::NotFound.as_i32(); + }; + if key.starts_with(':') { + let new_val = match &m { + HeaderMutation::Add(v) | HeaderMutation::Replace(v) => Some(v.clone()), + HeaderMutation::Remove => None, + }; + match (mt, key) { + (MapType::HttpRequestHeaders, ":path") => { + ctx.request_pseudo.path = new_val.unwrap_or_default(); + } + (MapType::HttpRequestHeaders, ":method") => { + ctx.request_pseudo.method = new_val.unwrap_or_default(); + } + (MapType::HttpRequestHeaders, ":authority") => { + ctx.request_pseudo.authority = new_val.unwrap_or_default(); + } + (MapType::HttpRequestHeaders, ":scheme") => { + ctx.request_pseudo.scheme = new_val.unwrap_or_default(); + } + (MapType::HttpResponseHeaders, ":status") => { + if let Some(v) = new_val { + ctx.response_status = v.parse().ok(); + } + } + _ => {} + } + return Status::Ok.as_i32(); + } + let Ok(name) = HeaderName::try_from(key) else { + return Status::BadArgument.as_i32(); + }; + let map = match mt { + MapType::HttpRequestHeaders => &mut ctx.request_headers, + MapType::HttpRequestTrailers => &mut ctx.request_trailers, + MapType::HttpResponseHeaders => &mut ctx.response_headers, + MapType::HttpResponseTrailers => &mut ctx.response_trailers, + MapType::HttpCallResponseHeaders => &mut ctx.last_call_headers, + MapType::HttpCallResponseTrailers => &mut ctx.last_call_trailers, + MapType::GrpcCallInitialMetadata | MapType::GrpcCallTrailingMetadata => return Status::Unimplemented.as_i32(), + }; + match m { + HeaderMutation::Add(v) => { + if let Ok(val) = HeaderValue::try_from(v) { + map.append(name, val); + } + } + HeaderMutation::Replace(v) => { + if let Ok(val) = HeaderValue::try_from(v) { + map.insert(name, val); + } + } + HeaderMutation::Remove => { + map.remove(name); + } + } + Status::Ok.as_i32() +} + +fn collect_pairs(state: &HostState, mt: MapType) -> Vec<(Vec, Vec)> { + let Some(ctx) = state.current_context() else { + return Vec::new(); + }; + let map = match mt { + MapType::HttpRequestHeaders => &ctx.request_headers, + MapType::HttpRequestTrailers => &ctx.request_trailers, + MapType::HttpResponseHeaders => &ctx.response_headers, + MapType::HttpResponseTrailers => &ctx.response_trailers, + MapType::HttpCallResponseHeaders => &ctx.last_call_headers, + MapType::HttpCallResponseTrailers => &ctx.last_call_trailers, + MapType::GrpcCallInitialMetadata | MapType::GrpcCallTrailingMetadata => return Vec::new(), + }; + let mut out: Vec<(Vec, Vec)> = Vec::with_capacity(map.len() + 4); + match mt { + MapType::HttpRequestHeaders => { + if !ctx.request_pseudo.method.is_empty() { + out.push((b":method".to_vec(), ctx.request_pseudo.method.as_bytes().to_vec())); + } + if !ctx.request_pseudo.path.is_empty() { + out.push((b":path".to_vec(), ctx.request_pseudo.path.as_bytes().to_vec())); + } + if !ctx.request_pseudo.authority.is_empty() { + out.push((b":authority".to_vec(), ctx.request_pseudo.authority.as_bytes().to_vec())); + } + if !ctx.request_pseudo.scheme.is_empty() { + out.push((b":scheme".to_vec(), ctx.request_pseudo.scheme.as_bytes().to_vec())); + } + } + MapType::HttpResponseHeaders => { + if let Some(s) = ctx.response_status { + out.push((b":status".to_vec(), s.to_string().into_bytes())); + } + } + MapType::HttpCallResponseHeaders => { + if ctx.last_call_status > 0 { + out.push((b":status".to_vec(), ctx.last_call_status.to_string().into_bytes())); + } + } + _ => {} + } + for (k, v) in map.iter() { + out.push((k.as_str().as_bytes().to_vec(), v.as_bytes().to_vec())); + } + out +} + +fn pairs_to_header_map(pairs: &[(Vec, Vec)]) -> HeaderMap { + let mut out = HeaderMap::new(); + for (k, v) in pairs { + let Ok(key) = HeaderName::try_from(k.as_slice()) else { + continue; + }; + let Ok(val) = HeaderValue::from_bytes(v.as_slice()) else { + continue; + }; + out.append(key, val); + } + out +} + +fn replace_map(state: &mut HostState, mt: MapType, new_map: HeaderMap) { + let ctx_id = state.effective_context; + let Some(ctx) = state.contexts.get_mut(&ctx_id) else { + return; + }; + match mt { + MapType::HttpRequestHeaders => ctx.request_headers = new_map, + MapType::HttpRequestTrailers => ctx.request_trailers = new_map, + MapType::HttpResponseHeaders => ctx.response_headers = new_map, + MapType::HttpResponseTrailers => ctx.response_trailers = new_map, + _ => {} + } +} diff --git a/crates/plugin-wasm/src/host_state.rs b/crates/plugin-wasm/src/host_state.rs new file mode 100644 index 00000000..63f315a6 --- /dev/null +++ b/crates/plugin-wasm/src/host_state.rs @@ -0,0 +1,220 @@ +//! 传给 `wasmtime::Store` 的宿主状态。 +//! +//! - 顶层 [`HostState`] 承载:进程级 reqwest 客户端、shell 配置、序列化后的 plugin_config 字节、 +//! memory / 分配器 export、所有 HTTP 上下文、未完结的 `proxy_http_call` 句柄等。 +//! - 每个 HTTP 请求建一个 [`RequestContext`],由 `vm.rs` 在调 `proxy_on_*` 钩子前后维护。 +//! - host fn 通过 `caller.data() / data_mut()` 读写 `HostState`,并以 +//! `effective_context` 字段定位「当前是哪个上下文」(spec §Effective context changes, +//! guest 通过 `proxy_set_effective_context` 切换)。 + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; + +use bytes::Bytes; +use http::HeaderMap; +use wasmtime::{Memory, ResourceLimiter, TypedFunc}; + +use crate::config::WasmPluginShellConfig; + +/// 约定的 root context id:proxy-wasm 默认从 1 开始。 +pub const ROOT_CONTEXT_ID: u32 = 1; + +/// HTTP 上下文在生命周期中处于的阶段(vm.rs 调钩子时打标记,host fn 据此判断)。 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ContextStage { + #[default] + Init, + RequestHeaders, + RequestBody, + RequestTrailers, + ResponseHeaders, + ResponseBody, + ResponseTrailers, + Log, +} + +/// HTTP/2 风格的伪头(`:method` 等)。proxy-wasm guest 通过 header_map 拿到它们, +/// 我们额外用一个结构体专门承载,便于在 `inner.call` 前重建 `Uri`。 +#[derive(Debug, Clone, Default)] +pub struct PseudoHeaders { + pub method: String, + pub path: String, + pub authority: String, + pub scheme: String, +} + +/// guest 调 `proxy_send_local_response` 时 host 捕获的结构。 +#[derive(Debug)] +pub struct LocalResponse { + pub status: u16, + pub headers: HeaderMap, + pub body: Bytes, +} + +/// `proxy_http_call` 的异步结果:spawn 出去的 reqwest 任务通过 channel 把它送回 Vm。 +#[derive(Debug, Default)] +pub struct HttpCallResult { + pub status: u16, + pub status_message: String, + pub headers: HeaderMap, + pub body: Bytes, +} + +/// 一次未完结的 `proxy_http_call`。`source_context_id` 指明该 token 是哪个 ctx 发起的, +/// 这样 Vm 状态机在拿到结果时能恢复到正确的 effective_context 再调 `proxy_on_http_call_response`。 +#[derive(Debug)] +pub struct PendingCall { + #[allow(dead_code)] + pub waker: Option, + pub source_context_id: u32, +} + +#[derive(Debug, Clone)] +pub struct HostResourceLimiter { + max_memory_bytes: Option, +} + +impl HostResourceLimiter { + pub fn new(shell_cfg: &WasmPluginShellConfig) -> Self { + Self { + max_memory_bytes: shell_cfg.max_memory_bytes(), + } + } +} + +impl ResourceLimiter for HostResourceLimiter { + fn memory_growing(&mut self, _current: usize, desired: usize, _maximum: Option) -> wasmtime::Result { + Ok(self.max_memory_bytes.map(|max| desired <= max).unwrap_or(true)) + } + + fn table_growing(&mut self, _current: u32, _desired: u32, _maximum: Option) -> wasmtime::Result { + Ok(true) + } +} + +/// 单个 HTTP 请求的所有状态(请求/响应头 / body / 上次 dispatch 结果 / 本地响应 / 短路标记)。 +#[derive(Debug, Default)] +pub struct RequestContext { + pub parent_id: u32, + pub stage: ContextStage, + pub request_pseudo: PseudoHeaders, + pub request_headers: HeaderMap, + pub request_trailers: HeaderMap, + pub request_body: Option, + pub response_status: Option, + pub response_status_message: String, + pub response_headers: HeaderMap, + pub response_trailers: HeaderMap, + pub response_body: Option, + /// 上次 `proxy_http_call` 回调时由 host 注入;guest 通过 + /// `get_http_call_response_*` 读它。 + pub last_call_headers: HeaderMap, + pub last_call_trailers: HeaderMap, + pub last_call_body: Bytes, + /// 最近一次 dispatch_http_call 返回的状态码(hai 用 `:status` 伪头读取)。 + pub last_call_status: u16, + pub last_call_status_message: String, + /// guest 显式 `resume_http_request()` 后置 true;Vm 退出 Pause 等待循环。 + pub continue_requested: bool, + /// guest 调 `send_local_response` 后写入;Vm 据此短路返回。 + pub local_response: Option, + /// HTTP 协议版本字符串(spec well-known property `request.protocol`)。 + pub request_protocol: String, + /// 收到的 request body 已知字节数。 + pub request_size: u64, + /// 输出的 response body 已知字节数。 + pub response_size: u64, + /// 通过 `proxy_done` 显式标记的 done 阶段(spec §proxy_done / §proxy_on_done)。 + pub done_marker: bool, + /// guest 上一次 `proxy_on_done` 返回值;false 表示要等 `proxy_done` 才能进 on_log/on_delete。 + pub awaiting_done: bool, +} + +/// 进程内传给 wasmtime `Store` 的状态。生命周期与一次 Vm 实例一致。 +/// +/// 不 derive Debug 因为 `TypedFunc` 不实现 Debug。 +pub struct HostState { + pub shell_cfg: Arc, + /// guest `proxy_on_configure` 读取的字节(来自 shell_cfg.plugin_config 序列化)。 + pub configuration: Vec, + /// guest 导出的线性内存(vm.rs 实例化完成后填)。 + pub memory: Option, + /// guest 导出 `proxy_on_memory_allocate(size) -> ptr` 或 deprecated `malloc(size) -> ptr`。 + pub alloc: Option>, + pub root_context_id: u32, + /// 当前 hostcall 关联的上下文 id(由 vm.rs 在每次钩子前设置, + /// 也可被 guest 的 `proxy_set_effective_context` 覆盖)。 + pub effective_context: u32, + pub contexts: HashMap, + /// guest 调用 `proxy_set_tick_period_milliseconds` 后存这里。 + /// `WasmPluginShell` 的后台 tick 任务 50ms 颗粒度地轮询本字段,到点 → `Vm::tick()`。 + pub tick_period_ms: Option, + /// 未完结的 dispatch_http_call 句柄表。 + pub pending_calls: HashMap, + /// dispatch token 单调递增计数器。 + next_token: u32, + /// host 端 reqwest 客户端:所有 dispatch_http_call 复用一个,免去握手开销。 + pub http_client: reqwest::Client, + /// 用户通过 `proxy_set_property` 设置的自定义属性(key = `\0` 分割的 path 字节)。 + pub user_properties: HashMap, Vec>, + /// 客户端 socket 地址(spec well-known property `source.address` / `source.port`)。 + pub source_addr: Option, + /// 服务端 socket 地址(spec well-known property `destination.address` / `destination.port`)。 + pub destination_addr: Option, + /// 插件标识(spec well-known property `plugin_name` / `plugin_root_id` / `plugin_vm_id`)。 + pub plugin_name: String, + pub plugin_root_id: String, + pub plugin_vm_id: String, + /// Wasmtime 资源限制器:控制单 VM 线性内存增长。 + pub resource_limiter: HostResourceLimiter, +} + +impl HostState { + pub fn new(shell_cfg: Arc) -> Self { + let configuration = shell_cfg.configuration_bytes(); + let http_client = reqwest::Client::builder().pool_max_idle_per_host(8).build().unwrap_or_else(|_| reqwest::Client::new()); + let plugin_name = shell_cfg.plugin_name.clone(); + let plugin_root_id = shell_cfg.plugin_root_id.clone(); + let plugin_vm_id = shell_cfg.plugin_vm_id.clone(); + let resource_limiter = HostResourceLimiter::new(&shell_cfg); + Self { + shell_cfg, + configuration, + memory: None, + alloc: None, + root_context_id: ROOT_CONTEXT_ID, + effective_context: ROOT_CONTEXT_ID, + contexts: HashMap::new(), + tick_period_ms: None, + pending_calls: HashMap::new(), + next_token: 1, + http_client, + user_properties: HashMap::new(), + source_addr: None, + destination_addr: None, + plugin_name, + plugin_root_id, + plugin_vm_id, + resource_limiter, + } + } + + /// 取当前生效的 ctx 的不可变引用(host fn 大量使用)。 + pub fn current_context(&self) -> Option<&RequestContext> { + self.contexts.get(&self.effective_context) + } + + /// 取当前生效的 ctx 的可变引用。 + #[allow(dead_code)] + pub fn current_context_mut(&mut self) -> Option<&mut RequestContext> { + self.contexts.get_mut(&self.effective_context) + } + + /// 分配下一个 dispatch_http_call token;约定 0 保留,token 从 1 开始单调递增。 + pub fn next_dispatch_token(&mut self) -> u32 { + let t = self.next_token; + self.next_token = self.next_token.wrapping_add(1).max(1); + t + } +} diff --git a/crates/plugin-wasm/src/lib.rs b/crates/plugin-wasm/src/lib.rs new file mode 100644 index 00000000..7fa11892 --- /dev/null +++ b/crates/plugin-wasm/src/lib.rs @@ -0,0 +1,94 @@ +//! SpaceGate **proxy-wasm (wasmtime) 宿主** crate。 +//! +//! **集成方式**:不要从 `spacegate-plugin` 依赖本 crate(会形成循环依赖),应启用 `spacegate-shell` +//! 的 `plugin-wasm` feature;`spacegate_shell::startup` 会在网关启动时调用 [`register`]。 +//! +//! # 与 [proxy-wasm/spec v0.2.1](https://github.com/proxy-wasm/spec) 的覆盖情况 +//! +//! ## Host functions (env) +//! +//! - **Integration / Memory management**:guest 导出 `_initialize` 优先,否则回退 `_start`; +//! allocator 优先 `proxy_on_memory_allocate`,否则回退 `malloc`。 +//! - **Logging**:`proxy_log` / `proxy_get_log_level` 完整实现(host tracing 级别映射)。 +//! - **Clocks**:`proxy_get_current_time_nanoseconds` + `wasi_snapshot_preview1.clock_time_get`。 +//! - **Timers**:`proxy_set_tick_period_milliseconds` 完整生效;`shell.rs` 为每个 Vm 起一条 50ms 颗粒度的 +//! 后台 tokio 任务,到点 → `Vm::tick()` → guest `proxy_on_tick`。这要求 `Plugin::create` +//! 时存在 tokio runtime(spacegate-shell 的标准启动路径);无 runtime 时降级为不驱动。 +//! - **Randomness**:`wasi_snapshot_preview1.random_get` 走 `getrandom`(OS RNG)。 +//! - **Environment**:`environ_*` 按 spec 全部返回 0/SUCCESS。 +//! - **Buffers**:`proxy_get_buffer_bytes` / `proxy_get_buffer_status` 覆盖 +//! HttpRequestBody / HttpResponseBody / HttpCallResponseBody / Vm/PluginConfiguration; +//! TCP / gRPC / FFI args 类型按 spec 返回 NotFound。 +//! `proxy_set_buffer_bytes` 实现 prepend / append / inject / replace 语义。 +//! - **HTTP fields**:`proxy_get_header_map_size/pairs/value` + add/replace/remove + set_pairs, +//! 覆盖 Request/Response/Trailers + HttpCallResponse Headers/Trailers;GRPC metadata 类型 +//! 按 spec 返回 Unimplemented。 +//! - **HTTP streams**:`proxy_send_local_response` / `proxy_continue_stream` / +//! `proxy_close_stream`(TCP downstream/upstream 按 spec 返回 Unimplemented)。 +//! - **HTTP calls**:`proxy_http_call`(reqwest 异步、`:method`/`:path`/`:authority` 校验, +//! 按 cluster map 或 `:authority` 兜底解析 URL)。 +//! - **Shared K/V**:`proxy_get/set_shared_data` 进程级 RwLock,含 CAS 比对。 +//! - **Shared queues**:`proxy_register/resolve/enqueue/dequeue_shared_queue` 进程级 Mutex VecDeque。 +//! - **Metrics**:`proxy_define/record/increment/get_metric` 进程级 Counter/Gauge/Histogram。 +//! - **Properties**:`proxy_get/set_property` 支持 well-known +//! (`plugin_name`/`plugin_root_id`/`plugin_vm_id`/`source.address`+`source.port`/ +//! `destination.address`+`destination.port`/`request.protocol`/`request.size`/`request.total_size`/ +//! `response.size`/`response.total_size`) 与用户自定义。 +//! - **gRPC**:按 spec 全部 `Unimplemented`。 +//! - **Foreign function**:按 spec `NotFound`(无注册表)。 +//! - **`proxy_done` / `proxy_set_effective_context`**:完整实现。 +//! - **资源隔离**:`limits.max_memory_pages` 通过 Wasmtime `ResourceLimiter` 限制线性内存增长; +//! `limits.fuel_per_call` 和 `limits.epoch_timeout_millis` 在每次 guest hook 前重置执行预算; +//! `limits.max_body_bytes` 限制 host 物化 request/response/dispatch body 的大小; +//! `limits.max_pending_calls` 限制单 VM 未完成 `proxy_http_call` 数。 +//! +//! ## Guest callbacks driven by host +//! +//! - 启动:每个 Vm 执行 `_initialize`/`_start` → `proxy_on_context_create(root,0)` → +//! `proxy_on_vm_start` → `proxy_on_configure`(由 `WasmPluginShell::create` 按 `vm_pool_size` 执行)。 +//! - 每请求:`proxy_on_context_create(http_id, root)` → `proxy_on_request_headers` → +//! (可选)`proxy_on_request_body` → (可选)`proxy_on_request_trailers` → +//! `inner.call` → `proxy_on_response_headers` → (可选)`proxy_on_response_body` → +//! (可选)`proxy_on_response_trailers` → `proxy_on_log` → `proxy_on_done` → `proxy_on_delete`。 +//! `WasmPluginShell` 默认持有 1 个 `Arc>`;配置 `vm_pool_size > 1` +//! 后会创建多个独立 root Vm,并通过 try-lock + round-robin 调度请求。 +//! 配置 `wait_vm_pool_size > 0` 后,带 `X-RateLimit-Policy: wait` 的请求会进入单独 wait VM 池, +//! 其余 `abandon`/`queue`/未标记请求仍进入普通 VM 池,避免长等待请求拖住普通限流路径。 +//! 每个 VM slot 会记录并输出 inflight tracing 字段;guest trap / 资源隔离错误 / dispatch 通道异常后, +//! shell 会在原 slot 内尝试重建 VM,避免异常 Store 长期留在池内。 +//! - 后台 `proxy_on_tick`:`shell.rs` 为每个 Vm 起 50ms 颗粒度的 tokio 任务驱动;guest 通过 +//! `proxy_set_tick_period_milliseconds` 改周期。 +//! - 异步 `proxy_on_http_call_response`:在 Pause 状态机里 await `dispatch_rx` 后回调。 +//! - Pause/Continue:`proxy_continue_stream` 同步解除 Pause;多次 dispatch 可串联。 +//! - Local response:`proxy_send_local_response` 任意 hook 都能短路。 +//! +//! ## 已知尚未驱动的回调(按设计取舍) +//! +//! - TCP 流回调(`proxy_on_new_connection`/`*_downstream_*`/`*_upstream_*`): +//! spacegate-kernel 当前是 HTTP-only,TCP 插件层不支持。 +//! - `proxy_on_queue_ready` / `proxy_on_grpc_*` / `proxy_on_foreign_function`: +//! 对应 host fn 已为 spec 合规返回值;guest 侧回调不会被触发。 + +#![deny(clippy::unwrap_used, clippy::dbg_macro)] + +pub mod abi; +pub mod config; +pub mod engine; +pub mod error; +pub mod fetch; +pub mod host_fn; +pub mod host_state; +pub mod runtime; +pub mod shared; +pub mod shell; +pub mod vm; + +pub use config::WasmPluginShellConfig; +pub use shell::WasmPluginShell; + +use spacegate_plugin::PluginRepository; + +/// 向仓库注册 `wasm` 插件类型(需在 `register_prelude` 或启动逻辑中调用一次)。 +pub fn register(repo: &PluginRepository) { + repo.register::(); +} diff --git a/crates/plugin-wasm/src/runtime.rs b/crates/plugin-wasm/src/runtime.rs new file mode 100644 index 00000000..4b828412 --- /dev/null +++ b/crates/plugin-wasm/src/runtime.rs @@ -0,0 +1,77 @@ +//! WASM 模块编译与按 URL 缓存(减少同一 `url` 重复编译)。 + +use std::sync::Arc; + +use moka::sync::Cache; +use once_cell::sync::OnceCell; +use sha2::{Digest, Sha256}; +use wasmtime::Module; + +use crate::config::WasmPluginShellConfig; +use crate::engine::shared_engine; +use crate::error::WasmHostError; +use crate::fetch::fetch_wasm_bytes_sync_with_auth; + +/// 进程内模块缓存(键:wasm `url` 字符串)。 +pub struct WasmModuleCache { + engine: &'static wasmtime::Engine, + inner: Cache>, +} + +impl WasmModuleCache { + pub fn new(max_entries: u64) -> Self { + Self { + engine: shared_engine(), + inner: Cache::new(max_entries), + } + } + + /// 拉取字节并编译;命中缓存则直接返回 `Arc`。 + pub fn get_or_compile(&self, cfg: &WasmPluginShellConfig) -> Result, WasmHostError> { + let key = module_cache_key(cfg); + if cfg.use_cache { + if let Some(m) = self.inner.get(&key) { + return Ok(m); + } + } + let bytes = fetch_wasm_bytes_sync_with_auth(cfg.url.trim(), cfg.oci_auth.as_ref())?; + verify_sha256(&bytes, cfg.sha256.as_deref())?; + let m = Arc::new(Module::new(self.engine, &bytes)?); + if cfg.use_cache { + self.inner.insert(key, m.clone()); + } + Ok(m) + } +} + +fn module_cache_key(cfg: &WasmPluginShellConfig) -> String { + let mut key = cfg.module_cache_key.as_deref().filter(|s| !s.trim().is_empty()).unwrap_or_else(|| cfg.url.trim()).to_string(); + if let Some(sha256) = cfg.sha256.as_deref().filter(|s| !s.trim().is_empty()) { + key.push_str("#sha256="); + key.push_str(normalize_sha256(sha256)); + } + key +} + +fn normalize_sha256(s: &str) -> &str { + s.trim().strip_prefix("sha256:").unwrap_or_else(|| s.trim()) +} + +fn verify_sha256(bytes: &[u8], expected: Option<&str>) -> Result<(), WasmHostError> { + let Some(expected) = expected.map(normalize_sha256).filter(|s| !s.is_empty()) else { + return Ok(()); + }; + let actual = format!("{:x}", Sha256::digest(bytes)); + if actual.eq_ignore_ascii_case(expected) { + Ok(()) + } else { + Err(WasmHostError::Fetch(format!("sha256 mismatch: expected {expected}, actual {actual}",))) + } +} + +static CACHE: OnceCell = OnceCell::new(); + +/// 默认缓存(容量 64);多实例同 URL 共享编译结果。 +pub fn default_module_cache() -> &'static WasmModuleCache { + CACHE.get_or_init(|| WasmModuleCache::new(64)) +} diff --git a/crates/plugin-wasm/src/shared.rs b/crates/plugin-wasm/src/shared.rs new file mode 100644 index 00000000..c350a136 --- /dev/null +++ b/crates/plugin-wasm/src/shared.rs @@ -0,0 +1,292 @@ +//! 进程级共享状态:spec §Shared Key-Value Store / §Shared Queues / §Metrics。 +//! +//! 这些设施按 proxy-wasm spec 必须在多个 VM / plugin 实例之间共享,因此放在进程级 `OnceCell` +//! + `RwLock` 之后;不依赖具体 `HostState`。 +//! +//! 实现要点: +//! - **Shared Data**:键值 + CAS(compare-and-swap)。每次成功 set 都使 cas 自增; +//! guest 传 `cas=0` 表示不校验。 +//! - **Shared Queues**:通过 `register_shared_queue(name)` / `resolve_shared_queue(vm_id, name)` 拿 qid; +//! `enqueue`/`dequeue` 操作 `VecDeque>`。 +//! - **Metrics**:Counter / Gauge / Histogram。Counter 不允许 decrement;Histogram 这里按 Gauge 处理 +//! (proxy-wasm 0.2.1 没有规定 histogram 的内部表示),足以满足 guest 的调用语义。 + +use std::collections::{HashMap, VecDeque}; +use std::sync::{Mutex, RwLock}; + +use once_cell::sync::Lazy; +use opentelemetry::global; + +use crate::abi::MetricType; + +// ───────────────────────────────────────────────────────── +// Shared Data(spec §Shared Key-Value Store) +// ───────────────────────────────────────────────────────── + +#[derive(Debug, Default, Clone)] +pub struct SharedDataEntry { + pub value: Vec, + pub cas: u32, +} + +#[derive(Debug, Default)] +struct SharedDataStore { + map: HashMap, SharedDataEntry>, +} + +static SHARED_DATA: Lazy> = Lazy::new(|| RwLock::new(SharedDataStore::default())); + +/// 读:返回 (value, cas);不存在返回 `None`。 +pub fn shared_data_get(key: &[u8]) -> Option<(Vec, u32)> { + let g = SHARED_DATA.read().ok()?; + g.map.get(key).map(|e| (e.value.clone(), e.cas)) +} + +#[derive(Debug, PartialEq, Eq)] +pub enum SharedDataSetResult { + Ok, + CasMismatch, +} + +/// 写:cas==0 表示不校验;非 0 必须等于当前 cas 才能成功。 +pub fn shared_data_set(key: &[u8], value: &[u8], cas: u32) -> SharedDataSetResult { + let Ok(mut g) = SHARED_DATA.write() else { + return SharedDataSetResult::CasMismatch; + }; + let entry = g.map.entry(key.to_vec()).or_default(); + if cas != 0 && cas != entry.cas { + return SharedDataSetResult::CasMismatch; + } + entry.value = value.to_vec(); + entry.cas = entry.cas.wrapping_add(1).max(1); + SharedDataSetResult::Ok +} + +// ───────────────────────────────────────────────────────── +// Shared Queues(spec §Shared Queues) +// ───────────────────────────────────────────────────────── + +#[derive(Debug, Default)] +struct SharedQueueRegistry { + by_id: HashMap>>, + by_name: HashMap<(String, String), u32>, // (vm_id, name) -> qid + next_id: u32, +} + +static SHARED_QUEUES: Lazy> = Lazy::new(|| Mutex::new(SharedQueueRegistry::default())); + +/// 注册(或打开已存在)一个共享队列;返回 qid。 +/// +/// `vm_id` 取本 VM 的 plugin_vm_id(按 spec 是 host 实现细节;这里用 "default")。 +pub fn queue_register(vm_id: &str, name: &str) -> u32 { + let mut g = match SHARED_QUEUES.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + let key = (vm_id.to_string(), name.to_string()); + if let Some(qid) = g.by_name.get(&key).copied() { + return qid; + } + g.next_id = g.next_id.wrapping_add(1).max(1); + let qid = g.next_id; + g.by_id.insert(qid, VecDeque::new()); + g.by_name.insert(key, qid); + qid +} + +/// 解析已存在的队列;不存在返回 None。 +pub fn queue_resolve(vm_id: &str, name: &str) -> Option { + let g = SHARED_QUEUES.lock().ok()?; + g.by_name.get(&(vm_id.to_string(), name.to_string())).copied() +} + +#[derive(Debug, PartialEq, Eq)] +pub enum QueueOpResult { + Ok, + NotFound, + Empty, +} + +pub fn queue_enqueue(qid: u32, value: &[u8]) -> QueueOpResult { + let mut g = match SHARED_QUEUES.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + match g.by_id.get_mut(&qid) { + Some(q) => { + q.push_back(value.to_vec()); + QueueOpResult::Ok + } + None => QueueOpResult::NotFound, + } +} + +pub fn queue_dequeue(qid: u32) -> (QueueOpResult, Option>) { + let mut g = match SHARED_QUEUES.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + match g.by_id.get_mut(&qid) { + Some(q) => match q.pop_front() { + Some(v) => (QueueOpResult::Ok, Some(v)), + None => (QueueOpResult::Empty, None), + }, + None => (QueueOpResult::NotFound, None), + } +} + +// ───────────────────────────────────────────────────────── +// Metrics(spec §Metrics) +// ───────────────────────────────────────────────────────── + +#[derive(Debug)] +struct MetricEntry { + kind: MetricType, + value: u64, + instrument: OtelMetricInstrument, +} + +#[derive(Debug)] +enum OtelMetricInstrument { + Counter(opentelemetry::metrics::Counter), + Gauge(opentelemetry::metrics::Gauge), + Histogram(opentelemetry::metrics::Histogram), +} + +#[derive(Debug, Default)] +struct MetricRegistry { + by_id: HashMap, + by_name: HashMap, + next_id: u32, +} + +static METRICS: Lazy> = Lazy::new(|| Mutex::new(MetricRegistry::default())); + +#[derive(Debug, PartialEq, Eq)] +pub enum MetricOpResult { + Ok, + NotFound, + BadArgument, +} + +pub fn metric_define(kind: MetricType, name: &str) -> u32 { + let mut g = match METRICS.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + if let Some(id) = g.by_name.get(name).copied() { + return id; + } + g.next_id = g.next_id.wrapping_add(1).max(1); + let id = g.next_id; + let meter = global::meter("spacegate_plugin_wasm"); + let instrument = match kind { + MetricType::Counter => OtelMetricInstrument::Counter(meter.u64_counter(name.to_string()).build()), + MetricType::Gauge => OtelMetricInstrument::Gauge(meter.i64_gauge(name.to_string()).build()), + MetricType::Histogram => OtelMetricInstrument::Histogram(meter.u64_histogram(name.to_string()).build()), + }; + g.by_id.insert(id, MetricEntry { kind, value: 0, instrument }); + g.by_name.insert(name.to_string(), id); + id +} + +pub fn metric_record(id: u32, value: u64) -> MetricOpResult { + let mut g = match METRICS.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + match g.by_id.get_mut(&id) { + Some(m) => { + m.value = value; + match &m.instrument { + OtelMetricInstrument::Counter(counter) => counter.add(value, &[]), + OtelMetricInstrument::Gauge(gauge) => gauge.record(value as i64, &[]), + OtelMetricInstrument::Histogram(histogram) => histogram.record(value, &[]), + } + MetricOpResult::Ok + } + None => MetricOpResult::NotFound, + } +} + +pub fn metric_increment(id: u32, delta: i64) -> MetricOpResult { + let mut g = match METRICS.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + let Some(m) = g.by_id.get_mut(&id) else { + return MetricOpResult::NotFound; + }; + if matches!(m.kind, MetricType::Counter) && delta < 0 { + return MetricOpResult::BadArgument; + } + if delta >= 0 { + m.value = m.value.saturating_add(delta as u64); + } else { + m.value = m.value.saturating_sub((-delta) as u64); + } + match &m.instrument { + OtelMetricInstrument::Counter(counter) => counter.add(delta.max(0) as u64, &[]), + OtelMetricInstrument::Gauge(gauge) => gauge.record(m.value as i64, &[]), + OtelMetricInstrument::Histogram(histogram) => histogram.record(m.value, &[]), + } + MetricOpResult::Ok +} + +pub fn metric_get(id: u32) -> Option { + let g = METRICS.lock().ok()?; + g.by_id.get(&id).map(|m| m.value) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn shared_data_cas_roundtrip() { + let key = b"shared_data_cas_roundtrip_key"; + assert_eq!(shared_data_set(key, b"v1", 0), SharedDataSetResult::Ok); + let (v, cas1) = shared_data_get(key).unwrap(); + assert_eq!(v, b"v1"); + assert!(cas1 > 0); + assert_eq!(shared_data_set(key, b"v2", 99), SharedDataSetResult::CasMismatch); + assert_eq!(shared_data_set(key, b"v2", cas1), SharedDataSetResult::Ok); + let (v, cas2) = shared_data_get(key).unwrap(); + assert_eq!(v, b"v2"); + assert!(cas2 > cas1); + } + + #[test] + fn shared_queue_roundtrip() { + let qid = queue_register("default", "shared_queue_roundtrip_q"); + assert_eq!(queue_enqueue(qid, b"a"), QueueOpResult::Ok); + assert_eq!(queue_enqueue(qid, b"b"), QueueOpResult::Ok); + let (s, v) = queue_dequeue(qid); + assert_eq!(s, QueueOpResult::Ok); + assert_eq!(v.as_deref(), Some(b"a".as_slice())); + let (s, v) = queue_dequeue(qid); + assert_eq!(s, QueueOpResult::Ok); + assert_eq!(v.as_deref(), Some(b"b".as_slice())); + let (s, _) = queue_dequeue(qid); + assert_eq!(s, QueueOpResult::Empty); + } + + #[test] + fn metric_counter_increment_only() { + let id = metric_define(MetricType::Counter, "metric_counter_increment_only"); + assert_eq!(metric_increment(id, 3), MetricOpResult::Ok); + assert_eq!(metric_get(id), Some(3)); + assert_eq!(metric_increment(id, -1), MetricOpResult::BadArgument); + assert_eq!(metric_get(id), Some(3)); + } + + #[test] + fn metric_gauge_bidirectional() { + let id = metric_define(MetricType::Gauge, "metric_gauge_bidirectional"); + assert_eq!(metric_increment(id, 5), MetricOpResult::Ok); + assert_eq!(metric_increment(id, -2), MetricOpResult::Ok); + assert_eq!(metric_get(id), Some(3)); + assert_eq!(metric_record(id, 100), MetricOpResult::Ok); + assert_eq!(metric_get(id), Some(100)); + } +} diff --git a/crates/plugin-wasm/src/shell.rs b/crates/plugin-wasm/src/shell.rs new file mode 100644 index 00000000..228e6038 --- /dev/null +++ b/crates/plugin-wasm/src/shell.rs @@ -0,0 +1,282 @@ +//! `Plugin` 实现:实例化一个或多个长生命 Vm,后续请求复用,并为每个 Vm 起一条后台 tick 任务驱动 `proxy_on_tick`。 +//! +//! 与「每请求新建 Vm」相比的取舍: +//! +//! - 优点:guest 的 root context 可保留状态;`proxy_on_tick` 可真正按 `proxy_set_tick_period_milliseconds` 周期触发; +//! `on_vm_start` / `on_configure` 仅跑一次,热路径少几毫秒。 +//! - 单个 `Vm` 内仍通过 `tokio::sync::Mutex` 串行化处理(wasmtime `Store` 是 !Sync); +//! 配置 `vm_pool_size > 1` 时,通过多个独立 `Store + Instance` 提供插件实例内并发。 +//! - 配置 `wait_vm_pool_size > 0` 时,`X-RateLimit-Policy: wait` 请求会进入独立 wait 池; +//! 其他请求继续走普通池。 + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; + +use spacegate_kernel::{SgBody, SgRequest, SgResponse}; +use spacegate_plugin::{BoxError, Inner, Plugin, PluginConfig}; +use tokio::sync::{Mutex as AsyncMutex, MutexGuard}; + +use crate::config::{FailStrategy, WasmPluginShellConfig}; +use crate::runtime::default_module_cache; +use crate::vm::Vm; + +/// Drop 时 abort 关联的 tokio 任务;保证后台 tick 不会在 shell 析构后继续持有 Vm 引用。 +struct AbortOnDrop(tokio::task::JoinHandle<()>); +impl Drop for AbortOnDrop { + fn drop(&mut self) { + self.0.abort(); + } +} + +#[derive(Clone)] +struct VmSlot { + vm: Arc>, + inflight: Arc, +} + +impl VmSlot { + fn new(vm: Vm) -> Self { + Self { + vm: Arc::new(AsyncMutex::new(vm)), + inflight: Arc::new(AtomicUsize::new(0)), + } + } +} + +struct InflightGuard { + inflight: Arc, + pool_name: &'static str, + vm_index: usize, +} + +impl InflightGuard { + fn new(slot: &VmSlot, pool_name: &'static str, vm_index: usize) -> Self { + let current = slot.inflight.fetch_add(1, Ordering::AcqRel) + 1; + tracing::debug!(target: "spacegate_plugin_wasm", vm_pool = pool_name, vm_index, inflight = current, "VM inflight incremented"); + Self { + inflight: slot.inflight.clone(), + pool_name, + vm_index, + } + } +} + +impl Drop for InflightGuard { + fn drop(&mut self) { + let current = self.inflight.fetch_sub(1, Ordering::AcqRel).saturating_sub(1); + tracing::debug!(target: "spacegate_plugin_wasm", vm_pool = self.pool_name, vm_index = self.vm_index, inflight = current, "VM inflight decremented"); + } +} + +/// Proxy-Wasm 宿主壳插件(`CODE = "wasm"`)。 +pub struct WasmPluginShell { + cfg: Arc, + #[allow(dead_code)] + module: Arc, + vms: Vec, + wait_vms: Vec, + next_vm: AtomicUsize, + next_wait_vm: AtomicUsize, + /// 后台 tick 任务句柄;shell drop 时自动 abort。 + /// `None` 表示创建时没有 tokio runtime 上下文(非测试常见路径),tick 退化为不驱动。 + _tick_tasks: Vec, +} + +impl Plugin for WasmPluginShell { + const CODE: &'static str = "wasm"; + + fn call(&self, req: SgRequest, inner: Inner) -> impl std::future::Future> + Send { + let cfg = self.cfg.clone(); + let use_wait_pool = is_wait_policy(&req) && !self.wait_vms.is_empty(); + let pool_name = if use_wait_pool { "wait" } else { "normal" }; + let slots = if use_wait_pool { self.wait_vms.clone() } else { self.vms.clone() }; + let module = self.module.clone(); + let start_index = if use_wait_pool { + self.next_wait_vm.fetch_add(1, Ordering::Relaxed) + } else { + self.next_vm.fetch_add(1, Ordering::Relaxed) + }; + async move { + tracing::info!( + target: "spacegate_plugin_wasm", + method = %req.method(), + uri = %req.uri(), + vm_pool = pool_name, + "wasm plugin shell: request entered plugin layer" + ); + + if slots.is_empty() { + let mut resp = SgResponse::new(SgBody::full(format!("wasm plugin error: empty {pool_name} VM pool"))); + *resp.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; + return Ok(resp); + } + + for offset in 0..slots.len() { + let index = start_index.wrapping_add(offset) % slots.len(); + if let Ok(guard) = slots[index].vm.try_lock() { + let _inflight = InflightGuard::new(&slots[index], pool_name, index); + return process_with_vm(module, cfg, req, inner, guard, pool_name, index).await; + } + } + + let index = start_index % slots.len(); + let _inflight = InflightGuard::new(&slots[index], pool_name, index); + let guard = slots[index].vm.lock().await; + process_with_vm(module, cfg, req, inner, guard, pool_name, index).await + } + } + + fn create(plugin_config: PluginConfig) -> Result { + let raw_spec = plugin_config.spec.clone(); + let cfg: WasmPluginShellConfig = serde_json::from_value(plugin_config.spec).map_err(|e| -> BoxError { format!("wasm spec: {e}").into() })?; + if cfg.url.trim().is_empty() { + return Err("wasm plugin: missing or empty `url`".into()); + } + tracing::info!( + target: "spacegate_plugin_wasm", + url = %cfg.url, + plugin_config_kind = %if cfg.plugin_config.is_null() { "null" } else { "object" }, + plugin_config_keys = ?cfg.plugin_config.as_object().map(|o| o.keys().collect::>()), + clusters = ?cfg.clusters.keys().collect::>(), + raw_keys = ?raw_spec.as_object().map(|o| o.keys().collect::>()), + "wasm plugin: create with config" + ); + let cache = default_module_cache(); + let module = cache.get_or_compile(&cfg).map_err(|e| -> BoxError { format!("compile wasm: {e}").into() })?; + let cfg = Arc::new(cfg); + let pool_size = cfg.normalized_vm_pool_size(); + let wait_pool_size = cfg.normalized_wait_vm_pool_size(); + let mut vms = Vec::with_capacity(pool_size); + let mut wait_vms = Vec::with_capacity(wait_pool_size); + let mut tick_tasks = Vec::with_capacity(pool_size + wait_pool_size); + for index in 0..pool_size { + let vm = Vm::new(&module, cfg.clone()).map_err(|e| -> BoxError { format!("Vm::new[{index}]: {e}").into() })?; + let slot = VmSlot::new(vm); + if let Some(task) = spawn_tick_loop("normal", index, &slot.vm) { + tick_tasks.push(task); + } + vms.push(slot); + } + for index in 0..wait_pool_size { + let vm = Vm::new(&module, cfg.clone()).map_err(|e| -> BoxError { format!("Vm::new[wait:{index}]: {e}").into() })?; + let slot = VmSlot::new(vm); + if let Some(task) = spawn_tick_loop("wait", index, &slot.vm) { + tick_tasks.push(task); + } + wait_vms.push(slot); + } + tracing::info!( + target: "spacegate_plugin_wasm", + pool_size, + wait_pool_size, + "wasm plugin: VM pools created" + ); + Ok(Self { + cfg, + module, + vms, + wait_vms, + next_vm: AtomicUsize::new(0), + next_wait_vm: AtomicUsize::new(0), + _tick_tasks: tick_tasks, + }) + } +} + +fn is_wait_policy(req: &SgRequest) -> bool { + req.headers().get("x-ratelimit-policy").and_then(|value| value.to_str().ok()).map(|value| value.trim().eq_ignore_ascii_case("wait")).unwrap_or(false) +} + +async fn process_with_vm( + module: Arc, + cfg: Arc, + req: SgRequest, + inner: Inner, + mut guard: MutexGuard<'_, Vm>, + pool_name: &'static str, + vm_index: usize, +) -> Result { + match guard.process(req, inner).await { + Ok(resp) => { + tracing::info!(target: "spacegate_plugin_wasm", vm_pool = pool_name, vm_index, status = %resp.status(), "Vm::process ok"); + Ok(resp) + } + Err(e) => { + tracing::error!(target: "spacegate_plugin_wasm", vm_pool = pool_name, vm_index, error = %e, "wasm plugin failed"); + if e.requires_vm_rebuild() { + match Vm::new(&module, cfg.clone()) { + Ok(new_vm) => { + *guard = new_vm; + tracing::warn!(target: "spacegate_plugin_wasm", vm_pool = pool_name, vm_index, "VM rebuilt after abnormal failure"); + } + Err(rebuild_err) => { + tracing::error!( + target: "spacegate_plugin_wasm", + vm_pool = pool_name, + vm_index, + error = %rebuild_err, + "VM rebuild failed after abnormal failure" + ); + } + } + } + let status = if matches!(cfg.fail_strategy, FailStrategy::FailOpen) { + http::StatusCode::BAD_GATEWAY + } else { + http::StatusCode::INTERNAL_SERVER_ERROR + }; + let mut resp = SgResponse::new(SgBody::full(format!("wasm plugin error: {e}"))); + *resp.status_mut() = status; + Ok(resp) + } + } +} + +/// 起一条 50ms 粒度的轮询任务:每个 tick 看一眼 `Vm::tick_period_ms()`,到点了就 `Vm::tick()`。 +/// +/// - 粒度 50ms 是工程取舍:spec 没有规定 tick 必须精确,envoy 也是大颗粒度;如果 guest 设置 < 50ms 的周期, +/// 实际触发率会被压到 50ms 一次——记入 `lib.rs` 顶部已知限制。 +/// - 任务持有 `Arc>`,shell drop 时 `AbortOnDrop` 立刻 abort,不存在悬挂任务。 +/// - 若 `proxy_on_tick` trap,记 error 后退出循环(防止热循环 panic)。 +fn spawn_tick_loop(pool_name: &'static str, vm_index: usize, vm: &Arc>) -> Option { + let handle = tokio::runtime::Handle::try_current().ok()?; + let vm = vm.clone(); + let task = handle.spawn(async move { + const POLL_GRANULARITY: Duration = Duration::from_millis(50); + let mut interval = tokio::time::interval(POLL_GRANULARITY); + // 首次 tick 立刻就绪——跳过它,避免一启动就触发 on_tick。 + interval.tick().await; + let mut last_tick: Option = None; + loop { + interval.tick().await; + let mut guard = vm.lock().await; + let period = guard.tick_period_ms(); + if period == 0 { + last_tick = None; + continue; + } + let due = match last_tick { + Some(t) => t.elapsed().as_millis() as u64 >= period as u64, + None => true, + }; + if !due { + continue; + } + if let Err(e) = guard.tick() { + tracing::error!( + target: "spacegate_plugin_wasm", + vm_pool = pool_name, + vm_index, + error = %e, + "proxy_on_tick failed; stopping tick task" + ); + return; + } + last_tick = Some(Instant::now()); + } + }); + Some(AbortOnDrop(task)) +} diff --git a/crates/plugin-wasm/src/vm.rs b/crates/plugin-wasm/src/vm.rs new file mode 100644 index 00000000..02aba8a9 --- /dev/null +++ b/crates/plugin-wasm/src/vm.rs @@ -0,0 +1,734 @@ +//! `Vm`:单个 wasm 实例 + per-request 驱动状态机。 +//! +//! 一个 `Vm` 包含: +//! +//! - `wasmtime::Store`:宿主状态 + linear memory +//! - `wasmtime::Instance`:实例化后的 wasm guest +//! - 缓存的 guest exports(避免每次按名查找) +//! +//! 关键流程([`Vm::process`]): +//! +//! 1. `proxy_on_context_create(http_ctx_id, root_id)` +//! 2. `proxy_on_request_headers` → 解析 Action(end_of_stream=false 当有 body 时) +//! 3. 若 Pause → `drive_until_continue`:循环 await dispatch_http_call 结果 → +//! 调 `proxy_on_http_call_response` → 直至 guest `continue_stream(HTTP_REQUEST)` 或 `send_local_response` +//! 4. 若 guest 导出 `proxy_on_request_body`:收齐 body,调 hook,可能再次 Pause;body 写回 SgRequest +//! 5. 若 guest 导出 `proxy_on_request_trailers`:用空 trailer map 调一次(spacegate 暂不暴露 trailer) +//! 6. 若 guest 写了 `local_response`:直接返回它(短路 inner.call) +//! 7. 否则:把 ctx 内的 headers/body 同步回 `SgRequest`,调 `inner.call` +//! 8. `proxy_on_response_headers` / `proxy_on_response_body` / `proxy_on_response_trailers` / `proxy_on_log` / +//! `proxy_on_done` (spec 要求 false 时等 `proxy_done`) / `proxy_on_delete` +//! 9. ctx 清理 + +use std::sync::Arc; + +use bytes::Bytes; +use http::HeaderMap; +use http_body_util::{BodyExt, Limited}; +use spacegate_kernel::{SgBody, SgRequest, SgResponse}; +use tracing::{debug, info, warn}; +use wasmtime::{AsContext, AsContextMut, Instance, Linker, Store, TypedFunc}; + +use crate::abi::{Action, MemoryHelper}; +use crate::config::{FailStrategy, WasmPluginShellConfig}; +use crate::engine::{ensure_epoch_ticker_started, shared_engine}; +use crate::error::WasmHostError; +use crate::host_fn::register_all; +use crate::host_state::{ContextStage, HostState, HttpCallResult, PseudoHeaders, RequestContext}; + +/// 长生命 Vm:插件 `create` 时实例化一次,之后被多次请求复用,再加一条 +/// 后台 tick 任务用来驱动 `proxy_on_tick`。`store` !Sync,所以共享时必须 +/// 套 `tokio::sync::Mutex`(见 `shell.rs`)。 +pub struct Vm { + store: Store, + #[allow(dead_code)] + instance: Instance, + root_id: u32, + next_ctx_id: u32, + dispatch_rx: tokio::sync::mpsc::UnboundedReceiver<(u32, HttpCallResult)>, + fail_strategy: FailStrategy, + fn_on_context_create: TypedFunc<(u32, u32), ()>, + fn_on_vm_start: Option>, + fn_on_configure: TypedFunc<(u32, u32), u32>, + fn_on_request_headers: TypedFunc<(u32, u32, u32), u32>, + fn_on_request_body: Option>, + fn_on_request_trailers: Option>, + fn_on_response_headers: TypedFunc<(u32, u32, u32), u32>, + fn_on_response_body: Option>, + fn_on_response_trailers: Option>, + fn_on_http_call_response: TypedFunc<(u32, u32, u32, u32, u32), ()>, + fn_on_log: Option>, + fn_on_done: Option>, + fn_on_delete: Option>, + fn_on_tick: Option>, +} + +impl Vm { + /// 创建并启动一个 Vm:实例化 → 缓存 exports → 跑 vm_start/configure。 + /// + /// 这是同步函数。整个过程不涉及 `await`(wasmtime 编译/实例化、guest `_initialize`、 + /// `on_vm_start` / `on_configure` 全部是同步调用),所以 `WasmPluginShell::create` + /// 这种 sync 上下文也能直接构造。 + pub fn new(module: &wasmtime::Module, shell_cfg: Arc) -> Result { + ensure_epoch_ticker_started(); + let engine = shared_engine(); + let host = HostState::new(shell_cfg.clone()); + let mut store: Store = Store::new(engine, host); + store.limiter(|state| &mut state.resource_limiter); + prepare_store_for_guest_call(&mut store)?; + let mut linker: Linker = Linker::new(engine); + let (dispatch_tx, dispatch_rx) = tokio::sync::mpsc::unbounded_channel::<(u32, HttpCallResult)>(); + register_all(&mut linker, dispatch_tx).map_err(|e| WasmHostError::Instantiate(format!("register host fn: {e}")))?; + + register_wasi_stubs(&mut linker)?; + + let instance = linker.instantiate(&mut store, module).map_err(|e| WasmHostError::Instantiate(format!("instantiate: {e}")))?; + + let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| WasmHostError::AbiViolation("no `memory` export".into()))?; + if let Some(max_pages) = shell_cfg.limits.max_memory_pages { + let current_pages = memory.size(&store) as u32; + if current_pages > max_pages { + return Err(WasmHostError::ResourceLimit(format!( + "initial memory pages {current_pages} exceeds max_memory_pages {max_pages}" + ))); + } + } + store.data_mut().memory = Some(memory); + // spec §Memory management:优先 `proxy_on_memory_allocate`,否则回退 `malloc`。 + if let Ok(alloc) = instance.get_typed_func::(&mut store, "proxy_on_memory_allocate") { + store.data_mut().alloc = Some(alloc); + } else if let Ok(alloc) = instance.get_typed_func::(&mut store, "malloc") { + store.data_mut().alloc = Some(alloc); + } else { + return Err(WasmHostError::AbiViolation("no memory allocator export (proxy_on_memory_allocate or malloc)".into())); + } + + // spec §Integration:先 `_initialize`;若不存在尝试 `_start`。 + if let Ok(init) = instance.get_typed_func::<(), ()>(&mut store, "_initialize") { + prepare_store_for_guest_call(&mut store)?; + init.call(&mut store, ()).map_err(|e| WasmHostError::Instantiate(format!("_initialize: {e}")))?; + } else if let Ok(start) = instance.get_typed_func::<(), ()>(&mut store, "_start") { + prepare_store_for_guest_call(&mut store)?; + start.call(&mut store, ()).map_err(|e| WasmHostError::Instantiate(format!("_start: {e}")))?; + } + + let fn_on_context_create = instance + .get_typed_func::<(u32, u32), ()>(&mut store, "proxy_on_context_create") + .map_err(|e| WasmHostError::AbiViolation(format!("get proxy_on_context_create: {e}")))?; + let fn_on_vm_start = instance.get_typed_func::<(u32, u32), u32>(&mut store, "proxy_on_vm_start").ok(); + let fn_on_configure = + instance.get_typed_func::<(u32, u32), u32>(&mut store, "proxy_on_configure").map_err(|e| WasmHostError::AbiViolation(format!("get proxy_on_configure: {e}")))?; + let fn_on_request_headers = instance + .get_typed_func::<(u32, u32, u32), u32>(&mut store, "proxy_on_request_headers") + .map_err(|e| WasmHostError::AbiViolation(format!("get proxy_on_request_headers: {e}")))?; + let fn_on_request_body = instance.get_typed_func::<(u32, u32, u32), u32>(&mut store, "proxy_on_request_body").ok(); + let fn_on_request_trailers = instance.get_typed_func::<(u32, u32), u32>(&mut store, "proxy_on_request_trailers").ok(); + let fn_on_response_headers = instance + .get_typed_func::<(u32, u32, u32), u32>(&mut store, "proxy_on_response_headers") + .map_err(|e| WasmHostError::AbiViolation(format!("get proxy_on_response_headers: {e}")))?; + let fn_on_response_body = instance.get_typed_func::<(u32, u32, u32), u32>(&mut store, "proxy_on_response_body").ok(); + let fn_on_response_trailers = instance.get_typed_func::<(u32, u32), u32>(&mut store, "proxy_on_response_trailers").ok(); + let fn_on_http_call_response = instance + .get_typed_func::<(u32, u32, u32, u32, u32), ()>(&mut store, "proxy_on_http_call_response") + .map_err(|e| WasmHostError::AbiViolation(format!("get proxy_on_http_call_response: {e}")))?; + let fn_on_log = instance.get_typed_func::(&mut store, "proxy_on_log").ok(); + let fn_on_done = instance.get_typed_func::(&mut store, "proxy_on_done").ok(); + let fn_on_delete = instance.get_typed_func::(&mut store, "proxy_on_delete").ok(); + let fn_on_tick = instance.get_typed_func::(&mut store, "proxy_on_tick").ok(); + + let root_id = store.data().root_context_id; + let next_ctx_id = root_id + 1; + let fail_strategy = shell_cfg.fail_strategy; + + let mut vm = Self { + store, + instance, + root_id, + next_ctx_id, + dispatch_rx, + fail_strategy, + fn_on_context_create, + fn_on_vm_start, + fn_on_configure, + fn_on_request_headers, + fn_on_request_body, + fn_on_request_trailers, + fn_on_response_headers, + fn_on_response_body, + fn_on_response_trailers, + fn_on_http_call_response, + fn_on_log, + fn_on_done, + fn_on_delete, + fn_on_tick, + }; + + // 启动序:on_context_create(root, 0) → on_vm_start → on_configure + vm.store.data_mut().contexts.insert(root_id, RequestContext::default()); + vm.create_context(root_id, 0)?; + if let Some(ref f) = vm.fn_on_vm_start { + vm.store.data_mut().effective_context = root_id; + let cfg_len = vm.store.data().configuration.len() as u32; + prepare_store_for_guest_call(&mut vm.store)?; + let ok = f.call(&mut vm.store, (root_id, cfg_len)).map_err(|e| WasmHostError::GuestTrap { hook: "on_vm_start", source: e })?; + if ok == 0 { + return Err(WasmHostError::Instantiate("guest on_vm_start returned 0 (=invalid VM configuration)".into())); + } + } + vm.store.data_mut().effective_context = root_id; + let cfg_len = vm.store.data().configuration.len() as u32; + tracing::info!(target: "spacegate_plugin_wasm", cfg_len, "calling proxy_on_configure"); + let configure_fn = vm.fn_on_configure.clone(); + prepare_store_for_guest_call(&mut vm.store)?; + let ok = configure_fn.call(&mut vm.store, (root_id, cfg_len)).map_err(|e| WasmHostError::GuestTrap { hook: "on_configure", source: e })?; + if ok == 0 { + warn!(target: "spacegate_plugin_wasm", "guest on_configure returned 0 (=invalid config)"); + } + Ok(vm) + } + + fn create_context(&mut self, ctx_id: u32, parent_id: u32) -> Result<(), WasmHostError> { + self.store.data_mut().effective_context = ctx_id; + let f = self.fn_on_context_create.clone(); + self.prepare_guest_call()?; + f.call(&mut self.store, (ctx_id, parent_id)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_context_create", + source: e, + })?; + Ok(()) + } + + /// 完整跑一遍:on_request_headers → 可能多次 dispatch → on_request_body → inner.call → on_response_* + pub async fn process(&mut self, req: SgRequest, inner: spacegate_plugin::Inner) -> Result { + // 跨请求清理:上一次请求若提前 `send_local_response` 短路,可能留下未消费的 + // dispatch 结果和 pending token,不清掉会让本请求的 `drive_until_continue` + // 把陈旧响应误当成自己的(spec §proxy_http_call 不要求 host 持久化)。 + while self.dispatch_rx.try_recv().is_ok() {} + self.store.data_mut().pending_calls.clear(); + + let http_ctx_id = self.next_ctx_id; + self.next_ctx_id = self.next_ctx_id.wrapping_add(1); + + let (parts, body) = req.into_parts(); + let method = parts.method.clone(); + let uri = parts.uri.clone(); + let version = parts.version; + let path = uri.path_and_query().map(|p| p.to_string()).unwrap_or_else(|| "/".to_string()); + let authority = uri.authority().map(|a| a.to_string()).unwrap_or_else(|| parts.headers.get(http::header::HOST).and_then(|h| h.to_str().ok()).unwrap_or("").to_string()); + let scheme = uri.scheme_str().unwrap_or("http").to_string(); + let headers = parts.headers.clone(); + let pseudo = PseudoHeaders { + method: method.as_str().to_string(), + path: path.clone(), + authority: authority.clone(), + scheme, + }; + let request_protocol = format!("{:?}", version); + + let want_request_body = self.fn_on_request_body.is_some(); + + let root_id = self.root_id; + self.create_context(http_ctx_id, root_id)?; + { + let st = self.store.data_mut(); + let ctx = st.contexts.entry(http_ctx_id).or_default(); + ctx.parent_id = root_id; + ctx.stage = ContextStage::RequestHeaders; + ctx.request_pseudo = pseudo; + ctx.request_headers = headers.clone(); + ctx.continue_requested = false; + ctx.request_protocol = request_protocol; + st.effective_context = http_ctx_id; + } + + // 调 on_request_headers + let num_headers = (self.store.data().contexts[&http_ctx_id].request_headers.len() + 4) as u32; + let end_of_stream_for_headers: u32 = if want_request_body { 0 } else { 1 }; + let on_req_hdr = self.fn_on_request_headers.clone(); + self.prepare_guest_call()?; + let action_raw = on_req_hdr.call(&mut self.store, (http_ctx_id, num_headers, end_of_stream_for_headers)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_request_headers", + source: e, + })?; + let action = Action::from_u32(action_raw); + debug!(target: "spacegate_plugin_wasm", http_ctx_id, ?action, "on_request_headers returned"); + + // Guest 可能在 headers 阶段 Pause 以等待 on_request_body(尚未 dispatch_http_call); + // 此时 pending_calls 为空,不能进入 drive_until_continue,否则会永久阻塞在 dispatch_rx。 + if action == Action::Pause && !self.store.data().pending_calls.is_empty() { + self.drive_until_continue(http_ctx_id).await?; + } + + if let Some(local) = self.store.data_mut().contexts.get_mut(&http_ctx_id).and_then(|c| c.local_response.take()) { + info!(target: "spacegate_plugin_wasm", http_ctx_id, status = local.status, "guest local response (after headers)"); + self.invoke_log_done_delete(http_ctx_id)?; + return Ok(build_local_response(local)); + } + + // ─── on_request_body:把请求 body 物化后喂给 guest(仅当 guest 导出该 hook)─── + let (new_req_for_inner, collected_body_after_hook) = if want_request_body { + // collect body + let collected = collect_body_limited(body, self.store.data().shell_cfg.limits.max_body_bytes).await?; + let body_size = collected.len() as u32; + { + let st = self.store.data_mut(); + if let Some(ctx) = st.contexts.get_mut(&http_ctx_id) { + ctx.request_body = Some(collected.clone()); + ctx.stage = ContextStage::RequestBody; + ctx.continue_requested = false; + ctx.request_size = collected.len() as u64; + st.effective_context = http_ctx_id; + } + } + let on_req_body = self.fn_on_request_body.clone().expect("guarded by want_request_body"); + self.prepare_guest_call()?; + let action_raw = on_req_body.call(&mut self.store, (http_ctx_id, body_size, 1)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_request_body", + source: e, + })?; + if Action::from_u32(action_raw) == Action::Pause { + self.drive_until_continue(http_ctx_id).await?; + } + if let Some(local) = self.store.data_mut().contexts.get_mut(&http_ctx_id).and_then(|c| c.local_response.take()) { + info!(target: "spacegate_plugin_wasm", http_ctx_id, status = local.status, "guest local response (after request body)"); + self.invoke_log_done_delete(http_ctx_id)?; + return Ok(build_local_response(local)); + } + let final_body = self.store.data().contexts.get(&http_ctx_id).and_then(|c| c.request_body.clone()).unwrap_or(collected); + (None, Some(final_body)) + } else { + (Some(body), None) + }; + + // ─── on_request_trailers:spacegate 当前不感知 trailers,给 guest 一个空 trailer 入参 ─── + if let Some(f) = self.fn_on_request_trailers.clone() { + self.store.data_mut().effective_context = http_ctx_id; + if let Some(ctx) = self.store.data_mut().contexts.get_mut(&http_ctx_id) { + ctx.stage = ContextStage::RequestTrailers; + ctx.continue_requested = false; + } + self.prepare_guest_call()?; + let action_raw = f.call(&mut self.store, (http_ctx_id, 0)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_request_trailers", + source: e, + })?; + if Action::from_u32(action_raw) == Action::Pause { + self.drive_until_continue(http_ctx_id).await?; + } + if let Some(local) = self.store.data_mut().contexts.get_mut(&http_ctx_id).and_then(|c| c.local_response.take()) { + info!(target: "spacegate_plugin_wasm", http_ctx_id, status = local.status, "guest local response (after request trailers)"); + self.invoke_log_done_delete(http_ctx_id)?; + return Ok(build_local_response(local)); + } + } + + // 把 ctx 内可能被 guest 改过的 method/path/headers 写回 SgRequest + let (new_headers, new_pseudo) = self + .store + .data() + .contexts + .get(&http_ctx_id) + .map(|c| (c.request_headers.clone(), c.request_pseudo.clone())) + .unwrap_or_else(|| (HeaderMap::new(), PseudoHeaders::default())); + let new_uri = rebuild_uri(&new_pseudo.scheme, &new_pseudo.authority, &new_pseudo.path).unwrap_or(uri); + let mut new_parts = parts; + new_parts.method = new_pseudo.method.parse().unwrap_or(method); + new_parts.uri = new_uri; + new_parts.headers = new_headers; + new_parts.version = version; + let new_body = match (new_req_for_inner, collected_body_after_hook) { + (Some(b), _) => b, + (None, Some(bytes)) => SgBody::full(bytes), + (None, None) => SgBody::empty(), + }; + let new_req = SgRequest::from_parts(new_parts, new_body); + + let resp = inner.call(new_req).await; + + // ─── on_response_headers ─── + let (resp_parts, resp_body) = resp.into_parts(); + let status = resp_parts.status.as_u16(); + let status_message = resp_parts.status.canonical_reason().unwrap_or("").to_string(); + let resp_headers = resp_parts.headers.clone(); + { + let st = self.store.data_mut(); + if let Some(ctx) = st.contexts.get_mut(&http_ctx_id) { + ctx.stage = ContextStage::ResponseHeaders; + ctx.response_status = Some(status); + ctx.response_status_message = status_message; + ctx.response_headers = resp_headers.clone(); + ctx.continue_requested = false; + st.effective_context = http_ctx_id; + } + } + let want_response_body = self.fn_on_response_body.is_some(); + let end_of_stream_for_resp_hdr: u32 = if want_response_body { 0 } else { 1 }; + let on_resp_hdr = self.fn_on_response_headers.clone(); + self.prepare_guest_call()?; + let action_raw = on_resp_hdr.call(&mut self.store, (http_ctx_id, (resp_headers.len() + 1) as u32, end_of_stream_for_resp_hdr)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_response_headers", + source: e, + })?; + if Action::from_u32(action_raw) == Action::Pause { + self.drive_until_continue(http_ctx_id).await?; + } + if let Some(local) = self.store.data_mut().contexts.get_mut(&http_ctx_id).and_then(|c| c.local_response.take()) { + info!(target: "spacegate_plugin_wasm", http_ctx_id, status = local.status, "guest local response (after response headers)"); + self.invoke_log_done_delete(http_ctx_id)?; + return Ok(build_local_response(local)); + } + + // ─── on_response_body ─── + let (mut final_headers, final_body): (HeaderMap, SgBody) = if let Some(f) = self.fn_on_response_body.clone() { + let collected = collect_body_limited(resp_body, self.store.data().shell_cfg.limits.max_body_bytes).await?; + let body_size = collected.len() as u32; + { + let st = self.store.data_mut(); + if let Some(ctx) = st.contexts.get_mut(&http_ctx_id) { + ctx.response_body = Some(collected.clone()); + ctx.stage = ContextStage::ResponseBody; + ctx.continue_requested = false; + ctx.response_size = collected.len() as u64; + st.effective_context = http_ctx_id; + } + } + self.prepare_guest_call()?; + let action_raw = f.call(&mut self.store, (http_ctx_id, body_size, 1)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_response_body", + source: e, + })?; + if Action::from_u32(action_raw) == Action::Pause { + self.drive_until_continue(http_ctx_id).await?; + } + let updated_body = self.store.data().contexts.get(&http_ctx_id).and_then(|c| c.response_body.clone()).unwrap_or(collected); + let updated_headers = self.store.data().contexts.get(&http_ctx_id).map(|c| c.response_headers.clone()).unwrap_or(resp_headers); + (updated_headers, SgBody::full(updated_body)) + } else { + (resp_headers, SgBody::new(resp_body)) + }; + + // ─── on_response_trailers ─── + if let Some(f) = self.fn_on_response_trailers.clone() { + self.store.data_mut().effective_context = http_ctx_id; + if let Some(ctx) = self.store.data_mut().contexts.get_mut(&http_ctx_id) { + ctx.stage = ContextStage::ResponseTrailers; + ctx.continue_requested = false; + } + self.prepare_guest_call()?; + let _ = f.call(&mut self.store, (http_ctx_id, 0)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_response_trailers", + source: e, + })?; + // guest 可能改了 response_headers → 同步回 final_headers + if let Some(ctx) = self.store.data().contexts.get(&http_ctx_id) { + final_headers = ctx.response_headers.clone(); + } + } + + // ─── on_log + on_done + on_delete ─── + self.invoke_log_done_delete(http_ctx_id)?; + + let mut new_resp_parts = resp_parts; + new_resp_parts.headers = final_headers; + Ok(SgResponse::from_parts(new_resp_parts, final_body)) + } + + /// 在 guest 返回 Pause 之后,不停地 await dispatch_rx 来驱动状态机, + /// 直到 guest `continue_stream(HTTP_REQUEST/RESPONSE)` 或写了 `local_response`。 + async fn drive_until_continue(&mut self, ctx_id: u32) -> Result<(), WasmHostError> { + loop { + { + let st = self.store.data(); + let Some(ctx) = st.contexts.get(&ctx_id) else { + return Err(WasmHostError::AbiViolation(format!("ctx {ctx_id} gone"))); + }; + if ctx.local_response.is_some() { + return Ok(()); + } + if ctx.continue_requested && st.pending_calls.is_empty() { + return Ok(()); + } + // 无 outbound call 的 Pause(例如 defer 到 body hook)不应阻塞等待 dispatch 结果。 + if st.pending_calls.is_empty() { + return Ok(()); + } + } + let Some((token, result)) = self.dispatch_rx.recv().await else { + return Err(WasmHostError::Dispatch("dispatch channel closed".to_string())); + }; + let source_ctx_id = self.store.data_mut().pending_calls.remove(&token).map(|p| p.source_context_id).unwrap_or(ctx_id); + let header_count; + let body_len; + { + let st = self.store.data_mut(); + st.effective_context = source_ctx_id; + if let Some(ctx) = st.contexts.get_mut(&source_ctx_id) { + ctx.last_call_headers = result.headers.clone(); + ctx.last_call_status = result.status; + ctx.last_call_status_message = result.status_message.clone(); + ctx.last_call_body = result.body.clone(); + ctx.continue_requested = false; + } + header_count = result.headers.len() as u32 + 1; + body_len = result.body.len() as u32; + } + debug!(target: "spacegate_plugin_wasm", token, source_ctx_id, status = result.status, body_len, "fire proxy_on_http_call_response"); + let f = self.fn_on_http_call_response.clone(); + self.prepare_guest_call()?; + f.call(&mut self.store, (source_ctx_id, token, header_count, body_len, 0)).map_err(|e| WasmHostError::GuestTrap { + hook: "on_http_call_response", + source: e, + })?; + } + } + + fn invoke_log_done_delete(&mut self, ctx_id: u32) -> Result<(), WasmHostError> { + self.store.data_mut().effective_context = ctx_id; + if let Some(ctx) = self.store.data_mut().contexts.get_mut(&ctx_id) { + ctx.stage = ContextStage::Log; + } + if let Some(f) = self.fn_on_log.clone() { + self.prepare_guest_call()?; + let _ = f.call(&mut self.store, ctx_id); + } + if let Some(f) = self.fn_on_done.clone() { + // spec §proxy_on_done:返回 false 表示 plugin 还要再调 `proxy_done`。 + // 当前 http context 在请求结束时即刻销毁,host 没有"再等一会"的空间: + // 标记 awaiting_done 让 `proxy_done` 能 Ok 一次,guest 若在 on_log 里立刻 done 则完美; + // 否则强制完成并 warn。 + if let Some(ctx) = self.store.data_mut().contexts.get_mut(&ctx_id) { + ctx.awaiting_done = true; + } + self.prepare_guest_call()?; + let v = f.call(&mut self.store, ctx_id).unwrap_or(1); + let done = v != 0 || self.store.data().contexts.get(&ctx_id).map(|c| c.done_marker).unwrap_or(true); + if !done { + warn!( + target: "spacegate_plugin_wasm", + ctx_id, + "proxy_on_done returned false but http context cannot defer; forcing delete" + ); + } + } + if let Some(f) = self.fn_on_delete.clone() { + self.prepare_guest_call()?; + let _ = f.call(&mut self.store, ctx_id); + } + self.store.data_mut().contexts.remove(&ctx_id); + Ok(()) + } + + pub fn fail_strategy(&self) -> FailStrategy { + self.fail_strategy + } + + /// guest 当前请求的 `proxy_set_tick_period_milliseconds` 值;0 表示尚未配置 / 已停。 + pub fn tick_period_ms(&self) -> u32 { + self.store.data().tick_period_ms.unwrap_or(0) + } + + /// 在 root_context 上同步触发一次 `proxy_on_tick`。host 端后台任务调用本方法。 + /// + /// 失败要么是 guest trap(要么后台任务自停),要么是 guest 没导出 `proxy_on_tick`——后者直接 Ok。 + pub fn tick(&mut self) -> Result<(), WasmHostError> { + let Some(f) = self.fn_on_tick.clone() else { + return Ok(()); + }; + self.store.data_mut().effective_context = self.root_id; + self.prepare_guest_call()?; + f.call(&mut self.store, self.root_id).map_err(|e| WasmHostError::GuestTrap { hook: "on_tick", source: e })?; + Ok(()) + } + + fn prepare_guest_call(&mut self) -> Result<(), WasmHostError> { + prepare_store_for_guest_call(&mut self.store) + } +} + +fn prepare_store_for_guest_call(store: &mut Store) -> Result<(), WasmHostError> { + let fuel = store.data().shell_cfg.guest_fuel_per_call(); + store.set_fuel(fuel).map_err(|e| WasmHostError::ResourceLimit(format!("set fuel: {e}")))?; + let deadline = store.data().shell_cfg.guest_epoch_deadline_ticks(); + store.set_epoch_deadline(deadline); + store.epoch_deadline_trap(); + Ok(()) +} + +async fn collect_body_limited(body: SgBody, limit: Option) -> Result { + if let Some(limit) = limit { + let limited = Limited::new(body, limit); + let collected = limited.collect().await.map_err(|_| WasmHostError::BodyTooLarge { + actual: limit.saturating_add(1), + limit, + })?; + let bytes = collected.to_bytes(); + if bytes.len() > limit { + return Err(WasmHostError::BodyTooLarge { actual: bytes.len(), limit }); + } + Ok(bytes) + } else { + Ok(body.collect().await.map(|c| c.to_bytes()).unwrap_or_default()) + } +} + +fn rebuild_uri(scheme: &str, authority: &str, path: &str) -> Option { + let mut s = String::new(); + if !scheme.is_empty() && !authority.is_empty() { + s.push_str(scheme); + s.push_str("://"); + s.push_str(authority); + } + if !path.is_empty() { + s.push_str(path); + } else { + s.push('/'); + } + s.parse().ok() +} + +fn build_local_response(local: crate::host_state::LocalResponse) -> SgResponse { + let mut resp = SgResponse::new(SgBody::full(local.body)); + *resp.status_mut() = http::StatusCode::from_u16(local.status).unwrap_or(http::StatusCode::OK); + for (k, v) in local.headers.iter() { + resp.headers_mut().insert(k, v.clone()); + } + resp +} + +/// spec §Unimplemented WASI functions + §Logging §Clocks §Randomness:完整的 wasi_snapshot_preview1 子集。 +/// +/// - `random_get`:用 OS RNG(spec §Randomness)。 +/// - `clock_time_get`:spec §Clocks,REALTIME 用 SystemTime,MONOTONIC 用 Instant。 +/// - `environ_get` / `environ_sizes_get`:spec 明确不暴露 host env,全部 0。 +/// - `fd_write`:spec §Logging:fd=1→INFO,fd=2→ERROR;解析 iovec 提取 bytes。 +/// - `args_sizes_get` / `args_get`:spec §Unimplemented WASI,固定写 0。 +/// - `proc_exit`:spec §Unimplemented WASI,noop。 +pub fn register_wasi_stubs(linker: &mut Linker) -> Result<(), wasmtime::Error> { + use crate::abi::{wasi_errno, wasi_fd}; + + linker.func_wrap( + "wasi_snapshot_preview1", + "random_get", + |mut caller: wasmtime::Caller<'_, HostState>, ptr: i32, len: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return wasi_errno::FAULT, + }; + let mut buf = vec![0u8; len.max(0) as usize]; + if getrandom::getrandom(&mut buf).is_err() { + return wasi_errno::FAULT; + } + if mem.write_bytes(caller.as_context_mut(), ptr as u32, &buf).is_err() { + return wasi_errno::FAULT; + } + wasi_errno::SUCCESS + }, + )?; + linker.func_wrap( + "wasi_snapshot_preview1", + "clock_time_get", + |mut caller: wasmtime::Caller<'_, HostState>, clock_id: i32, _prec: i64, return_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return wasi_errno::FAULT, + }; + let nanos: u64 = match clock_id { + 0 => std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_nanos() as u64).unwrap_or(0), + 1 => { + static EPOCH: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); + let epoch = EPOCH.get_or_init(std::time::Instant::now); + epoch.elapsed().as_nanos() as u64 + } + _ => return wasi_errno::NOTSUP, + }; + if mem.write_u64(caller.as_context_mut(), return_ptr as u32, nanos).is_err() { + return wasi_errno::FAULT; + } + wasi_errno::SUCCESS + }, + )?; + linker.func_wrap("wasi_snapshot_preview1", "environ_get", |_c: wasmtime::Caller<'_, HostState>, _a: i32, _b: i32| -> i32 { + wasi_errno::SUCCESS + })?; + linker.func_wrap( + "wasi_snapshot_preview1", + "environ_sizes_get", + |mut caller: wasmtime::Caller<'_, HostState>, count_ptr: i32, buf_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return wasi_errno::FAULT, + }; + if mem.write_u32(caller.as_context_mut(), count_ptr as u32, 0).is_err() { + return wasi_errno::FAULT; + } + if mem.write_u32(caller.as_context_mut(), buf_ptr as u32, 0).is_err() { + return wasi_errno::FAULT; + } + wasi_errno::SUCCESS + }, + )?; + linker.func_wrap("wasi_snapshot_preview1", "args_get", |_c: wasmtime::Caller<'_, HostState>, _a: i32, _b: i32| -> i32 { + wasi_errno::SUCCESS + })?; + linker.func_wrap( + "wasi_snapshot_preview1", + "args_sizes_get", + |mut caller: wasmtime::Caller<'_, HostState>, argc_ptr: i32, buf_size_ptr: i32| -> i32 { + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return wasi_errno::FAULT, + }; + if mem.write_u32(caller.as_context_mut(), argc_ptr as u32, 0).is_err() { + return wasi_errno::FAULT; + } + if mem.write_u32(caller.as_context_mut(), buf_size_ptr as u32, 0).is_err() { + return wasi_errno::FAULT; + } + wasi_errno::SUCCESS + }, + )?; + linker.func_wrap( + "wasi_snapshot_preview1", + "fd_write", + |mut caller: wasmtime::Caller<'_, HostState>, fd: i32, iovs: i32, iovs_len: i32, nwritten_ptr: i32| -> i32 { + // spec §Logging:fd=1→INFO,fd=2→ERROR;其它 fd → BADF。 + if fd != wasi_fd::STDOUT && fd != wasi_fd::STDERR { + return wasi_errno::BADF; + } + let mem = match MemoryHelper::from_caller(&mut caller) { + Ok(m) => m, + Err(_) => return wasi_errno::FAULT, + }; + // iovec[]:每项 (buf_ptr: u32, buf_len: u32),共 iovs_len 项。 + let mut total: u32 = 0; + let mut bytes_out: Vec = Vec::new(); + for i in 0..(iovs_len as u32) { + let entry_ptr = (iovs as u32) + i * 8; + let Ok(buf_ptr) = mem.read_u32(caller.as_context(), entry_ptr) else { + return wasi_errno::FAULT; + }; + let Ok(buf_len) = mem.read_u32(caller.as_context(), entry_ptr + 4) else { + return wasi_errno::FAULT; + }; + let Ok(chunk) = mem.read_bytes(caller.as_context(), buf_ptr, buf_len) else { + return wasi_errno::FAULT; + }; + bytes_out.extend_from_slice(&chunk); + total = total.saturating_add(buf_len); + } + let msg = String::from_utf8_lossy(&bytes_out); + let msg_trimmed = msg.trim_end_matches('\n'); + if fd == wasi_fd::STDOUT { + tracing::info!(target: "spacegate_plugin_wasm::guest::stdout", "{msg_trimmed}"); + } else { + tracing::error!(target: "spacegate_plugin_wasm::guest::stderr", "{msg_trimmed}"); + } + if mem.write_u32(caller.as_context_mut(), nwritten_ptr as u32, total).is_err() { + return wasi_errno::FAULT; + } + wasi_errno::SUCCESS + }, + )?; + linker.func_wrap("wasi_snapshot_preview1", "proc_exit", |_c: wasmtime::Caller<'_, HostState>, _code: i32| {})?; + Ok(()) +} diff --git a/crates/plugin-wasm/tests/http_call.rs b/crates/plugin-wasm/tests/http_call.rs new file mode 100644 index 00000000..b7022764 --- /dev/null +++ b/crates/plugin-wasm/tests/http_call.rs @@ -0,0 +1,299 @@ +//! 端到端验证 `proxy_http_call` → `proxy_on_http_call_response` 链路: +//! 模式来自 [`sdk_examples_guest`] 的 `auth_random`,guest 在 `on_request_headers` +//! 发起一次外呼,host 通过 reqwest 真正打到一个本地 mock HTTP server,server +//! 返回一段固定字节;guest 的 `on_http_call_response` 根据第一个字节决定 +//! `resume_http_request()` 放行 / `send_local_response(403)`。 +//! +//! 这条测试是 `proxy_http_call` 的唯一覆盖路径——host fn 注册、token 分配、 +//! reqwest spawn、UnboundedSender → drive_until_continue 状态机、effective_context +//! 切换、guest 通过 `get_http_call_response_body` 读 body —— 一次跑齐。 + +use std::convert::Infallible; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Request as HyperRequest, Response}; +use hyper_util::rt::TokioIo; +use spacegate_kernel::backend_service::ArcHyperService; +use spacegate_kernel::helper_layers::function::Inner; +use spacegate_kernel::{SgBody, SgRequest, SgResponse}; +use spacegate_plugin::{Plugin, PluginConfig, PluginInstanceId, PluginInstanceName}; +use spacegate_plugin_wasm::config::WasmPluginShellConfig; +use spacegate_plugin_wasm::engine::shared_engine; +use spacegate_plugin_wasm::vm::Vm; +use spacegate_plugin_wasm::WasmPluginShell; +use tokio::net::TcpListener; +use wasmtime::Module; + +// ───────────────────────────────────────────────────────── +// 共用:定位 sdk_examples_guest.wasm(与 sdk_examples.rs 相同;故意复制以保持测试独立) +// ───────────────────────────────────────────────────────── + +fn guest_manifest_path() -> PathBuf { + let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push("tests"); + p.push("sdk_examples_guest"); + p.push("Cargo.toml"); + p +} + +fn guest_wasm_path() -> PathBuf { + let manifest = guest_manifest_path(); + let out = std::process::Command::new(env!("CARGO")) + .args(["metadata", "--no-deps", "--format-version", "1", "--manifest-path"]) + .arg(&manifest) + .output() + .expect("cargo metadata: spawn"); + assert!(out.status.success(), "cargo metadata failed: {}", String::from_utf8_lossy(&out.stderr)); + let meta: serde_json::Value = serde_json::from_slice(&out.stdout).expect("parse cargo metadata json"); + let target_dir = meta["target_directory"].as_str().expect("target_directory missing"); + PathBuf::from(target_dir).join("wasm32-wasip1").join("release").join("sdk_examples_guest.wasm") +} + +fn ensure_guest_built() -> PathBuf { + let wasm = guest_wasm_path(); + if !wasm.exists() { + let status = std::process::Command::new(env!("CARGO")) + .args(["build", "--release", "--target", "wasm32-wasip1", "--manifest-path"]) + .arg(guest_manifest_path()) + .status() + .expect("cargo build: spawn"); + assert!(status.success(), "sdk_examples_guest build failed"); + assert!(wasm.exists(), "wasm still missing after build: {wasm:?}"); + } + wasm +} + +fn load_module() -> Arc { + let path = ensure_guest_built(); + let bytes = std::fs::read(&path).expect("read wasm"); + Arc::new(Module::new(shared_engine(), &bytes).expect("Module::new")) +} + +// ───────────────────────────────────────────────────────── +// mock HTTP server:返回单字节 body,用于驱动 auth_random 判断 +// ───────────────────────────────────────────────────────── + +async fn start_mock_server(body_byte: u8) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local_addr"); + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(s) => s, + Err(_) => return, + }; + tokio::spawn(async move { + let svc = service_fn(move |_req: HyperRequest| async move { + let body = Bytes::from(vec![body_byte]); + let resp = Response::builder().status(200).body(Full::new(body)).expect("build resp"); + Ok::<_, Infallible>(resp) + }); + let _ = http1::Builder::new().serve_connection(TokioIo::new(stream), svc).await; + }); + } + }); + addr +} + +async fn start_delayed_mock_server(body_byte: u8, delay: Duration) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local_addr"); + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(s) => s, + Err(_) => return, + }; + tokio::spawn(async move { + let svc = service_fn(move |_req: HyperRequest| async move { + tokio::time::sleep(delay).await; + let body = Bytes::from(vec![body_byte]); + let resp = Response::builder().status(200).body(Full::new(body)).expect("build resp"); + Ok::<_, Infallible>(resp) + }); + let _ = http1::Builder::new().serve_connection(TokioIo::new(stream), svc).await; + }); + } + }); + addr +} + +// ───────────────────────────────────────────────────────── +// mock inner.call:guest 放行后会下沉到这里,echo body 即可 +// ───────────────────────────────────────────────────────── + +fn echo_inner() -> Inner { + let svc = service_fn(|req: SgRequest| async move { + let (_, body) = req.into_parts(); + let bytes = body.collect().await.map(|c| c.to_bytes()).unwrap_or_default(); + let mut resp = SgResponse::new(SgBody::full(bytes)); + *resp.status_mut() = http::StatusCode::OK; + Ok::<_, Infallible>(resp) + }); + Inner::new(ArcHyperService::new(svc)) +} + +async fn full_body(resp: SgResponse) -> (SgResponse, Bytes) { + let (parts, body) = resp.into_parts(); + let bytes = body.collect().await.map(|c| c.to_bytes()).unwrap_or_default(); + (SgResponse::from_parts(parts, SgBody::full(bytes.clone())), bytes) +} + +async fn run(auth_byte: u8) -> (u16, Bytes) { + let addr = start_mock_server(auth_byte).await; + // 给 server 一个起跳的间隙;用 50ms 兜底(tokio 实际可即时 accept)。 + tokio::time::sleep(Duration::from_millis(20)).await; + + let module = load_module(); + let cfg = Arc::new(WasmPluginShellConfig { + url: "file://sdk_examples_guest".into(), + plugin_config: serde_json::json!({ + "mode": "auth_random", + "auth_cluster": "auth", + "auth_threshold": 128 + }), + clusters: [("auth".to_string(), format!("http://{addr}"))].into_iter().collect(), + ..Default::default() + }); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder() + .method("POST") + .uri("http://example.test/") + .header("host", "example.test") + .body(SgBody::full(Bytes::from_static(b"protected payload"))) + .expect("build req"); + + let resp = vm.process(req, echo_inner()).await.expect("process"); + let (resp, body) = full_body(resp).await; + (resp.status().as_u16(), body) +} + +fn protected_request() -> SgRequest { + protected_request_with_policy(None) +} + +fn protected_request_with_policy(policy: Option<&str>) -> SgRequest { + let mut builder = HyperRequest::builder().method("POST").uri("http://example.test/").header("host", "example.test"); + if let Some(policy) = policy { + builder = builder.header("x-ratelimit-policy", policy); + } + builder.body(SgBody::full(Bytes::from_static(b"protected payload"))).expect("build req") +} + +// ───────────────────────────────────────────────────────── +// auth byte < threshold → 放行;echo 回原 body +// ───────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_random_allow() { + let (status, body) = run(50).await; + assert_eq!(status, 200, "expected allow → echo"); + assert_eq!(body, Bytes::from_static(b"protected payload")); +} + +// ───────────────────────────────────────────────────────── +// auth byte >= threshold → guest 短路 403 +// ───────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_random_deny() { + let (status, body) = run(200).await; + assert_eq!(status, 403, "expected deny → 403"); + assert_eq!(body, Bytes::from_static(b"forbidden")); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn vm_pool_runs_slow_http_calls_concurrently() { + let wasm = ensure_guest_built(); + let addr = start_delayed_mock_server(50, Duration::from_millis(450)).await; + tokio::time::sleep(Duration::from_millis(20)).await; + + let shell = WasmPluginShell::create(PluginConfig { + id: PluginInstanceId { + code: "wasm".into(), + name: PluginInstanceName::named("vm-pool-test"), + }, + spec: serde_json::json!({ + "url": format!("file://{}", wasm.display()), + "plugin_config": { + "mode": "auth_random", + "auth_cluster": "auth", + "auth_threshold": 128 + }, + "clusters": { + "auth": format!("http://{addr}") + }, + "vm_pool_size": 2 + }), + }) + .expect("create wasm shell"); + + let started = Instant::now(); + let (resp1, resp2) = tokio::join!(shell.call(protected_request(), echo_inner()), shell.call(protected_request(), echo_inner())); + let elapsed = started.elapsed(); + + let (resp1, body1) = full_body(resp1.expect("resp1")).await; + let (resp2, body2) = full_body(resp2.expect("resp2")).await; + assert_eq!(resp1.status(), http::StatusCode::OK); + assert_eq!(resp2.status(), http::StatusCode::OK); + assert_eq!(body1, Bytes::from_static(b"protected payload")); + assert_eq!(body2, Bytes::from_static(b"protected payload")); + assert!( + elapsed < Duration::from_millis(800), + "expected two 450ms dispatches to overlap with vm_pool_size=2, elapsed={elapsed:?}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn wait_policy_uses_separate_vm_pool() { + let wasm = ensure_guest_built(); + let addr = start_delayed_mock_server(50, Duration::from_millis(450)).await; + tokio::time::sleep(Duration::from_millis(20)).await; + + let shell = WasmPluginShell::create(PluginConfig { + id: PluginInstanceId { + code: "wasm".into(), + name: PluginInstanceName::named("wait-vm-pool-test"), + }, + spec: serde_json::json!({ + "url": format!("file://{}", wasm.display()), + "plugin_config": { + "mode": "auth_random", + "auth_cluster": "auth", + "auth_threshold": 128 + }, + "clusters": { + "auth": format!("http://{addr}") + }, + "vm_pool_size": 1, + "wait_vm_pool_size": 1 + }), + }) + .expect("create wasm shell"); + + let started = Instant::now(); + let (wait_resp, normal_resp) = tokio::join!(shell.call(protected_request_with_policy(Some("wait")), echo_inner()), async { + tokio::time::sleep(Duration::from_millis(50)).await; + shell.call(protected_request(), echo_inner()).await + }); + let elapsed = started.elapsed(); + + let (wait_resp, wait_body) = full_body(wait_resp.expect("wait resp")).await; + let (normal_resp, normal_body) = full_body(normal_resp.expect("normal resp")).await; + assert_eq!(wait_resp.status(), http::StatusCode::OK); + assert_eq!(normal_resp.status(), http::StatusCode::OK); + assert_eq!(wait_body, Bytes::from_static(b"protected payload")); + assert_eq!(normal_body, Bytes::from_static(b"protected payload")); + assert!( + elapsed < Duration::from_millis(800), + "expected wait traffic to use wait_vm_pool and not block normal pool, elapsed={elapsed:?}" + ); +} diff --git a/crates/plugin-wasm/tests/on_tick.rs b/crates/plugin-wasm/tests/on_tick.rs new file mode 100644 index 00000000..0cc316ec --- /dev/null +++ b/crates/plugin-wasm/tests/on_tick.rs @@ -0,0 +1,104 @@ +//! 端到端验证 `proxy_set_tick_period_milliseconds` + `proxy_on_tick` ⇄ host VmPool +//! 后台 tick 任务的协同: +//! +//! 1. `WasmPluginShell::create` 后,shell 内部起一条 50ms 颗粒度的 tick 循环; +//! 2. guest 在 `on_vm_start` 把 tick 周期设为 50ms,`on_tick` 把 shared_data 计数原子 +1; +//! 3. 测试 sleep 几个 tick 周期后从 host 侧直接读 shared_data,断言至少 N 次 tick; +//! 4. `drop(shell)` 后再 sleep,确认计数不再继续增长(tick 任务随 shell 析构)。 + +use std::path::PathBuf; +use std::time::Duration; + +use spacegate_model::{PluginInstanceId, PluginInstanceName}; +use spacegate_plugin::{Plugin, PluginConfig}; +use spacegate_plugin_wasm::shared::{shared_data_get, shared_data_set}; +use spacegate_plugin_wasm::WasmPluginShell; + +// ───────────────────────────────────────────────────────── +// guest .wasm 定位/构建 +// ───────────────────────────────────────────────────────── + +fn guest_manifest_path() -> PathBuf { + let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push("tests"); + p.push("on_tick_guest"); + p.push("Cargo.toml"); + p +} + +fn guest_wasm_path() -> PathBuf { + let manifest = guest_manifest_path(); + let out = std::process::Command::new(env!("CARGO")) + .args(["metadata", "--no-deps", "--format-version", "1", "--manifest-path"]) + .arg(&manifest) + .output() + .expect("cargo metadata: spawn"); + assert!(out.status.success(), "cargo metadata failed: {}", String::from_utf8_lossy(&out.stderr)); + let meta: serde_json::Value = serde_json::from_slice(&out.stdout).expect("parse cargo metadata json"); + let target_dir = meta["target_directory"].as_str().expect("target_directory missing"); + PathBuf::from(target_dir).join("wasm32-wasip1").join("release").join("on_tick_guest.wasm") +} + +fn ensure_guest_built() -> PathBuf { + let wasm = guest_wasm_path(); + if !wasm.exists() { + let status = std::process::Command::new(env!("CARGO")) + .args(["build", "--release", "--target", "wasm32-wasip1", "--manifest-path"]) + .arg(guest_manifest_path()) + .status() + .expect("cargo build: spawn"); + assert!(status.success(), "on_tick_guest build failed"); + assert!(wasm.exists(), "wasm still missing after build: {wasm:?}"); + } + wasm +} + +fn read_counter() -> u64 { + let (raw, _cas) = shared_data_get(b"on_tick.count").expect("counter present"); + std::str::from_utf8(&raw).ok().and_then(|s| s.parse::().ok()).unwrap_or(0) +} + +// ───────────────────────────────────────────────────────── +// 主测试 +// ───────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn proxy_on_tick_drives_background_ticks() { + // 把 shared_data 计数器清零(cas=0 = 无 CAS 期望)。 + let _ = shared_data_set(b"on_tick.count", b"0", 0); + + let wasm = ensure_guest_built(); + let plugin_config = PluginConfig { + id: PluginInstanceId { + code: "wasm".into(), + name: PluginInstanceName::named("on-tick-test"), + }, + spec: serde_json::json!({ + "url": format!("file://{}", wasm.display()), + "plugin_name": "on-tick-plugin", + "plugin_root_id": "on-tick-root", + "plugin_vm_id": "default", + }), + }; + let shell = WasmPluginShell::create(plugin_config).expect("Plugin::create"); + + // shell 内部已经 spawn 了 50ms 颗粒度的 tick 任务; + // 期间 guest `on_vm_start` 把 period 设成 50ms。 + // 等 450ms 至少 4 次 tick(保留 CI / 本地调度抖动余量)。 + tokio::time::sleep(Duration::from_millis(450)).await; + + let count = read_counter(); + assert!(count >= 4, "expected >= 4 ticks in 450ms, got {count}"); + tracing::info!("got {count} ticks"); + + // 取一次 snapshot,drop 之后再 sleep 同等时间,断言不再继续增长(允许 1 次余量: + // task abort 与正在执行中的同步 tick() 之间可能交叠一次)。 + let snapshot = count; + drop(shell); + tokio::time::sleep(Duration::from_millis(200)).await; + let after = read_counter(); + assert!( + after.saturating_sub(snapshot) <= 1, + "tick task should stop after shell drop: snapshot={snapshot}, after={after}", + ); +} diff --git a/crates/plugin-wasm/tests/on_tick_guest/.cargo/config.toml b/crates/plugin-wasm/tests/on_tick_guest/.cargo/config.toml new file mode 100644 index 00000000..6b509f5b --- /dev/null +++ b/crates/plugin-wasm/tests/on_tick_guest/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasip1" diff --git a/crates/plugin-wasm/tests/on_tick_guest/Cargo.toml b/crates/plugin-wasm/tests/on_tick_guest/Cargo.toml new file mode 100644 index 00000000..6fb9b844 --- /dev/null +++ b/crates/plugin-wasm/tests/on_tick_guest/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "on_tick_guest" +version = "0.0.0" +edition = "2021" +publish = false +description = "Tiny proxy-wasm guest exercising proxy_set_tick_period_milliseconds + proxy_on_tick; persists tick count into shared_data for host-side assertion." + +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm = "0.2" + +[profile.release] +codegen-units = 1 +opt-level = "z" +lto = "fat" +strip = true +panic = "abort" diff --git a/crates/plugin-wasm/tests/on_tick_guest/src/lib.rs b/crates/plugin-wasm/tests/on_tick_guest/src/lib.rs new file mode 100644 index 00000000..f859de48 --- /dev/null +++ b/crates/plugin-wasm/tests/on_tick_guest/src/lib.rs @@ -0,0 +1,46 @@ +//! 验证 host 端 VmPool + 后台 tick 任务能正确驱动 `proxy_on_tick`。 +//! +//! - `on_vm_start`:设置 50ms 的 tick 周期; +//! - `on_tick`:把全局 tick 计数器(shared_data,key="on_tick.count")原子地 +1; +//! +//! 测试侧通过 `spacegate_plugin_wasm::shared::shared_data_get` 直接读 shared_data, +//! 等若干 tick 之后断言计数大于 0。 +//! +//! `set_shared_data` 用 cas-loop 保证多 VM / 后台并发也不会丢更新(虽然现在只有一条 tick 任务)。 + +use std::time::Duration; + +use proxy_wasm::hostcalls; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +const KEY: &str = "on_tick.count"; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Info); + proxy_wasm::set_root_context(|_| -> Box { Box::new(TickRoot) }); +}} + +struct TickRoot; +impl Context for TickRoot {} +impl RootContext for TickRoot { + fn on_vm_start(&mut self, _: usize) -> bool { + // 50ms 周期:host 端默认 50ms 颗粒度的轮询正好能驱动。 + let _ = hostcalls::set_tick_period(Duration::from_millis(50)); + true + } + + fn on_tick(&mut self) { + // cas 循环:读 → +1 → 写;写失败 (CasMismatch) 重读重试。 + for _ in 0..8 { + let (cur, cas) = hostcalls::get_shared_data(KEY).unwrap_or((None, None)); + let next = cur.as_deref().and_then(|b| std::str::from_utf8(b).ok()).and_then(|s| s.parse::().ok()).unwrap_or(0) + 1; + let buf = next.to_string(); + match hostcalls::set_shared_data(KEY, Some(buf.as_bytes()), cas) { + Ok(()) => return, + Err(Status::CasMismatch) => continue, + Err(_) => return, + } + } + } +} diff --git a/crates/plugin-wasm/tests/runtime_fetch.rs b/crates/plugin-wasm/tests/runtime_fetch.rs new file mode 100644 index 00000000..2a445a53 --- /dev/null +++ b/crates/plugin-wasm/tests/runtime_fetch.rs @@ -0,0 +1,192 @@ +//! Covers wasm module loading concerns that sit below Proxy-Wasm execution: +//! remote fetch, digest verification, and cache invalidation. + +use std::collections::HashMap; +use std::convert::Infallible; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use sha2::{Digest, Sha256}; +use spacegate_plugin_wasm::config::WasmPluginShellConfig; +use spacegate_plugin_wasm::fetch::fetch_wasm_bytes_sync; +use spacegate_plugin_wasm::runtime::WasmModuleCache; +use tokio::net::TcpListener; + +async fn start_bytes_server(body: Bytes) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local_addr"); + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(s) => s, + Err(_) => return, + }; + let body = body.clone(); + tokio::spawn(async move { + let svc = service_fn(move |_req: Request| { + let body = body.clone(); + async move { Ok::<_, Infallible>(Response::new(Full::new(body))) } + }); + let _ = http1::Builder::new().serve_connection(TokioIo::new(stream), svc).await; + }); + } + }); + addr +} + +async fn start_oci_registry_server(wasm: Bytes) -> SocketAddr { + let digest = sha256_hex(&wasm); + let manifest = Bytes::from(format!( + r#"{{ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "config": {{"mediaType": "application/vnd.unknown.config.v1+json", "digest": "sha256:{}", "size": 2}}, + "layers": [ + {{"mediaType": "application/vnd.module.wasm.content.layer.v1+wasm", "digest": "sha256:{digest}", "size": {}}} + ] +}}"#, + "0".repeat(64), + wasm.len() + )); + let mut routes = HashMap::new(); + routes.insert("/v2/plugin/manifests/v1".to_string(), manifest); + routes.insert(format!("/v2/plugin/blobs/sha256:{digest}"), wasm); + let routes = Arc::new(routes); + + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local_addr"); + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(s) => s, + Err(_) => return, + }; + let routes = routes.clone(); + tokio::spawn(async move { + let svc = service_fn(move |req: Request| { + let routes = routes.clone(); + async move { + let path = req.uri().path().to_string(); + if let Some(body) = routes.get(&path) { + Ok::<_, Infallible>(Response::new(Full::new(body.clone()))) + } else { + let mut resp = Response::new(Full::new(Bytes::from_static(b"not found"))); + *resp.status_mut() = StatusCode::NOT_FOUND; + Ok(resp) + } + } + }); + let _ = http1::Builder::new().serve_connection(TokioIo::new(stream), svc).await; + }); + } + }); + addr +} + +fn guest_manifest_path() -> PathBuf { + let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push("tests"); + p.push("on_tick_guest"); + p.push("Cargo.toml"); + p +} + +fn guest_wasm_path() -> PathBuf { + let manifest = guest_manifest_path(); + let out = std::process::Command::new(env!("CARGO")) + .args(["metadata", "--no-deps", "--format-version", "1", "--manifest-path"]) + .arg(&manifest) + .output() + .expect("cargo metadata: spawn"); + assert!(out.status.success(), "cargo metadata failed: {}", String::from_utf8_lossy(&out.stderr)); + let meta: serde_json::Value = serde_json::from_slice(&out.stdout).expect("parse cargo metadata json"); + let target_dir = meta["target_directory"].as_str().expect("target_directory missing"); + PathBuf::from(target_dir).join("wasm32-wasip1").join("release").join("on_tick_guest.wasm") +} + +fn ensure_guest_built() -> PathBuf { + let wasm = guest_wasm_path(); + if !wasm.exists() { + let status = std::process::Command::new(env!("CARGO")) + .args(["build", "--release", "--target", "wasm32-wasip1", "--manifest-path"]) + .arg(guest_manifest_path()) + .status() + .expect("cargo build: spawn"); + assert!(status.success(), "on_tick_guest build failed"); + assert!(wasm.exists(), "wasm still missing after build: {wasm:?}"); + } + wasm +} + +fn sha256_hex(bytes: &[u8]) -> String { + format!("{:x}", Sha256::digest(bytes)) +} + +#[tokio::test] +async fn fetch_wasm_bytes_supports_http_urls() { + let expected = Bytes::from_static(b"hello wasm over http"); + let addr = start_bytes_server(expected.clone()).await; + let url = format!("http://{addr}/plugin.wasm"); + + let fetched = tokio::task::spawn_blocking(move || fetch_wasm_bytes_sync(&url)).await.expect("join").expect("fetch"); + + assert_eq!(fetched, expected); +} + +#[tokio::test] +async fn fetch_wasm_bytes_supports_oci_image_layers() { + let expected = Bytes::from_static(b"\0asm\x01\0\0\0"); + let addr = start_oci_registry_server(expected.clone()).await; + let url = format!("oci://{addr}/plugin:v1"); + + let fetched = tokio::task::spawn_blocking(move || fetch_wasm_bytes_sync(&url)).await.expect("join").expect("fetch"); + + assert_eq!(fetched, expected); +} + +#[test] +fn wasm_module_cache_uses_module_cache_key_for_invalidation() { + let wasm = ensure_guest_built(); + let bytes = std::fs::read(&wasm).expect("read wasm"); + let sha256 = sha256_hex(&bytes); + let cache = WasmModuleCache::new(8); + + let cfg_v1 = WasmPluginShellConfig { + url: format!("file://{}", wasm.display()), + sha256: Some(sha256.clone()), + module_cache_key: Some("on-tick:v1".to_string()), + ..Default::default() + }; + let first = cache.get_or_compile(&cfg_v1).expect("compile v1"); + let cached = cache.get_or_compile(&cfg_v1).expect("compile v1 cached"); + assert!(Arc::ptr_eq(&first, &cached)); + + let cfg_v2 = WasmPluginShellConfig { + module_cache_key: Some("on-tick:v2".to_string()), + ..cfg_v1 + }; + let second = cache.get_or_compile(&cfg_v2).expect("compile v2"); + assert!(!Arc::ptr_eq(&first, &second)); +} + +#[test] +fn wasm_module_cache_rejects_sha256_mismatch() { + let wasm = ensure_guest_built(); + let cache = WasmModuleCache::new(8); + let cfg = WasmPluginShellConfig { + url: format!("file://{}", wasm.display()), + sha256: Some("sha256:0000000000000000000000000000000000000000000000000000000000000000".to_string()), + use_cache: false, + ..Default::default() + }; + + let err = cache.get_or_compile(&cfg).expect_err("expected sha mismatch"); + assert!(err.to_string().contains("sha256 mismatch"), "{err}"); +} diff --git a/crates/plugin-wasm/tests/sdk_examples.rs b/crates/plugin-wasm/tests/sdk_examples.rs new file mode 100644 index 00000000..023960c4 --- /dev/null +++ b/crates/plugin-wasm/tests/sdk_examples.rs @@ -0,0 +1,246 @@ +//! 用 `proxy-wasm-rust-sdk` 仓库 `examples/` 的 4 个范例对应行为构造 +//! [`sdk_examples_guest`] 单 wasm 多模式,逐个跑完整 [`Vm::process`] 链路。 +//! +//! 这个测试是真正的端到端:plugin configuration → on_configure → 请求进入插件 → +//! on_request_headers / on_request_body → inner.call(我们 mock 出来的 hyper 服务)→ +//! on_response_headers → 最终响应 → on_log。它直接证明: +//! +//! - SDK 标准范例的 host fn 调用面我们 host 全部正确实现; +//! - body 改写、本地响应短路、required header 拦截 这些跨阶段的协同没问题。 +//! +//! 本文件 **不** 覆盖 `auth_random`(需要 mock HTTP server 走 reqwest),那部分在 +//! [`tests/http_call.rs`] 单独测。 + +use std::convert::Infallible; +use std::path::PathBuf; +use std::sync::Arc; + +use bytes::Bytes; +use http_body_util::BodyExt; +use hyper::service::service_fn; +use hyper::Request as HyperRequest; +use spacegate_kernel::backend_service::ArcHyperService; +use spacegate_kernel::helper_layers::function::Inner; +use spacegate_kernel::{SgBody, SgRequest, SgResponse}; +use spacegate_plugin_wasm::config::WasmPluginShellConfig; +use spacegate_plugin_wasm::engine::shared_engine; +use spacegate_plugin_wasm::vm::Vm; +use wasmtime::Module; + +// ───────────────────────────────────────────────────────── +// 公共:定位/构建 sdk_examples_guest.wasm +// ───────────────────────────────────────────────────────── + +fn guest_manifest_path() -> PathBuf { + let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push("tests"); + p.push("sdk_examples_guest"); + p.push("Cargo.toml"); + p +} + +fn guest_wasm_path() -> PathBuf { + let manifest = guest_manifest_path(); + let out = std::process::Command::new(env!("CARGO")) + .args(["metadata", "--no-deps", "--format-version", "1", "--manifest-path"]) + .arg(&manifest) + .output() + .expect("cargo metadata: spawn"); + assert!(out.status.success(), "cargo metadata failed: {}", String::from_utf8_lossy(&out.stderr)); + let meta: serde_json::Value = serde_json::from_slice(&out.stdout).expect("parse cargo metadata json"); + let target_dir = meta["target_directory"].as_str().expect("target_directory missing"); + PathBuf::from(target_dir).join("wasm32-wasip1").join("release").join("sdk_examples_guest.wasm") +} + +fn ensure_guest_built() -> PathBuf { + let wasm = guest_wasm_path(); + if !wasm.exists() { + eprintln!("[sdk_examples] building sdk_examples_guest …"); + let status = std::process::Command::new(env!("CARGO")) + .args(["build", "--release", "--target", "wasm32-wasip1", "--manifest-path"]) + .arg(guest_manifest_path()) + .status() + .expect("cargo build: spawn"); + assert!(status.success(), "sdk_examples_guest build failed"); + assert!(wasm.exists(), "wasm still missing after build: {wasm:?}"); + } + wasm +} + +fn load_module() -> Arc { + let path = ensure_guest_built(); + let bytes = std::fs::read(&path).expect("read wasm"); + let module = Module::new(shared_engine(), &bytes).expect("Module::new"); + Arc::new(module) +} + +// ───────────────────────────────────────────────────────── +// mock `Inner`:把请求 body 原样回显,并复制 `x-echo-*` 头 +// ───────────────────────────────────────────────────────── + +#[derive(Clone, Default)] +struct CaptureState { + /// inner.call 真正收到的请求体(guest 修改后下沉的内容) + inbound_body: Arc>>, + /// inner.call 实际是否被调到(验证 send_local_response 的短路) + invoked: Arc, +} + +fn make_inner(state: CaptureState) -> Inner { + let svc = service_fn(move |req: SgRequest| { + let state = state.clone(); + async move { + state.invoked.store(true, std::sync::atomic::Ordering::SeqCst); + let (parts, body) = req.into_parts(); + let bytes = body.collect().await.map(|c| c.to_bytes()).unwrap_or_default(); + *state.inbound_body.lock().await = Some(bytes.clone()); + let mut resp = SgResponse::new(SgBody::full(bytes)); + for (k, v) in parts.headers.iter() { + if k.as_str().starts_with("x-echo-") { + resp.headers_mut().insert(k, v.clone()); + } + } + Ok::<_, Infallible>(resp) + } + }); + Inner::new(ArcHyperService::new(svc)) +} + +fn make_cfg(spec: serde_json::Value) -> Arc { + Arc::new(WasmPluginShellConfig { + url: "file://sdk_examples_guest".into(), + plugin_config: spec, + plugin_name: "sdk-examples-test".into(), + plugin_root_id: "sdk-examples-root".into(), + plugin_vm_id: "default".into(), + ..Default::default() + }) +} + +async fn full_body(resp: SgResponse) -> (SgResponse, Bytes) { + let (parts, body) = resp.into_parts(); + let bytes = body.collect().await.map(|c| c.to_bytes()).unwrap_or_default(); + (SgResponse::from_parts(parts, SgBody::full(bytes.clone())), bytes) +} + +// ───────────────────────────────────────────────────────── +// 1. http_headers:/hello → 本地 200 + Hello/Powered-By;其余 → 走 inner,加 x-sdk-headers 响应头 +// ───────────────────────────────────────────────────────── + +#[tokio::test] +async fn sdk_example_http_headers_hello() { + let module = load_module(); + let cfg = make_cfg(serde_json::json!({"mode": "headers"})); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder().method("GET").uri("http://example.test/hello").header("host", "example.test").body(SgBody::empty()).expect("build req"); + let captured = CaptureState::default(); + let inner = make_inner(captured.clone()); + let resp = vm.process(req, inner).await.expect("process"); + let (resp, body) = full_body(resp).await; + + assert_eq!(resp.status(), 200); + assert_eq!(body, Bytes::from_static(b"Hello, World!\n")); + assert_eq!(resp.headers().get("hello").and_then(|v| v.to_str().ok()), Some("world")); + assert_eq!(resp.headers().get("powered-by").and_then(|v| v.to_str().ok()), Some("proxy-wasm")); + assert!( + !captured.invoked.load(std::sync::atomic::Ordering::SeqCst), + "inner.call must NOT be invoked for local response" + ); +} + +#[tokio::test] +async fn sdk_example_http_headers_passthrough() { + let module = load_module(); + let cfg = make_cfg(serde_json::json!({"mode": "headers"})); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder() + .method("GET") + .uri("http://example.test/world") + .header("host", "example.test") + .header("x-echo-foo", "bar") + .body(SgBody::empty()) + .expect("build req"); + let captured = CaptureState::default(); + let resp = vm.process(req, make_inner(captured.clone())).await.expect("process"); + let (resp, _body) = full_body(resp).await; + + assert_eq!(resp.status(), 200); + assert!(captured.invoked.load(std::sync::atomic::Ordering::SeqCst)); + assert_eq!( + resp.headers().get("x-sdk-headers").and_then(|v| v.to_str().ok()), + Some("seen"), + "on_response_headers should inject x-sdk-headers" + ); + // echo header 应该原路回来 + assert_eq!(resp.headers().get("x-echo-foo").and_then(|v| v.to_str().ok()), Some("bar")); +} + +// ───────────────────────────────────────────────────────── +// 2. http_body:on_request_body 把 body 反转后下沉给 inner.call +// ───────────────────────────────────────────────────────── + +#[tokio::test] +async fn sdk_example_http_body_reverses_request_body() { + let module = load_module(); + let cfg = make_cfg(serde_json::json!({"mode": "body"})); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder() + .method("POST") + .uri("http://example.test/reverse") + .header("host", "example.test") + .body(SgBody::full(Bytes::from_static(b"abc-123"))) + .expect("build req"); + let captured = CaptureState::default(); + let resp = vm.process(req, make_inner(captured.clone())).await.expect("process"); + let (resp, body) = full_body(resp).await; + + assert_eq!(resp.status(), 200); + // Inner 收到的应是反转后的字节,echo 回来后响应体也是它。 + assert_eq!(captured.inbound_body.lock().await.clone().expect("body captured"), Bytes::from_static(b"321-cba")); + assert_eq!(body, Bytes::from_static(b"321-cba")); +} + +// ───────────────────────────────────────────────────────── +// 3. http_config:缺失 x-token → 本地 403;带上 → 放行 +// ───────────────────────────────────────────────────────── + +#[tokio::test] +async fn sdk_example_http_config_missing_header_rejected() { + let module = load_module(); + let cfg = make_cfg(serde_json::json!({"mode": "config", "required_header": "x-token"})); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder().method("GET").uri("http://example.test/").header("host", "example.test").body(SgBody::empty()).expect("build req"); + let captured = CaptureState::default(); + let resp = vm.process(req, make_inner(captured.clone())).await.expect("process"); + let (resp, body) = full_body(resp).await; + + assert_eq!(resp.status(), 403); + assert_eq!(body, Bytes::from_static(b"missing required header")); + assert!(!captured.invoked.load(std::sync::atomic::Ordering::SeqCst)); +} + +#[tokio::test] +async fn sdk_example_http_config_present_header_passthrough() { + let module = load_module(); + let cfg = make_cfg(serde_json::json!({"mode": "config", "required_header": "x-token"})); + let mut vm = Vm::new(&module, cfg).expect("Vm::new"); + + let req = HyperRequest::builder() + .method("GET") + .uri("http://example.test/") + .header("host", "example.test") + .header("x-token", "abc") + .body(SgBody::full(Bytes::from_static(b"hello"))) + .expect("build req"); + let captured = CaptureState::default(); + let resp = vm.process(req, make_inner(captured.clone())).await.expect("process"); + let (resp, body) = full_body(resp).await; + + assert_eq!(resp.status(), 200); + assert_eq!(body, Bytes::from_static(b"hello")); + assert!(captured.invoked.load(std::sync::atomic::Ordering::SeqCst)); +} diff --git a/crates/plugin-wasm/tests/sdk_examples_guest/.cargo/config.toml b/crates/plugin-wasm/tests/sdk_examples_guest/.cargo/config.toml new file mode 100644 index 00000000..6b509f5b --- /dev/null +++ b/crates/plugin-wasm/tests/sdk_examples_guest/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasip1" diff --git a/crates/plugin-wasm/tests/sdk_examples_guest/Cargo.toml b/crates/plugin-wasm/tests/sdk_examples_guest/Cargo.toml new file mode 100644 index 00000000..727907a2 --- /dev/null +++ b/crates/plugin-wasm/tests/sdk_examples_guest/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "sdk_examples_guest" +version = "0.0.0" +edition = "2021" +publish = false +description = "Mirrors proxy-wasm-rust-sdk examples (http_headers / http_body / http_config / http_auth_random) inside one guest, selectable via plugin configuration." + +# 独立 workspace:本 crate 目标 `wasm32-wasip1`,不参与外层 host workspace。 +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm = "0.2" +log = "0.4" + +[profile.release] +codegen-units = 1 +opt-level = "z" +lto = "fat" +strip = true +panic = "abort" diff --git a/crates/plugin-wasm/tests/sdk_examples_guest/src/lib.rs b/crates/plugin-wasm/tests/sdk_examples_guest/src/lib.rs new file mode 100644 index 00000000..35223020 --- /dev/null +++ b/crates/plugin-wasm/tests/sdk_examples_guest/src/lib.rs @@ -0,0 +1,207 @@ +//! 把 `proxy-wasm/proxy-wasm-rust-sdk` 仓库 `examples/` 下 4 个范例汇总进同一个 guest, +//! 由 `on_configure` 读到的 plugin configuration 选择运行模式: +//! +//! - `mode: headers` ←→ examples/http_headers (读 req/resp 头 + `/hello` 本地响应) +//! - `mode: body` ←→ examples/http_body (on_request_body 反转字节后落到 inner.call) +//! - `mode: config` ←→ examples/http_config (要求请求带某个 header,缺失则 403) +//! - `mode: auth_random` ←→ examples/http_auth_random (`proxy_http_call` 到 "auth" cluster 决定放行) +//! +//! 之所以做成单 wasm + 模式切换,是为了集成测试只需构建一次 wasm 即可覆盖所有 SDK 范例。 +//! +//! configuration 直接吃明文 YAML:第一行 `mode: `,第二行(可选)模式相关参数。 +//! 这样不引入 serde_yaml/serde_json 依赖,wasm 体积更小、装配速度更快。 + +use log::{info, warn}; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Trace); + proxy_wasm::set_root_context(|_| -> Box { Box::new(SdkRoot::default()) }); +}} + +#[derive(Default, Clone)] +struct SdkConfig { + mode: Mode, + /// `config` 模式:缺失即 403 的请求头名字(默认 `x-token`)。 + required_header: String, + /// `auth_random` 模式:放行阈值。host 给出的 random byte < threshold 则视为允许。 + auth_threshold: u8, + /// `auth_random` 模式:用于 dispatch_http_call 的 cluster 名。 + auth_cluster: String, +} + +#[derive(Default, Clone, Copy, PartialEq, Eq)] +enum Mode { + #[default] + Noop, + Headers, + Body, + Config, + AuthRandom, +} + +impl Mode { + fn parse(s: &str) -> Self { + match s.trim() { + "headers" => Mode::Headers, + "body" => Mode::Body, + "config" => Mode::Config, + "auth_random" => Mode::AuthRandom, + _ => Mode::Noop, + } + } +} + +#[derive(Default)] +struct SdkRoot { + cfg: SdkConfig, +} + +impl Context for SdkRoot {} + +impl RootContext for SdkRoot { + fn on_vm_start(&mut self, _: usize) -> bool { true } + + fn on_configure(&mut self, _: usize) -> bool { + let raw = self.get_plugin_configuration().unwrap_or_default(); + let text = String::from_utf8_lossy(&raw); + let mut cfg = SdkConfig::default(); + cfg.required_header = "x-token".into(); + cfg.auth_threshold = 128; + cfg.auth_cluster = "auth".into(); + for line in text.lines() { + let Some((k, v)) = line.split_once(':') else { continue }; + // 去掉 YAML 单/双引号 + let v = v.trim().trim_matches(['"', '\''].as_ref()); + match k.trim() { + "mode" => cfg.mode = Mode::parse(v), + "required_header" => cfg.required_header = v.to_string(), + "auth_threshold" => cfg.auth_threshold = v.parse().unwrap_or(128), + "auth_cluster" => cfg.auth_cluster = v.to_string(), + _ => {} + } + } + info!("sdk_examples_guest configured: mode_set={}", cfg.mode != Mode::Noop); + self.cfg = cfg; + true + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(SdkHttp { + cfg: self.cfg.clone(), + pending_token: None, + })) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} + +struct SdkHttp { + cfg: SdkConfig, + pending_token: Option, +} + +impl Context for SdkHttp { + fn on_http_call_response(&mut self, token_id: u32, _num_headers: usize, body_size: usize, _num_trailers: usize) { + // 只 auth_random 模式会到这里;其它模式根本没发 dispatch。 + if Some(token_id) != self.pending_token { return; } + self.pending_token = None; + let body = self + .get_http_call_response_body(0, body_size) + .unwrap_or_default(); + let allow = body.first().map(|b| *b < self.cfg.auth_threshold).unwrap_or(false); + if allow { + // 放行:把 Pause 恢复 → 让 host 继续 inner.call。 + self.resume_http_request(); + } else { + self.send_http_response(403, vec![("x-rejected-by", "auth_random")], Some(b"forbidden")); + } + } +} + +impl HttpContext for SdkHttp { + fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { + match self.cfg.mode { + Mode::Headers => { + for (name, value) in &self.get_http_request_headers() { + info!("-> {name}: {value}"); + } + match self.get_http_request_header(":path") { + Some(p) if p == "/hello" => { + self.send_http_response( + 200, + vec![("hello", "world"), ("powered-by", "proxy-wasm")], + Some(b"Hello, World!\n"), + ); + Action::Pause + } + _ => Action::Continue, + } + } + Mode::Body => Action::Continue, + Mode::Config => { + if self.get_http_request_header(&self.cfg.required_header).is_some() { + Action::Continue + } else { + self.send_http_response( + 403, + vec![("x-rejected-by", "http_config")], + Some(b"missing required header"), + ); + Action::Pause + } + } + Mode::AuthRandom => { + match self.dispatch_http_call( + &self.cfg.auth_cluster.clone(), + vec![(":method", "GET"), (":path", "/random"), (":authority", "auth")], + None, + vec![], + std::time::Duration::from_millis(500), + ) { + Ok(token) => { + self.pending_token = Some(token); + Action::Pause + } + Err(s) => { + warn!("dispatch_http_call failed: status={s:?}"); + self.send_http_response(502, vec![("x-rejected-by", "dispatch_failed")], Some(b"upstream auth unreachable")); + Action::Pause + } + } + } + Mode::Noop => Action::Continue, + } + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + if !matches!(self.cfg.mode, Mode::Body) || !end_of_stream { + return Action::Continue; + } + if body_size == 0 { return Action::Continue; } + let body = self.get_http_request_body(0, body_size).unwrap_or_default(); + let mut rev = body.clone(); + rev.reverse(); + // spec §Buffers:start=0, size=原长度 → 替换;len(value)=新长度。 + let _ = self.set_http_request_body(0, body.len(), &rev); + Action::Continue + } + + fn on_http_response_headers(&mut self, _: usize, _: bool) -> Action { + if matches!(self.cfg.mode, Mode::Headers) { + for (name, value) in &self.get_http_response_headers() { + info!("<- {name}: {value}"); + } + // 给响应额外塞一条头:SDK 那个 example 没塞,我们这里塞便于测试断言。 + let _ = self.add_http_response_header("x-sdk-headers", "seen"); + } + Action::Continue + } + + fn on_log(&mut self) { + info!("sdk_examples_guest: ctx done."); + } +} diff --git a/crates/plugin-wasm/tests/spec_compliance.rs b/crates/plugin-wasm/tests/spec_compliance.rs new file mode 100644 index 00000000..90273454 --- /dev/null +++ b/crates/plugin-wasm/tests/spec_compliance.rs @@ -0,0 +1,238 @@ +//! End-to-end spec compliance test:用 `proxy-wasm-rust-sdk` 编出来的真实 guest 插件 +//! ([`crates/plugin-wasm/tests/spec_test_guest`]) 跑一遍我们 host 注册的所有 hostcall, +//! 覆盖 proxy-wasm v0.2.1 spec 关键面: +//! +//! - Shared K/V(带 CAS) +//! - Shared queues(含 register/resolve/enqueue/dequeue 全链) +//! - Metrics(counter / gauge / record / increment) +//! - Properties(user + well-known `plugin_name`) +//! - Logging / Clocks +//! - Buffer(PluginConfiguration) +//! - HTTP header map(含 `:method` 伪头) +//! - Stream control(continue / close) +//! - effective_context / done +//! - gRPC / foreign_function 的 spec 合规返回值 +//! - send_local_response 短路写入 +//! - set_tick_period 接收 +//! +//! 运行:先 `cd crates/plugin-wasm/tests/spec_test_guest && cargo build --release` +//! 生成 wasm,再 `cargo test -p spacegate-plugin-wasm --test spec_compliance`。 +//! 测试入口会通过 cargo metadata 自动定位 wasm 路径并按需在缺失时调用 cargo 触发构建。 + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use bytes::Bytes; +use http::HeaderMap; +use spacegate_plugin_wasm::config::WasmPluginShellConfig; +use spacegate_plugin_wasm::engine::shared_engine; +use spacegate_plugin_wasm::host_fn::register_all; +use spacegate_plugin_wasm::host_state::{ContextStage, HostState, HttpCallResult, PseudoHeaders, RequestContext}; +use spacegate_plugin_wasm::vm::register_wasi_stubs; +use wasmtime::{Instance, Linker, Module, Store, TypedFunc}; + +const HTTP_CONTEXT_ID: u32 = 2; + +// ───────────────────────────────────────────────────────── +// 定位 guest wasm;缺失则触发一次 `cargo build --release` +// ───────────────────────────────────────────────────────── + +fn guest_manifest_path() -> PathBuf { + let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push("tests"); + p.push("spec_test_guest"); + p.push("Cargo.toml"); + p +} + +/// 用 `cargo metadata` 拿独立 workspace 的 `target_directory`,再拼出 `wasm32-wasip1/release/spec_test_guest.wasm`。 +fn guest_wasm_path() -> PathBuf { + let manifest = guest_manifest_path(); + let out = std::process::Command::new(env!("CARGO")) + .args(["metadata", "--no-deps", "--format-version", "1", "--manifest-path"]) + .arg(&manifest) + .output() + .expect("cargo metadata: spawn"); + assert!(out.status.success(), "cargo metadata failed: {}", String::from_utf8_lossy(&out.stderr)); + let meta: serde_json::Value = serde_json::from_slice(&out.stdout).expect("parse cargo metadata json"); + let target_dir = meta["target_directory"].as_str().expect("target_directory missing"); + PathBuf::from(target_dir).join("wasm32-wasip1").join("release").join("spec_test_guest.wasm") +} + +fn ensure_guest_built() -> PathBuf { + let wasm = guest_wasm_path(); + if !wasm.exists() { + let manifest = guest_manifest_path(); + eprintln!("[spec_compliance] guest wasm not found at {wasm:?}; running `cargo build --release` for spec_test_guest"); + let status = std::process::Command::new(env!("CARGO")) + .args(["build", "--release", "--target", "wasm32-wasip1", "--manifest-path"]) + .arg(&manifest) + .status() + .expect("cargo build: spawn"); + assert!(status.success(), "spec_test_guest build failed (exit = {status:?})"); + assert!(wasm.exists(), "spec_test_guest.wasm still missing after build: {wasm:?}"); + } + wasm +} + +// ───────────────────────────────────────────────────────── +// 测试 harness:直接搭一个不走 Vm 的 store/linker/instance +// ───────────────────────────────────────────────────────── + +struct GuestVm { + store: Store, + instance: Instance, +} + +impl GuestVm { + fn new(wasm_bytes: &[u8], cfg: WasmPluginShellConfig, configuration: Vec) -> Self { + let engine = shared_engine(); + let module = Module::new(engine, wasm_bytes).expect("Module::new"); + + let mut host = HostState::new(Arc::new(cfg)); + host.configuration = configuration; + // 预置一个 HTTP context,方便头部 / send_local_response 类场景。 + let ctx = RequestContext { + parent_id: host.root_context_id, + stage: ContextStage::RequestHeaders, + request_pseudo: PseudoHeaders { + method: "POST".into(), + path: "/spec".into(), + authority: "spec.local".into(), + scheme: "http".into(), + }, + request_headers: HeaderMap::new(), + ..Default::default() + }; + host.contexts.insert(HTTP_CONTEXT_ID, ctx); + host.effective_context = HTTP_CONTEXT_ID; + + let mut store: Store = Store::new(engine, host); + prepare_guest_budget(&mut store); + let mut linker: Linker = Linker::new(engine); + // dispatch_tx 在本测试里不会被消费——保留 rx 不让通道关闭即可。 + let (dispatch_tx, _dispatch_rx) = tokio::sync::mpsc::unbounded_channel::<(u32, HttpCallResult)>(); + register_all(&mut linker, dispatch_tx).expect("register_all"); + register_wasi_stubs(&mut linker).expect("register_wasi_stubs"); + + let instance = linker.instantiate(&mut store, &module).expect("instantiate"); + let mem = instance.get_memory(&mut store, "memory").expect("memory export"); + store.data_mut().memory = Some(mem); + if let Ok(a) = instance.get_typed_func::(&mut store, "proxy_on_memory_allocate") { + store.data_mut().alloc = Some(a); + } else if let Ok(a) = instance.get_typed_func::(&mut store, "malloc") { + store.data_mut().alloc = Some(a); + } else { + panic!("guest exports neither proxy_on_memory_allocate nor malloc"); + } + + // _initialize 优先(SDK 在 wasm32-wasip1 上默认导这个),回退 _start。 + if let Ok(init) = instance.get_typed_func::<(), ()>(&mut store, "_initialize") { + prepare_guest_budget(&mut store); + init.call(&mut store, ()).expect("_initialize"); + } else if let Ok(start) = instance.get_typed_func::<(), ()>(&mut store, "_start") { + prepare_guest_budget(&mut store); + start.call(&mut store, ()).expect("_start"); + } + + GuestVm { store, instance } + } + + fn run_test(&mut self, scenario: u32) -> u32 { + let f: TypedFunc = self.instance.get_typed_func(&mut self.store, "__run_test").expect("__run_test export"); + prepare_guest_budget(&mut self.store); + f.call(&mut self.store, scenario).expect("__run_test trap-free") + } + + fn data(&self) -> &HostState { + self.store.data() + } +} + +fn prepare_guest_budget(store: &mut Store) { + store.set_fuel(u64::MAX / 4).expect("set test fuel"); + store.set_epoch_deadline(24 * 60 * 60 * 1000); + store.epoch_deadline_trap(); +} + +// ───────────────────────────────────────────────────────── +// 唯一一个 `#[test]` —— 跑完所有 scenario;隔离 shared/queue/metric 已通过 scenario 内独立 key 实现。 +// ───────────────────────────────────────────────────────── + +#[test] +fn proxy_wasm_spec_v0_2_1_compliance() { + // 准备 wasm。 + let wasm_path = ensure_guest_built(); + let wasm_bytes = std::fs::read(&wasm_path).expect("read guest wasm"); + + // 业务侧配置:plugin_name 用于 well-known property 校验;configuration 走 buffer 通道。 + let cfg = WasmPluginShellConfig { + url: format!("file://{}", wasm_path.display()), + plugin_config: serde_json::Value::Null, + plugin_name: "spec-test-plugin".to_string(), + plugin_root_id: "spec-test-root".to_string(), + plugin_vm_id: "default".to_string(), + clusters: HashMap::new(), + ..Default::default() + }; + let configuration = b"spec-test-config".to_vec(); + + let mut vm = GuestVm::new(&wasm_bytes, cfg, configuration); + + // 依次跑场景;每个返回 0 视为通过。 + let scenarios: &[(u32, &str)] = &[ + (1, "shared_data CAS roundtrip"), + (2, "shared_queue lifecycle"), + (3, "metric counter increment-only"), + (4, "metric gauge bidirectional + record"), + (5, "user property set/get"), + (6, "well-known plugin_name property"), + (7, "get_log_level"), + (8, "get_current_time_nanoseconds"), + (9, "continue_stream(HTTP_REQUEST)"), + (10, "close_stream(DOWNSTREAM) → Unimplemented"), + (11, "set_effective_context(invalid) → BadArgument"), + (12, "grpc_call → Unimplemented"), + (13, "foreign_function → NotFound"), + (14, "request `:method` pseudo header"), + (15, "add/replace/remove header"), + (16, "get_buffer(PluginConfiguration)"), + (17, "send_local_response"), + (18, "proxy_done without awaiting → NotFound"), + (19, "log at all levels"), + (20, "set_tick_period"), + ]; + + let mut failures: Vec<(u32, &str, u32)> = Vec::new(); + for &(id, name) in scenarios { + // 每个 scenario 重置一次 effective_context(部分 scenario 会改它) + vm.store.data_mut().effective_context = HTTP_CONTEXT_ID; + let code = vm.run_test(id); + if code != 0 { + failures.push((id, name, code)); + } + } + + assert!(failures.is_empty(), "spec compliance failures: {failures:?}"); + + // 额外的 host 侧副作用断言: + let st = vm.data(); + + // (17) send_local_response 应写入 ctx.local_response + let ctx = st.contexts.get(&HTTP_CONTEXT_ID).expect("http ctx present"); + let lr = ctx.local_response.as_ref().expect("local_response written by guest"); + assert_eq!(lr.status, 418, "local_response.status"); + assert_eq!(lr.body, Bytes::from_static(b"local body"), "local_response.body"); + let x_spec = lr.headers.get("x-spec").expect("x-spec header present").to_str().unwrap_or(""); + assert_eq!(x_spec, "teapot"); + + // (20) set_tick_period 应写入 HostState.tick_period_ms + assert_eq!(st.tick_period_ms, Some(123), "tick_period_ms"); + + // (5) user_properties 应记录我们设的 key + // user_properties 的 key 是 path 用 \0 拼起来;spec_test_guest 设的是 vec!["spec","user_prop"] + let want_key: Vec = b"spec\0user_prop".to_vec(); + let v = st.user_properties.get(&want_key).expect("user_prop stored under '\\0'-joined key"); + assert_eq!(v.as_slice(), b"hello"); +} diff --git a/crates/plugin-wasm/tests/spec_test_guest/.cargo/config.toml b/crates/plugin-wasm/tests/spec_test_guest/.cargo/config.toml new file mode 100644 index 00000000..6b509f5b --- /dev/null +++ b/crates/plugin-wasm/tests/spec_test_guest/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasip1" diff --git a/crates/plugin-wasm/tests/spec_test_guest/Cargo.toml b/crates/plugin-wasm/tests/spec_test_guest/Cargo.toml new file mode 100644 index 00000000..a0b8e23d --- /dev/null +++ b/crates/plugin-wasm/tests/spec_test_guest/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "spec_test_guest" +version = "0.0.0" +edition = "2021" +publish = false +description = "Proxy-Wasm guest used by spacegate-plugin-wasm integration tests. Mirrors proxy-wasm-rust-sdk examples and exercises every host fn our host registers." + +# 独立 workspace:本 crate 目标 `wasm32-wasip1`,不参与外层 host workspace。 +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm = "0.2" +log = "0.4" + +[profile.release] +codegen-units = 1 +opt-level = "z" +lto = "fat" +strip = true +panic = "abort" diff --git a/crates/plugin-wasm/tests/spec_test_guest/src/lib.rs b/crates/plugin-wasm/tests/spec_test_guest/src/lib.rs new file mode 100644 index 00000000..cd5a1b99 --- /dev/null +++ b/crates/plugin-wasm/tests/spec_test_guest/src/lib.rs @@ -0,0 +1,338 @@ +//! 验证 spacegate-plugin-wasm host fn 实现的 **真实 proxy-wasm guest 插件**。 +//! +//! 用法:通过 `cargo build --release` 编译到 `wasm32-wasip1`,得到 +//! `target/wasm32-wasip1/release/spec_test_guest.wasm`,再由 +//! `crates/plugin-wasm/tests/spec_compliance.rs` 加载并依次调用 +//! [`__run_test`] 来跑各场景。 +//! +//! 设计取舍:每个场景都通过 [`proxy_wasm::hostcalls`] 直接调相应 host fn。 +//! SDK 在 status 不预期时会 panic(即 wasmtime trap),这正好让我们: +//! +//! - host fn 返回正确 Status → SDK 返回 Result → guest 自行断言并返回 0 / 失败码 +//! - host fn 返回错误 Status → SDK panic → wasmtime trap → test 立刻挂掉 +//! +//! 这样测试侧只看 `__run_test` 返回值就能判定通过。 + +use std::time::Duration; + +use proxy_wasm::hostcalls; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Trace); + proxy_wasm::set_root_context(|_| -> Box { Box::new(SpecRoot) }); +}} + +struct SpecRoot; + +impl Context for SpecRoot {} +impl RootContext for SpecRoot { + fn on_vm_start(&mut self, _vm_configuration_size: usize) -> bool { true } + fn on_configure(&mut self, _plugin_configuration_size: usize) -> bool { true } + fn get_type(&self) -> Option { Some(ContextType::HttpContext) } + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(SpecHttp)) + } +} + +struct SpecHttp; +impl Context for SpecHttp {} +impl HttpContext for SpecHttp {} + +// ───────────────────────────────────────────────────────── +// 直接 extern:spec 里这些 host fn SDK 没暴露,我们手动 import 用 +// ───────────────────────────────────────────────────────── + +extern "C" { + fn proxy_grpc_call( + a: *const u8, b: usize, c: *const u8, d: usize, e: *const u8, f: usize, + g: *const u8, h: usize, i: *const u8, j: usize, k: u32, l: *mut u32, + ) -> u32; + fn proxy_call_foreign_function( + a: *const u8, b: usize, c: *const u8, d: usize, e: *mut *mut u8, f: *mut usize, + ) -> u32; + fn proxy_continue_stream(stream_type: u32) -> u32; + fn proxy_close_stream(stream_type: u32) -> u32; + fn proxy_set_effective_context(ctx: u32) -> u32; + fn proxy_done() -> u32; +} + +// spec §Types: STATUS_UNIMPLEMENTED = 12, STATUS_NOT_FOUND = 1, STATUS_BAD_ARGUMENT = 2 +const STATUS_OK: u32 = 0; +const STATUS_NOT_FOUND: u32 = 1; +const STATUS_BAD_ARGUMENT: u32 = 2; +const STATUS_UNIMPLEMENTED: u32 = 12; + +const STREAM_HTTP_REQUEST: u32 = 0; +const STREAM_DOWNSTREAM: u32 = 2; + +// ───────────────────────────────────────────────────────── +// 测试入口 +// ───────────────────────────────────────────────────────── + +#[no_mangle] +pub extern "C" fn __run_test(scenario: u32) -> u32 { + match scenario { + 1 => test_shared_data(), + 2 => test_shared_queue(), + 3 => test_metric_counter(), + 4 => test_metric_gauge(), + 5 => test_user_property(), + 6 => test_well_known_plugin_name(), + 7 => test_log_level(), + 8 => test_current_time(), + 9 => test_continue_stream_http_request(), + 10 => test_close_stream_tcp_unimplemented(), + 11 => test_set_effective_context_bad_argument(), + 12 => test_grpc_unimplemented(), + 13 => test_foreign_function_not_found(), + 14 => test_request_header_pseudo_method(), + 15 => test_add_replace_remove_header(), + 16 => test_get_configuration_buffer(), + 17 => test_send_local_response(), + 18 => test_done_without_pending(), + 19 => test_log(), + 20 => test_tick_period(), + _ => 999, + } +} + +// ─── 1: shared_data CAS roundtrip ─── +fn test_shared_data() -> u32 { + let key = "spec.shared_data.k"; + hostcalls::set_shared_data(key, Some(b"v1"), None).unwrap(); + let (val, cas) = hostcalls::get_shared_data(key).unwrap(); + if val.as_deref() != Some(b"v1".as_slice()) { return 1; } + let cas = match cas { Some(c) if c > 0 => c, _ => return 2 }; + // 错误 cas → CasMismatch(SDK 把 CasMismatch 包装成 Err) + if hostcalls::set_shared_data(key, Some(b"v2"), Some(cas.wrapping_add(999))).is_ok() { + return 3; + } + // 正确 cas → Ok + hostcalls::set_shared_data(key, Some(b"v2"), Some(cas)).unwrap(); + let (val, cas2) = hostcalls::get_shared_data(key).unwrap(); + if val.as_deref() != Some(b"v2".as_slice()) { return 4; } + let cas2 = cas2.unwrap_or(0); + if cas2 <= cas { return 5; } + 0 +} + +// ─── 2: shared queue lifecycle ─── +fn test_shared_queue() -> u32 { + let qid = hostcalls::register_shared_queue("spec.q.basic").unwrap(); + if qid == 0 { return 1; } + // register 同名应返回相同 qid(spec §proxy_register_shared_queue) + if hostcalls::register_shared_queue("spec.q.basic").unwrap() != qid { return 2; } + // resolve_shared_queue:vm_id 默认是 "default",name 已存在 + match hostcalls::resolve_shared_queue("default", "spec.q.basic").unwrap() { + Some(id) if id == qid => {} + _ => return 3, + } + if hostcalls::resolve_shared_queue("default", "spec.q.nonexistent").unwrap().is_some() { return 4; } + hostcalls::enqueue_shared_queue(qid, Some(b"a")).unwrap(); + hostcalls::enqueue_shared_queue(qid, Some(b"bb")).unwrap(); + match hostcalls::dequeue_shared_queue(qid).unwrap() { + Some(v) if v == b"a".to_vec() => {} + _ => return 5, + } + match hostcalls::dequeue_shared_queue(qid).unwrap() { + Some(v) if v == b"bb".to_vec() => {} + _ => return 6, + } + // 空 → Ok(None)(SDK 把 Empty 折叠成 None) + if hostcalls::dequeue_shared_queue(qid).unwrap().is_some() { return 7; } + // 未知 qid → Err(NotFound) + if hostcalls::dequeue_shared_queue(9_999_999).is_ok() { return 8; } + if hostcalls::enqueue_shared_queue(9_999_999, Some(b"x")).is_ok() { return 9; } + 0 +} + +// ─── 3: counter only allows positive delta ─── +fn test_metric_counter() -> u32 { + let id = hostcalls::define_metric(MetricType::Counter, "spec.counter").unwrap(); + if id == 0 { return 1; } + hostcalls::increment_metric(id, 3).unwrap(); + hostcalls::increment_metric(id, 2).unwrap(); + if hostcalls::get_metric(id).unwrap() != 5 { return 2; } + // counter 不能 decrement → BadArgument + if hostcalls::increment_metric(id, -1).is_ok() { return 3; } + if hostcalls::get_metric(id).unwrap() != 5 { return 4; } + // 未知 mid → NotFound + if hostcalls::get_metric(9_999_999).is_ok() { return 5; } + 0 +} + +// ─── 4: gauge bidirectional + record ─── +fn test_metric_gauge() -> u32 { + let id = hostcalls::define_metric(MetricType::Gauge, "spec.gauge").unwrap(); + hostcalls::increment_metric(id, 10).unwrap(); + hostcalls::increment_metric(id, -3).unwrap(); + if hostcalls::get_metric(id).unwrap() != 7 { return 1; } + hostcalls::record_metric(id, 42).unwrap(); + if hostcalls::get_metric(id).unwrap() != 42 { return 2; } + 0 +} + +// ─── 5: user property set/get roundtrip ─── +fn test_user_property() -> u32 { + let path = vec!["spec", "user_prop"]; + hostcalls::set_property(path.clone(), Some(b"hello")).unwrap(); + let v = hostcalls::get_property(path.clone()).unwrap(); + if v.as_deref() != Some(b"hello".as_slice()) { return 1; } + // None / NotFound:未设置过的 path + let missing = hostcalls::get_property(vec!["spec", "absent"]).unwrap(); + if missing.is_some() { return 2; } + 0 +} + +// ─── 6: well-known property plugin_name ─── +fn test_well_known_plugin_name() -> u32 { + let v = hostcalls::get_property(vec!["plugin_name"]).unwrap(); + match v { + Some(b) if b == b"spec-test-plugin".to_vec() => 0, + Some(_) => 1, + None => 2, + } +} + +// ─── 7: log_level(host 当前 tracing 最大级别) ─── +fn test_log_level() -> u32 { + let lvl = hostcalls::get_log_level().unwrap(); + // host 默认 tracing 是 ERROR 以上;我们的实现至少返回 5(CRITICAL)或更宽 + // 只要不 panic 且能拿到值就算 OK + let _ = lvl; + 0 +} + +// ─── 8: current_time > 0 ─── +fn test_current_time() -> u32 { + let now = hostcalls::get_current_time().unwrap(); + if now < std::time::UNIX_EPOCH { return 1; } + 0 +} + +// ─── 9: continue_stream(HTTP_REQUEST) → Ok ─── +fn test_continue_stream_http_request() -> u32 { + let s = unsafe { proxy_continue_stream(STREAM_HTTP_REQUEST) }; + if s != STATUS_OK { return s; } + 0 +} + +// ─── 10: close_stream(DOWNSTREAM) → Unimplemented(TCP 我们不支持) ─── +fn test_close_stream_tcp_unimplemented() -> u32 { + let s = unsafe { proxy_close_stream(STREAM_DOWNSTREAM) }; + if s != STATUS_UNIMPLEMENTED { return 100 + s; } + 0 +} + +// ─── 11: set_effective_context 对未知 ctx → BadArgument ─── +fn test_set_effective_context_bad_argument() -> u32 { + let s = unsafe { proxy_set_effective_context(987654) }; + if s != STATUS_BAD_ARGUMENT { return 100 + s; } + 0 +} + +// ─── 12: gRPC host fn → Unimplemented ─── +fn test_grpc_unimplemented() -> u32 { + let mut tok: u32 = 0; + let s = unsafe { + proxy_grpc_call( + b"cluster".as_ptr(), 7, + b"svc".as_ptr(), 3, + b"m".as_ptr(), 1, + std::ptr::null(), 0, + std::ptr::null(), 0, + 1000, + &mut tok as *mut u32, + ) + }; + if s != STATUS_UNIMPLEMENTED { return 100 + s; } + 0 +} + +// ─── 13: foreign_function → NotFound(无注册表) ─── +fn test_foreign_function_not_found() -> u32 { + let mut data: *mut u8 = std::ptr::null_mut(); + let mut size: usize = 0; + let s = unsafe { + proxy_call_foreign_function( + b"some_fn".as_ptr(), 7, + b"args".as_ptr(), 4, + &mut data as *mut *mut u8, + &mut size as *mut usize, + ) + }; + if s != STATUS_NOT_FOUND { return 100 + s; } + 0 +} + +// ─── 14: get_http_request_header(":method") ─── +fn test_request_header_pseudo_method() -> u32 { + match hostcalls::get_map_value(MapType::HttpRequestHeaders, ":method").unwrap() { + Some(m) if m == "POST" => 0, + Some(_) => 1, + None => 2, + } +} + +// ─── 15: add / replace / remove header on HttpRequestHeaders ─── +fn test_add_replace_remove_header() -> u32 { + hostcalls::add_map_value(MapType::HttpRequestHeaders, "x-spec-add", "v1").unwrap(); + if hostcalls::get_map_value(MapType::HttpRequestHeaders, "x-spec-add").unwrap().as_deref() != Some("v1") { + return 1; + } + // SDK 用 set_map_value(map, key, Some("v2")) 触发 spec 的 replace 语义。 + hostcalls::set_map_value(MapType::HttpRequestHeaders, "x-spec-add", Some("v2")).unwrap(); + if hostcalls::get_map_value(MapType::HttpRequestHeaders, "x-spec-add").unwrap().as_deref() != Some("v2") { + return 2; + } + hostcalls::remove_map_value(MapType::HttpRequestHeaders, "x-spec-add").unwrap(); + if hostcalls::get_map_value(MapType::HttpRequestHeaders, "x-spec-add").unwrap().is_some() { + return 3; + } + 0 +} + +// ─── 16: get_buffer(PluginConfiguration) 返回配置字节 ─── +fn test_get_configuration_buffer() -> u32 { + // start=0, max_size=usize::MAX + match hostcalls::get_buffer(BufferType::PluginConfiguration, 0, usize::MAX).unwrap() { + Some(b) if b == b"spec-test-config".to_vec() => 0, + Some(_) => 1, + None => 2, + } +} + +// ─── 17: send_local_response(host 侧通过 contexts[ctx].local_response 验证) ─── +fn test_send_local_response() -> u32 { + hostcalls::send_http_response( + 418, + vec![("x-spec", "teapot")], + Some(b"local body"), + ).unwrap(); + 0 +} + +// ─── 18: proxy_done 在没有 awaiting_done 时 → NotFound ─── +fn test_done_without_pending() -> u32 { + let s = unsafe { proxy_done() }; + if s != STATUS_NOT_FOUND { return 100 + s; } + 0 +} + +// ─── 19: proxy_log 在各级别 ─── +fn test_log() -> u32 { + hostcalls::log(LogLevel::Trace, "spec trace").unwrap(); + hostcalls::log(LogLevel::Debug, "spec debug").unwrap(); + hostcalls::log(LogLevel::Info, "spec info").unwrap(); + hostcalls::log(LogLevel::Warn, "spec warn").unwrap(); + hostcalls::log(LogLevel::Error, "spec error").unwrap(); + 0 +} + +// ─── 20: set_tick_period 应 Ok ─── +fn test_tick_period() -> u32 { + hostcalls::set_tick_period(Duration::from_millis(123)).unwrap(); + 0 +} diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 16849abe..186fa532 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -33,6 +33,21 @@ pub mod plugins; pub use schemars; pub use spacegate_model; pub use spacegate_model::{plugin_meta, PluginAttributes, PluginConfig, PluginInstanceId, PluginInstanceMap, PluginInstanceName, PluginMetaData}; + +pub fn set_telemetry_field(req: &SgRequest, key: impl Into, value: impl ToString) -> Result<(), spacegate_kernel::observability::TelemetryError> { + if let Some(context) = req.extensions().get::() { + context.insert_checked(key, value)?; + } + Ok(()) +} + +pub fn set_plugin_telemetry_field(req: &SgRequest, namespace: &str, key: &str, value: impl ToString) -> Result<(), spacegate_kernel::observability::TelemetryError> { + if let Some(context) = req.extensions().get::() { + context.insert_namespaced(namespace, key, value)?; + } + Ok(()) +} + /// # Plugin Trait /// It's a easy way to define a plugin through this trait. /// You should give a unique [`code`](Plugin::CODE) for the plugin, diff --git a/crates/plugin/src/plugins/limit.rs b/crates/plugin/src/plugins/limit.rs index e5711011..c52c4422 100644 --- a/crates/plugin/src/plugins/limit.rs +++ b/crates/plugin/src/plugins/limit.rs @@ -126,7 +126,7 @@ impl Plugin for RateLimitPlugin { if result == EXCEEDED { let mut response = Response::::with_code_message(StatusCode::TOO_MANY_REQUESTS, "[SG.Filter.Limit] too many requests"); - response.extensions_mut().insert(self.report( ip)); + response.extensions_mut().insert(self.report(ip)); return Ok(response); } Ok(inner.call(req).await) diff --git a/crates/plugin/tests/test_telemetry.rs b/crates/plugin/tests/test_telemetry.rs new file mode 100644 index 00000000..1a44bdaa --- /dev/null +++ b/crates/plugin/tests/test_telemetry.rs @@ -0,0 +1,38 @@ +use spacegate_plugin::{set_plugin_telemetry_field, set_telemetry_field, SgBody}; + +fn request_with_telemetry() -> hyper::Request { + let mut req = hyper::Request::builder().body(SgBody::empty()).expect("request"); + req.extensions_mut().insert(spacegate_kernel::observability::TelemetryContext::default()); + req +} + +#[test] +fn set_telemetry_field_writes_checked_request_context() { + let req = request_with_telemetry(); + + set_telemetry_field(&req, "ai.asset_id", "deepseek-chat").expect("insert"); + set_telemetry_field(&req, "ai.total_tokens", 37).expect("insert"); + + let fields = req.extensions().get::().expect("telemetry context").snapshot(); + assert_eq!(fields.get("ai.asset_id").map(String::as_str), Some("deepseek-chat")); + assert_eq!(fields.get("ai.total_tokens").map(String::as_str), Some("37")); +} + +#[test] +fn set_plugin_telemetry_field_adds_namespace() { + let req = request_with_telemetry(); + + set_plugin_telemetry_field(&req, "mcp", "tool", "search").expect("insert"); + + let fields = req.extensions().get::().expect("telemetry context").snapshot(); + assert_eq!(fields.get("mcp.tool").map(String::as_str), Some("search")); +} + +#[test] +fn set_telemetry_field_rejects_unqualified_key() { + let req = request_with_telemetry(); + + let result = set_telemetry_field(&req, "total_tokens", 37); + + assert_eq!(result, Err(spacegate_kernel::observability::TelemetryError::MissingNamespace)); +} diff --git a/crates/shell/Cargo.toml b/crates/shell/Cargo.toml index 358d7f61..d76a67c4 100644 --- a/crates/shell/Cargo.toml +++ b/crates/shell/Cargo.toml @@ -52,8 +52,10 @@ plugin-set-version = ["spacegate-plugin/set-version"] plugin-east-west-traffic-white-list = [ "spacegate-plugin/east-west-traffic-white-list", ] +plugin-wasm = ["dep:spacegate-plugin-wasm"] [dependencies] +spacegate-plugin-wasm = { workspace = true, optional = true } spacegate-kernel = { workspace = true, features = ["reload"] } spacegate-plugin = { workspace = true, features = ["schema"] } spacegate-config = { workspace = true } @@ -62,6 +64,12 @@ spacegate-ext-axum = { workspace = true, optional = true } regex = { workspace = true } futures-util.workspace = true tracing.workspace = true +tracing-subscriber = { workspace = true, features = ["env-filter"] } +tracing-opentelemetry.workspace = true +opentelemetry.workspace = true +opentelemetry_sdk.workspace = true +opentelemetry-otlp.workspace = true +opentelemetry-appender-tracing.workspace = true tokio.workspace = true hyper.workspace = true rustls-pemfile.workspace = true @@ -70,7 +78,6 @@ tokio-util = { workspace = true, features = ["io"] } [dev-dependencies] reqwest = { workspace = true } -tracing-subscriber = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } testcontainers-modules = { workspace = true, features = ["redis"] } [package.metadata.docs.rs] diff --git a/crates/shell/src/config.rs b/crates/shell/src/config.rs index f72b8a8e..b92d39ce 100644 --- a/crates/shell/src/config.rs +++ b/crates/shell/src/config.rs @@ -29,6 +29,7 @@ where C: Retrieve + CreateListener + 'static, { let (init_config, listener) = config.create_listener().await?; + crate::observability::init(&init_config.observability); #[cfg(feature = "ext-axum")] let listener = { use crate::ext_features::axum::{shell_routers, App}; diff --git a/crates/shell/src/lib.rs b/crates/shell/src/lib.rs index b60f6442..36c25626 100644 --- a/crates/shell/src/lib.rs +++ b/crates/shell/src/lib.rs @@ -49,6 +49,8 @@ use tracing::{info, instrument}; pub mod config; /// http extensions pub mod extension; +/// OpenTelemetry initialization. +pub mod observability; /// Spacegate service creation pub mod server; @@ -118,6 +120,9 @@ where { info!("Spacegate Meta Info: {:?}", Meta::new()); info!("Starting gateway..."); + // 启用 `plugin-wasm` 时注册 `CODE = "wasm"`;注册放在 shell 而非 `spacegate-plugin`,避免与 `plugin-wasm` crate 循环依赖。 + #[cfg(feature = "plugin-wasm")] + spacegate_plugin_wasm::register(spacegate_plugin::PluginRepository::global()); config::startup_with_shutdown_signal(config, ctrl_c_cancel_token()).await } diff --git a/crates/shell/src/observability.rs b/crates/shell/src/observability.rs new file mode 100644 index 00000000..8ef285fd --- /dev/null +++ b/crates/shell/src/observability.rs @@ -0,0 +1,224 @@ +use std::sync::OnceLock; +use std::time::Duration; + +use opentelemetry::global; +use opentelemetry::trace::TracerProvider as _; +use opentelemetry_appender_tracing::layer::OpenTelemetryTracingBridge; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::{ + logs::SdkLoggerProvider, + metrics::{PeriodicReader, SdkMeterProvider}, + trace::{Sampler, SdkTracerProvider}, + Resource, +}; +use spacegate_config::model::{ObservabilityConfig, OtlpProtocol}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; + +static OTEL_GUARD: OnceLock = OnceLock::new(); + +#[derive(Debug)] +pub struct ObservabilityGuard { + tracer_provider: Option, + meter_provider: Option, + logger_provider: Option, +} + +impl ObservabilityGuard { + pub fn shutdown(&self) { + if let Some(provider) = &self.tracer_provider { + if let Err(err) = provider.shutdown() { + eprintln!("failed to shutdown otel tracer provider: {err}"); + } + } + if let Some(provider) = &self.meter_provider { + if let Err(err) = provider.shutdown() { + eprintln!("failed to shutdown otel meter provider: {err}"); + } + } + if let Some(provider) = &self.logger_provider { + if let Err(err) = provider.shutdown() { + eprintln!("failed to shutdown otel logger provider: {err}"); + } + } + } +} + +impl Drop for ObservabilityGuard { + fn drop(&mut self) { + self.shutdown(); + } +} + +pub fn init(config: &ObservabilityConfig) { + let _ = OTEL_GUARD.get_or_init(|| match build_guard(config) { + Ok(guard) => guard, + Err(err) => { + eprintln!("failed to initialize OpenTelemetry, falling back to stdout tracing: {err}"); + init_stdout_only(); + ObservabilityGuard { + tracer_provider: None, + meter_provider: None, + logger_provider: None, + } + } + }); +} + +fn build_guard(config: &ObservabilityConfig) -> Result> { + let env_filter = EnvFilter::from_default_env(); + let fmt_layer = tracing_subscriber::fmt::layer(); + if !config.enabled { + tracing_subscriber::registry().with(env_filter).with(fmt_layer).try_init()?; + return Ok(ObservabilityGuard { + tracer_provider: None, + meter_provider: None, + logger_provider: None, + }); + } + + let resource = Resource::builder().with_service_name(config.service_name.clone()).build(); + let mut guard = ObservabilityGuard { + tracer_provider: None, + meter_provider: None, + logger_provider: None, + }; + + let trace_layer = if config.traces.enabled { + match build_span_exporter(config) { + Ok(exporter) => { + let provider = SdkTracerProvider::builder() + .with_resource(resource.clone()) + .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(config.traces.sample_ratio)))) + .with_batch_exporter(exporter) + .build(); + let tracer = provider.tracer("spacegate"); + global::set_tracer_provider(provider.clone()); + guard.tracer_provider = Some(provider); + Some(tracing_opentelemetry::layer().with_tracer(tracer).boxed()) + } + Err(err) => { + eprintln!("failed to initialize OpenTelemetry traces, disabling traces: {err}"); + None + } + } + } else { + None + }; + + if config.metrics.enabled { + match build_metric_exporter(config) { + Ok(exporter) => { + let reader = PeriodicReader::builder(exporter).with_interval(metric_export_interval(config)).build(); + let provider = SdkMeterProvider::builder().with_resource(resource.clone()).with_reader(reader).build(); + global::set_meter_provider(provider.clone()); + guard.meter_provider = Some(provider); + } + Err(err) => { + eprintln!("failed to initialize OpenTelemetry metrics, disabling metrics: {err}"); + } + } + } + + let log_layer = if config.logs.enabled { + match build_log_exporter(config) { + Ok(exporter) => { + let provider = SdkLoggerProvider::builder().with_resource(resource).with_batch_exporter(exporter).build(); + let level_filter = log_level_filter(config); + let layer = OpenTelemetryTracingBridge::new(&provider).with_filter(level_filter).boxed(); + guard.logger_provider = Some(provider); + Some(layer) + } + Err(err) => { + eprintln!("failed to initialize OpenTelemetry logs, disabling logs: {err}"); + None + } + } + } else { + None + }; + + tracing_subscriber::registry().with(env_filter).with(fmt_layer).with(trace_layer).with(log_layer).try_init()?; + Ok(guard) +} + +fn init_stdout_only() { + let _ = tracing_subscriber::fmt().with_env_filter(tracing_subscriber::EnvFilter::from_default_env()).try_init(); +} + +fn otlp_protocol(config: &ObservabilityConfig) -> opentelemetry_otlp::Protocol { + match config.protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::Protocol::Grpc, + OtlpProtocol::Http => opentelemetry_otlp::Protocol::HttpBinary, + } +} + +fn metric_export_interval(config: &ObservabilityConfig) -> Duration { + Duration::from_millis(config.metrics.export_interval_ms) +} + +fn log_level_filter(config: &ObservabilityConfig) -> tracing_subscriber::filter::LevelFilter { + config.logs.level.parse::().unwrap_or(tracing_subscriber::filter::LevelFilter::INFO) +} + +fn build_span_exporter(config: &ObservabilityConfig) -> Result { + let timeout = Duration::from_secs(5); + match config.protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::SpanExporter::builder().with_tonic().with_endpoint(config.otlp_endpoint.clone()).with_timeout(timeout).build(), + OtlpProtocol::Http => { + opentelemetry_otlp::SpanExporter::builder().with_http().with_endpoint(config.otlp_endpoint.clone()).with_protocol(otlp_protocol(config)).with_timeout(timeout).build() + } + } +} + +#[cfg(test)] +mod tests { + use spacegate_config::model::{LogConfig, MetricConfig}; + + use super::*; + + #[test] + fn metric_export_interval_uses_configured_millis() { + let config = ObservabilityConfig { + metrics: MetricConfig { + enabled: true, + export_interval_ms: 15_000, + }, + ..Default::default() + }; + + assert_eq!(metric_export_interval(&config), Duration::from_secs(15)); + } + + #[test] + fn invalid_log_level_falls_back_to_info() { + let config = ObservabilityConfig { + logs: LogConfig { + enabled: true, + level: "not-a-level".to_string(), + }, + ..Default::default() + }; + + assert_eq!(log_level_filter(&config), tracing_subscriber::filter::LevelFilter::INFO); + } +} + +fn build_metric_exporter(config: &ObservabilityConfig) -> Result { + let timeout = Duration::from_secs(5); + match config.protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::MetricExporter::builder().with_tonic().with_endpoint(config.otlp_endpoint.clone()).with_timeout(timeout).build(), + OtlpProtocol::Http => { + opentelemetry_otlp::MetricExporter::builder().with_http().with_endpoint(config.otlp_endpoint.clone()).with_protocol(otlp_protocol(config)).with_timeout(timeout).build() + } + } +} + +fn build_log_exporter(config: &ObservabilityConfig) -> Result { + let timeout = Duration::from_secs(5); + match config.protocol { + OtlpProtocol::Grpc => opentelemetry_otlp::LogExporter::builder().with_tonic().with_endpoint(config.otlp_endpoint.clone()).with_timeout(timeout).build(), + OtlpProtocol::Http => { + opentelemetry_otlp::LogExporter::builder().with_http().with_endpoint(config.otlp_endpoint.clone()).with_protocol(otlp_protocol(config)).with_timeout(timeout).build() + } + } +} diff --git a/crates/shell/src/server.rs b/crates/shell/src/server.rs index 899e21ea..8628a4de 100644 --- a/crates/shell/src/server.rs +++ b/crates/shell/src/server.rs @@ -293,14 +293,14 @@ impl RunningSgGateway { tls_server_cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()]; tls_server_cfg.ignore_client_order = true; tls_server_cfg.enable_secret_extraction = true; - listen.add_service(service.clone().https(tls_server_cfg)) + listen.add_service(service.clone().https_with_gateway_name(tls_server_cfg, gateway_name.clone())) } else { error!("[SG.Server] Can not found a valid Tls private key"); } }; } } else { - listen.add_service(service.clone().http()); + listen.add_service(service.clone().http_with_gateway_name(gateway_name.clone())); } listens.push(listen) } diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 00000000..cdb73ff5 --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,504 @@ +# AI Gateway 队列限流 — 编译与部署指南 + +本文档说明如何编译 `ai-gateway-queue` Wasm 插件,并在 **本地开发 / Docker / Kubernetes** 等环境中部署,以及如何将 Wasm 发布为 **OCI 制品**。 + +相关文档: + +- 插件行为与请求头:[`plugins/wasm/ai-gateway-queue/README.md`](../plugins/wasm/ai-gateway-queue/README.md) +- **管理界面配置指南**:[`docs/ai-gateway-queue-admin-ui-guide.md`](../docs/ai-gateway-queue-admin-ui-guide.md) +- 测试用例规格:[`docs/ai-gateway-queue-test-spec.md`](../docs/ai-gateway-queue-test-spec.md) +- K8s manifest 目录:[`deploy/k8s/ai-gateway/`](k8s/ai-gateway/) + +--- + +## 1. 架构概览 + +```text +Client + → SpaceGate(ai-gateway-queue Wasm 插件) + → ai-gateway-service(限流 / 入队 / wait / worker / 回调) + → Redis 7+ + → 上游 LLM Service +``` + +| 组件 | 作用 | +|------|------| +| **ai-gateway-queue**(Wasm) | 解析 Policy / Tenant,调用后端限流,配额内转发上游,超额 429/202/wait | +| **ai-gateway-service** | 令牌桶、Redis Stream 队列、Worker、回调、指标 | +| **SpaceGate** | 加载 Wasm,路由到上游 | +| **Redis 7+** | 限流状态、队列、结果缓存 | + +三种策略(`X-RateLimit-Policy`)均 **先过令牌桶**;配额内三种策略都直通上游;超额时: + +- `abandon` → 429 +- `queue` → 202 + 回调/轮询 +- `wait` → 阻塞等待上游响应或 504 + +--- + +## 2. 编译 Wasm 插件 + +### 2.1 前置条件 + +- Rust 工具链(与 `spacegate` workspace 一致) +- 目标三元组 `wasm32-wasip1` + +```bash +rustup target add wasm32-wasip1 +``` + +### 2.2 Release 构建(部署用) + +在 **`spacegate` 仓库根目录**执行: + +```bash +cd spacegate + +cargo build --release \ + --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue +``` + +产物路径: + +```text +plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm +``` + +### 2.3 Debug 构建(开发调试用) + +```bash +cargo build \ + --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue +``` + +Debug 产物在 `plugins/wasm/target/wasm32-wasip1/debug/` 下,体积更大、未优化,**不要用于生产**。 + +### 2.4 校验产物 + +```bash +WASM=plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm +file "$WASM" # 应为 WebAssembly +ls -lh "$WASM" +shasum -a 256 "$WASM" +``` + +### 2.5 插件配置要点 + +Wasm 宿主侧需要完整 shell 配置(参考 [`.docker/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json`](../../.docker/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json)(工作区根目录)或 K8s `SgFilter`): + +| 字段 | 说明 | +|------|------| +| `url` | Wasm 来源:`file://`、`http(s)://` 或 `oci://` | +| `plugin_config.service_cluster` | 固定 cluster 名,如 `ai-gateway-service` | +| `clusters.ai-gateway-service` | 后端 base URL,如 `http://ai-gateway-service:18080` | +| `plugin_config.require_policy` | 是否强制 `X-RateLimit-Policy` | + +**注意:** 插件不要在 Gateway 与 HTTPRoute **重复挂载**,否则会执行两次限流(双倍扣 token)。 + +--- + +## 3. 编译 ai-gateway-service(后端) + +后端为普通 Rust 二进制,与 Wasm 分开构建。 + +### 3.1 本地运行 + +```bash +cd spacegate + +cargo build --release -p ai-gateway-service + +REDIS_URL=redis://127.0.0.1/ \ +AI_UPSTREAM_BASE_URL=http://127.0.0.1:9000 \ +AI_REQUIRE_HTTPS_CALLBACK=false \ +./target/release/ai-gateway-service \ + --port 18080 \ + --host 127.0.0.1 +``` + +配置模板:[`binary/ai-gateway-service/config/ai-gateway-service.example.toml`](../binary/ai-gateway-service/config/ai-gateway-service.example.toml) + +### 3.2 构建 Linux 容器镜像(K8s / Docker) + +```bash +cd spacegate/deploy/k8s/ai-gateway +./build-images.sh +# 默认镜像名 ai-gateway/service:dev +``` + +Dockerfile:[`deploy/k8s/ai-gateway/docker/Dockerfile.ai-gateway-service`](k8s/ai-gateway/docker/Dockerfile.ai-gateway-service) + +导入本地集群(示例 k3d): + +```bash +k3d image import ai-gateway/service:dev -c +``` + +--- + +## 4. 本地开发部署(Cargo + 文件配置) + +适合改代码、跑集成测试。 + +### 4.1 依赖服务 + +| 服务 | 端口 | 说明 | +|------|------|------| +| Redis 7+ | 6379 | 必须 | +| Mock 上游 | 9000 | 任意 HTTP 服务 | +| ai-gateway-service | 18080 | 队列后端 | +| SpaceGate | 9993 | 加载 Wasm + 路由 | + +### 4.2 SpaceGate 文件配置 + +参考 [`resource/ai-gateway-demo/`](../resource/ai-gateway-demo/) 模板,复制到 **工作区根目录** `.docker/ai-gateway-demo/`(与 `spacegate` 仓库同级,非 spacegate 子目录): + +```text +ai-gateway-dev/.docker/ai-gateway-demo/ + config.json + gateway/ai-demo/ + plugin/wasm.ai-gateway-queue.json # 仅 JSON + plugins/spacegate_plugin_ai_gateway_queue.wasm +``` + +`resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json` 内含本机绝对路径,**不要直接用于 Docker**;请使用 `.docker` 下已改为 `file:///etc/spacegate/plugins/...` 的版本。 + +`wasm.ai-gateway-queue.json` 中 `clusters` 示例: + +```json +"clusters": { + "ai-gateway-service": "http://127.0.0.1:18080" +} +``` + +### 4.3 启动 SpaceGate(示例) + +```bash +cd spacegate +cargo run -p spacegate -- -c file:resource/ai-gateway-demo +# Docker 使用工作区根目录 .docker/ai-gateway-demo(挂载到 /etc/spacegate) +``` + +**避免** 本地 debug SpaceGate 与 Docker 容器 **同时占用 `:9993`**。 + +### 4.4 冒烟测试 + +```bash +# 经网关(插件生效) +curl -i http://127.0.0.1:9993/v1/chat/completions \ + -H 'X-RateLimit-Policy: abandon' \ + -H 'X-Tenant-Id: demo' \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hello"}' + +# 直连后端 +curl http://127.0.0.1:18080/healthz +``` + +### 4.5 自动化测试 + +```bash +cd spacegate + +# 单元测试 +cargo test -p ai-gateway-service --lib + +# 集成测试(需 Redis) +./binary/ai-gateway-service/scripts/run-integration-tests.sh + +# Wasm 策略逻辑 +./binary/ai-gateway-service/scripts/run-wasm-policy-tests.sh +``` + +--- + +## 5. Docker Compose 部署 + +> 若工作区根目录的 `docker-compose.yml` 已删除,可从 Git 历史恢复,或参照本节手工起容器。 + +典型栈: + +| 容器 | 端口 | 镜像 | +|------|------|------| +| ai-gateway-redis | 6379 | redis:7 | +| ai-gateway-service | 18080 | ai-gateway/service:dev | +| ai-gateway-spacegate | 9993 | spacegate + Wasm 挂载 | +| ai-gateway-web | 9080 | 管理前端 | +| ai-gateway-mock-upstream | 9000 | mock LLM | + +要点: + +- 配置目录挂载:**工作区根目录** `.docker/ai-gateway-demo/` → 容器内 `/etc/spacegate` +- **admin-server 卷须可写**(勿 `:ro`),否则管理界面保存插件报 `Read-only file system` +- Wasm 放在 `plugin/`(JSON)与 `plugins/`(`.wasm` 二进制),**勿**把 `.wasm` 放进 `plugin/`(会被当 JSON 解析导致 SpaceGate 启动失败) +- macOS 上 **不能** `docker cp` 本机编译的 Mach-O 二进制进 Linux 容器,需在 Linux 环境构建镜像 + +管理界面 `:9080` 依赖 admin-server 能读到 `/etc/spacegate` 配置;若报 `No such file or directory`,检查 volume 挂载是否存在。 + +--- + +## 6. Kubernetes 部署 + +Manifest 位于 [`deploy/k8s/ai-gateway/`](k8s/ai-gateway/)。 + +### 6.1 前置:安装 SpaceGate 基础组件 + +```bash +# Gateway API CRD(见 docs/k8s/installation.md) +kubectl apply -f https://github.com/kubernetes-sigs/gateway-api/releases/download/v0.6.2/standard-install.yaml + +kubectl apply -f resource/kube-manifests/namespace.yaml +kubectl apply -f resource/kube-manifests/gatewayclass.yaml +kubectl apply -f resource/kube-manifests/spacegate-gateway.yaml +kubectl apply -f resource/kube-manifests/higress-wasmplugin-crd.yaml # 若使用 WasmPlugin +``` + +SpaceGate DaemonSet 使用 `CONFIG=k8s:spacegate`,监听同 namespace 下的 Gateway / HTTPRoute / SgFilter / WasmPlugin。 + +### 6.2 一键部署 AI Gateway 栈 + +```bash +# 1. 构建并导入 ai-gateway-service 镜像 +cd deploy/k8s/ai-gateway +./build-images.sh +k3d image import ai-gateway/service:dev -c # 按需 + +# 2. 编译 Wasm + apply +./apply.sh + +# 3. 验证 +./verify.sh +``` + +`apply.sh` 会: + +1. 编译 `spacegate_plugin_ai_gateway_queue.wasm` +2. 写入 `files/` 供 Kustomize 生成 ConfigMap +3. `kubectl apply -k .` 部署 Redis、mock-upstream、wasm-server、ai-gateway-service、Gateway、HTTPRoute、SgFilter + +### 6.3 资源说明 + +| 资源 | 说明 | +|------|------| +| `ai-gateway-redis` | Redis 7 | +| `ai-gateway-service` | 队列/限流后端 Service `:18080` | +| `ai-gateway-wasm` | Nginx 通过 HTTP 分发 `.wasm`(免改 SpaceGate DaemonSet) | +| `ai-gateway` Gateway | 监听 `:9993` | +| `ai-api` HTTPRoute | `/v1/*` → mock-upstream | +| `SgFilter ai-gateway-queue` | Wasm 插件 + `clusters` 映射(**推荐**) | + +### 6.4 Wasm 插件在 K8s 上的两种挂载方式 + +#### 方式 A:SgFilter(推荐) + +完整 shell spec 含 `clusters`,见 [`sgfilter-ai-gateway-queue.yaml`](k8s/ai-gateway/sgfilter-ai-gateway-queue.yaml): + +```yaml +config: + url: http://ai-gateway-wasm/spacegate_plugin_ai_gateway_queue.wasm + clusters: + ai-gateway-service: http://ai-gateway-service:18080 + plugin_config: + service_cluster: ai-gateway-service + require_policy: true + # ... +``` + +#### 方式 B:Higress WasmPlugin + +[`wasmplugin-ai-gateway-queue.yaml`](k8s/ai-gateway/wasmplugin-ai-gateway-queue.yaml) 中 `defaultConfig` **不会**自动写入顶层 `clusters`,生产环境需配合 SgFilter 或扩展 CRD 转换逻辑。 + +私有 OCI 仓库需配置 `imagePullSecret`。 + +### 6.5 网关入口测试 + +SpaceGate 使用 `hostNetwork` 时,在节点上访问: + +```bash +curl -i http://:9993/v1/chat/completions \ + -H 'X-RateLimit-Policy: abandon' \ + -H 'X-Tenant-Id: demo' \ + -d '{"prompt":"hi"}' +``` + +### 6.6 生产替换清单 + +| 开发默认 | 生产建议 | +|----------|----------| +| mock-upstream | 真实 LLM Service | +| `AI_REQUIRE_HTTPS_CALLBACK=false` | `true`,回调 URL 必须 HTTPS | +| HTTP Wasm 分发 | OCI 制品 + `oci://` URL | +| 单副本 Redis | 托管 Redis / Sentinel / Cluster | +| 无对象存储 | 配置 S3/MinIO(大 body offload) | + +--- + +## 7. 制作 OCI 制品 + +SpaceGate 支持从 OCI 仓库拉取 Wasm,URL 形式: + +```text +oci:///: +docker://... # 等价 +image://... # 等价 +oci+http://... # 本地非 TLS registry +``` + +接受的 layer 媒体类型: + +- `application/vnd.module.wasm.content.layer.v1+wasm`(推荐) +- `application/vnd.wasm.content.layer.v1+wasm` +- `application/wasm` + +### 7.1 安装 ORAS + +```bash +brew install oras +# 或从 https://github.com/oras-project/oras/releases 下载 +``` + +### 7.2 编译并计算 digest + +```bash +cd spacegate + +cargo build --release \ + --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue + +WASM=plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm +shasum -a 256 "$WASM" +``` + +### 7.3 推送到仓库 + +```bash +# 登录(按仓库类型选择) +oras login ghcr.io -u YOUR_USER +# oras login registry.cn-hangzhou.aliyuncs.com +# oras login your-harbor.example.com + +REGISTRY=ghcr.io/your-org +TAG=v1.0.0 + +oras push "${REGISTRY}/ai-gateway-queue:${TAG}" \ + --artifact-type application/vnd.module.wasm.content.layer.v1+wasm \ + "${WASM}:application/wasm" +``` + +推送成功后配置: + +```yaml +url: oci://ghcr.io/your-org/ai-gateway-queue:v1.0.0 +sha256: sha256:<上一步 shasum 输出> # 可选,建议生产开启 +``` + +在 SgFilter / WasmPlugin 中替换 `url` 即可;私有仓库配合 `imagePullSecret`。 + +### 7.4 本地 Registry 测试 + +```bash +docker run -d -p 5000:5000 --name registry registry:2 + +oras push localhost:5000/ai-gateway-queue:v1 \ + --artifact-type application/vnd.module.wasm.content.layer.v1+wasm \ + "${WASM}:application/wasm" +``` + +SpaceGate 配置(本地/insecure): + +```text +oci+http://localhost:5000/ai-gateway-queue:v1 +``` + +### 7.5 OCI 注意事项 + +| 项 | 说明 | +|----|------| +| Docker Hub | 通常 **不支持** Wasm OCI Artifact,请用 GHCR / Harbor / ACR / ECR 等 | +| 与容器镜像区别 | OCI Artifact 是单层 Wasm 文件,不是 `docker build` 的应用镜像 | +| ai-gateway-service 镜像 | 仍用 [`build-images.sh`](k8s/ai-gateway/build-images.sh) 单独构建 | +| 版本更新 | 改 tag 重新 push;或在配置中更新 `sha256` / `module_cache_key` 触发重新拉取 | + +### 7.6 一键推送脚本(可选) + +```bash +#!/usr/bin/env bash +set -euo pipefail +REGISTRY="${REGISTRY:?set REGISTRY e.g. ghcr.io/your-org}" +TAG="${TAG:-v1.0.0}" +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +WASM="$ROOT/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm" + +cd "$ROOT" +cargo build --release --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue + +DIGEST=$(shasum -a 256 "$WASM" | awk '{print $1}') + +oras push "${REGISTRY}/ai-gateway-queue:${TAG}" \ + --artifact-type application/vnd.module.wasm.content.layer.v1+wasm \ + "${WASM}:application/wasm" + +echo "url: oci://${REGISTRY}/ai-gateway-queue:${TAG}" +echo "sha256: sha256:${DIGEST}" +``` + +保存为 [`deploy/push-wasm-oci.sh`](push-wasm-oci.sh)(脚本内 `ROOT` 指向 `spacegate` 仓库根目录)后: + +```bash +REGISTRY=ghcr.io/your-org TAG=v1.0.0 ./deploy/push-wasm-oci.sh +``` + +--- + +## 8. 各环境对照表 + +| 环境 | Wasm 分发 | 后端地址配置 | 入口 | +|------|-----------|--------------|------| +| 本地 Cargo | `file://.../plugins/*.wasm` | `127.0.0.1:18080` | `:9993` SpaceGate | +| Docker | volume 挂载 `plugins/` | `http://ai-gateway-service:18080` | `:9993` / `:9080` 管理端 | +| K8s(HTTP) | `http://ai-gateway-wasm/...` | `http://ai-gateway-service:18080` | Gateway `:9993` | +| K8s / 生产(OCI) | `oci://registry/...:tag` | K8s Service DNS | Gateway `:9993` | + +--- + +## 9. 常见问题 + +**Q: 第一次请求就 429?** +A: 检查插件是否在 Gateway 与 Route **重复挂载**;或测试租户 burst 过小。Admin 设置:`PUT /v1/admin/tenant-rate-limits`。 + +**Q: `:9080` 报 `No such file or directory`?** +A: admin-server 读不到 `/etc/spacegate` 配置,恢复 **工作区根目录** `.docker/ai-gateway-demo` 挂载。 + +**Q: SpaceGate 启动报 JSON parse error?** +A: `plugin/` 目录下有 `.wasm` 文件,应移到 `plugins/` 子目录。 + +**Q: macOS 二进制拷进 Linux 容器失败?** +A: 在 Linux 环境 `docker build` 或使用已构建的 `ai-gateway/service:dev` 镜像。 + +**Q: WasmPlugin 无法连 ai-gateway-service?** +A: Higress WasmPlugin 的 `defaultConfig` 不含 `clusters`,请用 **SgFilter** 或改用 OCI + 完整 spec。 + +--- + +## 10. 目录索引 + +```text +spacegate/ +├── plugins/wasm/ai-gateway-queue/ # Wasm 插件源码 +├── binary/ai-gateway-service/ # 队列/限流后端 +├── resource/ai-gateway-demo/ # 文件模式配置模板 +├── deploy/ +│ ├── README.md # 本文档 +│ └── k8s/ai-gateway/ # K8s manifest + apply.sh +└── docs/ + ├── ai-gateway-queue-test-spec.md # 测试用例 + └── ai-gateway-queue-design-gap-fixlist.md +``` diff --git a/deploy/k8s/ai-gateway/admin-ui.yaml b/deploy/k8s/ai-gateway/admin-ui.yaml new file mode 100644 index 00000000..07e3cc5a --- /dev/null +++ b/deploy/k8s/ai-gateway/admin-ui.yaml @@ -0,0 +1,165 @@ +# SpaceGate Admin UI(K8s 模式) +# admin-server 读写 K8s 中的 Gateway / HTTPRoute / SgFilter 等 +apiVersion: v1 +kind: ServiceAccount +metadata: + name: spacegate-admin + namespace: spacegate + labels: + app.kubernetes.io/part-of: ai-gateway +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: spacegate-admin + labels: + app.kubernetes.io/part-of: ai-gateway +rules: + - apiGroups: [""] + resources: [services, pods, secrets, configmaps] + verbs: [get, list, watch, create, update, patch, delete] + - apiGroups: [apps] + resources: [daemonsets] + verbs: [get, list, watch] + - apiGroups: [gateway.networking.k8s.io] + resources: [gatewayclasses, gateways, httproutes, httproutes/status, gateways/status, gatewayclasses/status] + verbs: [get, list, create, update, watch, delete] + - apiGroups: [spacegate.idealworld.group] + resources: [sgfilters, httpspaceroutes, httpspaceroutes/status] + verbs: [get, create, update, patch, list, watch, delete] + - apiGroups: [extensions.higress.io] + resources: [wasmplugins, wasmplugins/status] + verbs: [get, create, update, patch, list, watch, delete] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: spacegate-admin + labels: + app.kubernetes.io/part-of: ai-gateway +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: spacegate-admin +subjects: + - kind: ServiceAccount + name: spacegate-admin + namespace: spacegate +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: spacegate-admin + namespace: spacegate + labels: + app.kubernetes.io/name: spacegate-admin + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: spacegate-admin + template: + metadata: + labels: + app.kubernetes.io/name: spacegate-admin + app.kubernetes.io/part-of: ai-gateway + spec: + serviceAccountName: spacegate-admin + containers: + - name: admin-server + image: ai-gateway/admin-server:dev + imagePullPolicy: IfNotPresent + args: + - -c + - k8s:spacegate + - -p + - "19992" + - -H + - 0.0.0.0 + env: + - name: CONFIG + value: k8s:spacegate + - name: RUST_LOG + value: info + ports: + - containerPort: 19992 + name: http + readinessProbe: + tcpSocket: + port: http + initialDelaySeconds: 3 + periodSeconds: 5 + resources: + requests: + cpu: 50m + memory: 64Mi +--- +apiVersion: v1 +kind: Service +metadata: + name: spacegate-admin + namespace: spacegate + labels: + app.kubernetes.io/name: spacegate-admin +spec: + selector: + app.kubernetes.io/name: spacegate-admin + ports: + - name: http + port: 19992 + targetPort: http +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ai-gateway-web + namespace: spacegate + labels: + app.kubernetes.io/name: ai-gateway-web + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: ai-gateway-web + template: + metadata: + labels: + app.kubernetes.io/name: ai-gateway-web + app.kubernetes.io/part-of: ai-gateway + spec: + containers: + - name: web + # 与 build-images.sh 输出一致;本地构建后 kubectl set image 更新 + image: ai-gateway/web:k8s-spa + imagePullPolicy: Never + ports: + - containerPort: 9080 + name: http + readinessProbe: + httpGet: + path: / + port: http + initialDelaySeconds: 2 + periodSeconds: 5 + resources: + requests: + cpu: 10m + memory: 32Mi +--- +apiVersion: v1 +kind: Service +metadata: + name: ai-gateway-web + namespace: spacegate + labels: + app.kubernetes.io/name: ai-gateway-web +spec: + type: LoadBalancer + selector: + app.kubernetes.io/name: ai-gateway-web + ports: + - name: http + port: 9080 + targetPort: http diff --git a/deploy/k8s/ai-gateway/ai-gateway-service.yaml b/deploy/k8s/ai-gateway/ai-gateway-service.yaml new file mode 100644 index 00000000..7efa8e98 --- /dev/null +++ b/deploy/k8s/ai-gateway/ai-gateway-service.yaml @@ -0,0 +1,71 @@ +# ai-gateway-service:限流 / 入队 / Worker / 回调 +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ai-gateway-service + labels: + app.kubernetes.io/name: ai-gateway-service + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: ai-gateway-service + template: + metadata: + labels: + app.kubernetes.io/name: ai-gateway-service + spec: + containers: + - name: ai-gateway-service + # 构建:见 deploy/k8s/ai-gateway/build-images.sh 或使用本地镜像 ai-gateway/service:dev + image: ai-gateway/service:dev + imagePullPolicy: IfNotPresent + ports: + - containerPort: 18080 + name: http + env: + - name: REDIS_URL + value: redis://ai-gateway-redis:6379/ + - name: AI_UPSTREAM_BASE_URL + value: http://ai-gateway-mock-upstream:9000 + - name: AI_REQUIRE_HTTPS_CALLBACK + value: "false" + - name: AI_GATEWAY_SERVICE_HOST + value: 0.0.0.0 + - name: AI_GATEWAY_SERVICE_PORT + value: "18080" + - name: RUST_LOG + value: info + readinessProbe: + httpGet: + path: /healthz + port: http + initialDelaySeconds: 3 + periodSeconds: 5 + livenessProbe: + httpGet: + path: /healthz + port: http + initialDelaySeconds: 10 + periodSeconds: 10 + resources: + requests: + cpu: 100m + memory: 64Mi + limits: + memory: 512Mi +--- +apiVersion: v1 +kind: Service +metadata: + name: ai-gateway-service + labels: + app.kubernetes.io/name: ai-gateway-service +spec: + selector: + app.kubernetes.io/name: ai-gateway-service + ports: + - name: http + port: 18080 + targetPort: http diff --git a/deploy/k8s/ai-gateway/apply-infra.sh b/deploy/k8s/ai-gateway/apply-infra.sh new file mode 100755 index 00000000..a418a0ae --- /dev/null +++ b/deploy/k8s/ai-gateway/apply-infra.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# 编译 Wasm + 部署 AI Gateway K8s 基础设施(不含默认 HTTPRoute ai-api / SgFilter) +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +ROOT="$(cd "$DIR/../../.." && pwd)" +KUSTOMIZE_FILE="$DIR/kustomization-infra.yaml" +WASM_SRC="$ROOT/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm" +WASM_DST="$DIR/files/spacegate_plugin_ai_gateway_queue.wasm" + +echo "==> 检查 SpaceGate 前置(namespace / GatewayClass / DaemonSet)" +if ! kubectl get namespace spacegate >/dev/null 2>&1; then + echo "ERROR: namespace 'spacegate' 不存在。请先执行:" >&2 + echo " ./scripts/deploy.sh k8s install-prereq" >&2 + echo " 或 kubectl apply -f $ROOT/resource/kube-manifests/" >&2 + exit 1 +fi + +echo "==> 移除默认 Demo 路由(若存在)" +kubectl delete -f "$DIR/httproute-ai.yaml" -n spacegate --ignore-not-found +kubectl delete -f "$DIR/sgfilter-ai-gateway-queue.yaml" -n spacegate --ignore-not-found + +echo "==> 编译 ai-gateway-queue Wasm" +cd "$ROOT" +rustup target add wasm32-wasip1 2>/dev/null || true +cargo build --release --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue + +mkdir -p "$DIR/files" +cp "$WASM_SRC" "$WASM_DST" + +echo "==> 应用 Kustomize(infra-only,无 ai-api HTTPRoute)" +KUST_BACKUP="$DIR/kustomization.yaml.full.bak" +cp "$DIR/kustomization.yaml" "$KUST_BACKUP" +cp "$DIR/kustomization-infra.yaml" "$DIR/kustomization.yaml" +kubectl apply -k "$DIR" +mv "$KUST_BACKUP" "$DIR/kustomization.yaml" + +echo "==> 确保 SpaceGate DaemonSet 使用 K8s 模式本地镜像" +SG_IMAGE="${SPACEGATE_K8S_IMAGE:-ai-gateway/spacegate:k8s}" +if kubectl get daemonset spacegate -n spacegate >/dev/null 2>&1; then + kubectl set image daemonset/spacegate spacegate="$SG_IMAGE" -n spacegate + kubectl rollout status daemonset/spacegate -n spacegate --timeout=180s +fi + +echo "==> 等待 AI Gateway Pod Ready" +kubectl wait --for=condition=ready pod \ + -l 'app.kubernetes.io/name in (ai-gateway-redis,ai-gateway-service,ai-gateway-wasm,ai-gateway-mock-upstream)' \ + -n spacegate \ + --timeout=180s + +echo "" +echo "部署完成(无默认 ai-api 路由)。" +echo " 验证: $DIR/verify-infra.sh" +echo "" +echo "后续:在管理界面或 kubectl 自行创建 HTTPRoute,并挂载 SgFilter / Wasm 插件。" +echo " Gateway 入口: ai-gateway(:9993,SpaceGate hostNetwork)" diff --git a/deploy/k8s/ai-gateway/apply.sh b/deploy/k8s/ai-gateway/apply.sh new file mode 100755 index 00000000..78d04ee7 --- /dev/null +++ b/deploy/k8s/ai-gateway/apply.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# 编译 Wasm + 打包 ConfigMap + 部署 AI Gateway K8s 栈 +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +ROOT="$(cd "$DIR/../../.." && pwd)" +WASM_SRC="$ROOT/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm" +WASM_DST="$DIR/files/spacegate_plugin_ai_gateway_queue.wasm" + +echo "==> 检查 SpaceGate 前置(namespace / GatewayClass / DaemonSet)" +if ! kubectl get namespace spacegate >/dev/null 2>&1; then + echo "ERROR: namespace 'spacegate' 不存在。请先安装 SpaceGate:" >&2 + echo " kubectl apply -f $ROOT/resource/kube-manifests/namespace.yaml" >&2 + echo " kubectl apply -f $ROOT/resource/kube-manifests/gatewayclass.yaml" >&2 + echo " kubectl apply -f $ROOT/resource/kube-manifests/spacegate-gateway.yaml" >&2 + exit 1 +fi + +echo "==> 编译 ai-gateway-queue Wasm" +cd "$ROOT" +rustup target add wasm32-wasip1 2>/dev/null || true +cargo build --release --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue + +mkdir -p "$DIR/files" +cp "$WASM_SRC" "$WASM_DST" + +echo "==> 应用 Kustomize" +kubectl apply -k "$DIR" + +echo "==> 等待 Pod Ready" +kubectl wait --for=condition=ready pod \ + -l app.kubernetes.io/part-of=ai-gateway \ + -n spacegate \ + --timeout=180s + +echo "" +echo "部署完成。验证:" +echo " $DIR/verify.sh" +echo "" +echo "网关入口(SpaceGate hostNetwork 监听 9993):" +echo " curl -i http://:9993/v1/chat/completions \\" +echo " -H 'X-RateLimit-Policy: abandon' -H 'X-Tenant-Id: demo' -d '{\"prompt\":\"hi\"}'" diff --git a/deploy/k8s/ai-gateway/build-images.sh b/deploy/k8s/ai-gateway/build-images.sh new file mode 100755 index 00000000..e933f87d --- /dev/null +++ b/deploy/k8s/ai-gateway/build-images.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# 构建 K8s 所需镜像:ai-gateway-service + SpaceGate(k8s 模式) +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +SG_ROOT="$(cd "$DIR/../../.." && pwd)" +# ai-gateway-dev 工作区根(spacegate 的父目录) +WORKSPACE_ROOT="$(cd "$SG_ROOT/.." && pwd)" + +SERVICE_IMAGE="${AI_GATEWAY_SERVICE_IMAGE:-ai-gateway/service:dev}" +SG_IMAGE="${SPACEGATE_K8S_IMAGE:-ai-gateway/spacegate:k8s}" + +echo "==> 构建 ai-gateway-service" +docker build -f "$DIR/docker/Dockerfile.ai-gateway-service" \ + --build-arg SPACEGATE_ROOT="$SG_ROOT" \ + -t "$SERVICE_IMAGE" \ + "$SG_ROOT" + +echo "==> 构建 SpaceGate(wasm + axum + k8s)" +docker build -f "$WORKSPACE_ROOT/docker/Dockerfile.spacegate-k8s" \ + -t "$SG_IMAGE" \ + "$SG_ROOT" + +WEB_IMAGE="${AI_GATEWAY_WEB_IMAGE:-ai-gateway/web:k8s-spa}" +echo "==> 构建管理 UI(spacegate-admin-fe SPA + nginx)" +if [[ ! -f "$WORKSPACE_ROOT/spacegate-admin-fe/dist/index.html" ]]; then + echo " 缺少 dist/index.html,尝试构建前端(需已 npm install)" + (cd "$WORKSPACE_ROOT/spacegate-admin-fe" && VITE_AI_GATEWAY_BASE_URL=/ai-gateway npm run build) || { + echo " 前端构建失败,请手动: cd spacegate-admin-fe && VITE_AI_GATEWAY_BASE_URL=/ai-gateway npm run build" + exit 1 + } +fi +docker build -f "$WORKSPACE_ROOT/docker/Dockerfile.web.k8s" \ + -t "$WEB_IMAGE" \ + "$WORKSPACE_ROOT" + +echo "Done." +echo " ai-gateway-service: $SERVICE_IMAGE" +echo " spacegate (k8s): $SG_IMAGE" +echo " admin web (k8s): $WEB_IMAGE" +echo "" +echo "更新管理 UI Deployment(本地 Docker Desktop 需 Never 拉取策略):" +echo " kubectl set image deployment/ai-gateway-web web=$WEB_IMAGE -n spacegate" +echo " kubectl rollout status deployment/ai-gateway-web -n spacegate" +echo "" +echo "更新 DaemonSet 镜像(若已安装 SpaceGate):" +echo " kubectl set image daemonset/spacegate spacegate=$SG_IMAGE -n spacegate" diff --git a/deploy/k8s/ai-gateway/docker/Dockerfile.ai-gateway-service b/deploy/k8s/ai-gateway/docker/Dockerfile.ai-gateway-service new file mode 100644 index 00000000..ec52d09e --- /dev/null +++ b/deploy/k8s/ai-gateway/docker/Dockerfile.ai-gateway-service @@ -0,0 +1,17 @@ +# ai-gateway-service K8s 镜像(多阶段:Rust 编译 + 最小运行层) +ARG RUST_IMAGE=rust:1-bookworm +ARG RUNTIME_IMAGE=debian:bookworm-slim + +FROM ${RUST_IMAGE} AS builder +ARG SPACEGATE_ROOT=/src +WORKDIR /src +COPY . . +RUN cargo build --release -p ai-gateway-service + +FROM ${RUNTIME_IMAGE} +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates tini \ + && rm -rf /var/lib/apt/lists/* +COPY --from=builder /src/target/release/ai-gateway-service /usr/local/bin/ai-gateway-service +EXPOSE 18080 +ENTRYPOINT ["/usr/bin/tini", "--"] +CMD ["/usr/local/bin/ai-gateway-service"] diff --git a/deploy/k8s/ai-gateway/files/.gitkeep b/deploy/k8s/ai-gateway/files/.gitkeep new file mode 100644 index 00000000..79d6aec9 --- /dev/null +++ b/deploy/k8s/ai-gateway/files/.gitkeep @@ -0,0 +1,2 @@ +# 占位:apply.sh 会把编译好的 wasm 复制到此目录 +# 勿手动删除本目录 diff --git a/deploy/k8s/ai-gateway/gateway-ai.yaml b/deploy/k8s/ai-gateway/gateway-ai.yaml new file mode 100644 index 00000000..08cd30d9 --- /dev/null +++ b/deploy/k8s/ai-gateway/gateway-ai.yaml @@ -0,0 +1,17 @@ +# AI 流量入口 Gateway(需已安装 GatewayClass: spacegate) +apiVersion: gateway.networking.k8s.io/v1beta1 +kind: Gateway +metadata: + name: ai-gateway + labels: + app.kubernetes.io/name: ai-gateway + app.kubernetes.io/part-of: ai-gateway +spec: + gatewayClassName: spacegate + listeners: + - name: http + port: 9993 + protocol: HTTP + allowedRoutes: + namespaces: + from: Same diff --git a/deploy/k8s/ai-gateway/httproute-ai.yaml b/deploy/k8s/ai-gateway/httproute-ai.yaml new file mode 100644 index 00000000..9fbf1f23 --- /dev/null +++ b/deploy/k8s/ai-gateway/httproute-ai.yaml @@ -0,0 +1,20 @@ +# /v1/* 路由到 mock 上游;Wasm 插件通过 SgFilter 挂载在本 Route 上(仅挂一次,避免双倍限流) +apiVersion: gateway.networking.k8s.io/v1beta1 +kind: HTTPRoute +metadata: + name: ai-api + labels: + app.kubernetes.io/name: ai-api + app.kubernetes.io/part-of: ai-gateway +spec: + parentRefs: + - name: ai-gateway + namespace: spacegate + rules: + - matches: + - path: + type: PathPrefix + value: /v1/ + backendRefs: + - name: ai-gateway-mock-upstream + port: 9000 diff --git a/deploy/k8s/ai-gateway/kustomization-infra.yaml b/deploy/k8s/ai-gateway/kustomization-infra.yaml new file mode 100644 index 00000000..fdbbcabc --- /dev/null +++ b/deploy/k8s/ai-gateway/kustomization-infra.yaml @@ -0,0 +1,22 @@ +# 基础设施栈:不含默认 HTTPRoute ai-api / SgFilter +# 用法:kubectl apply -k . --kustomize-file kustomization-infra.yaml(见 apply-infra.sh) +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +namespace: spacegate + +resources: + - redis.yaml + - mock-upstream.yaml + - wasm-server.yaml + - ai-gateway-service.yaml + - gateway-ai.yaml + - spacegate-rbac-cluster.yaml + - admin-ui.yaml + +configMapGenerator: + - name: ai-gateway-queue-wasm + files: + - files/spacegate_plugin_ai_gateway_queue.wasm + options: + disableNameSuffixHash: true diff --git a/deploy/k8s/ai-gateway/kustomization.yaml b/deploy/k8s/ai-gateway/kustomization.yaml new file mode 100644 index 00000000..a6624099 --- /dev/null +++ b/deploy/k8s/ai-gateway/kustomization.yaml @@ -0,0 +1,25 @@ +# AI Gateway 队列限流插件 — K8s 一键部署(Kustomize) +# 前置:spacegate 命名空间、GatewayClass、SpaceGate DaemonSet 已安装 +# 用法:./apply.sh +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +namespace: spacegate + +resources: + - redis.yaml + - mock-upstream.yaml + - wasm-server.yaml + - ai-gateway-service.yaml + - gateway-ai.yaml + - httproute-ai.yaml + - sgfilter-ai-gateway-queue.yaml + # 可选:Higress WasmPlugin(不含 clusters,生产建议用 SgFilter) + # - wasmplugin-ai-gateway-queue.yaml + +configMapGenerator: + - name: ai-gateway-queue-wasm + files: + - files/spacegate_plugin_ai_gateway_queue.wasm + options: + disableNameSuffixHash: true diff --git a/deploy/k8s/ai-gateway/mock-upstream.yaml b/deploy/k8s/ai-gateway/mock-upstream.yaml new file mode 100644 index 00000000..f974836f --- /dev/null +++ b/deploy/k8s/ai-gateway/mock-upstream.yaml @@ -0,0 +1,49 @@ +# 模拟上游 LLM(生产环境替换为真实模型 Service) +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ai-gateway-mock-upstream + labels: + app.kubernetes.io/name: ai-gateway-mock-upstream + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: ai-gateway-mock-upstream + template: + metadata: + labels: + app.kubernetes.io/name: ai-gateway-mock-upstream + spec: + containers: + - name: echo + image: hashicorp/http-echo:1.0.0 + args: + - -text={"choices":[{"message":{"content":"ok"}}]} + - -listen=:9000 + ports: + - containerPort: 9000 + name: http + readinessProbe: + tcpSocket: + port: http + periodSeconds: 5 + resources: + requests: + cpu: 10m + memory: 16Mi +--- +apiVersion: v1 +kind: Service +metadata: + name: ai-gateway-mock-upstream + labels: + app.kubernetes.io/name: ai-gateway-mock-upstream +spec: + selector: + app.kubernetes.io/name: ai-gateway-mock-upstream + ports: + - name: http + port: 9000 + targetPort: http diff --git a/deploy/k8s/ai-gateway/redis.yaml b/deploy/k8s/ai-gateway/redis.yaml new file mode 100644 index 00000000..d0babe4d --- /dev/null +++ b/deploy/k8s/ai-gateway/redis.yaml @@ -0,0 +1,54 @@ +# Redis 7+(队列 / 限流 / 结果存储) +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ai-gateway-redis + labels: + app.kubernetes.io/name: ai-gateway-redis + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: ai-gateway-redis + template: + metadata: + labels: + app.kubernetes.io/name: ai-gateway-redis + spec: + containers: + - name: redis + image: redis:7-alpine + ports: + - containerPort: 6379 + name: redis + readinessProbe: + exec: + command: ["redis-cli", "ping"] + initialDelaySeconds: 3 + periodSeconds: 5 + livenessProbe: + exec: + command: ["redis-cli", "ping"] + initialDelaySeconds: 10 + periodSeconds: 10 + resources: + requests: + cpu: 50m + memory: 64Mi + limits: + memory: 256Mi +--- +apiVersion: v1 +kind: Service +metadata: + name: ai-gateway-redis + labels: + app.kubernetes.io/name: ai-gateway-redis +spec: + selector: + app.kubernetes.io/name: ai-gateway-redis + ports: + - name: redis + port: 6379 + targetPort: redis diff --git a/deploy/k8s/ai-gateway/sgfilter-ai-gateway-queue.yaml b/deploy/k8s/ai-gateway/sgfilter-ai-gateway-queue.yaml new file mode 100644 index 00000000..0a934017 --- /dev/null +++ b/deploy/k8s/ai-gateway/sgfilter-ai-gateway-queue.yaml @@ -0,0 +1,48 @@ +# ai-gateway-queue Wasm 插件(SgFilter 含完整 clusters 映射,推荐 K8s 用法) +apiVersion: spacegate.idealworld.group/v1 +kind: SgFilter +metadata: + name: ai-gateway-queue + labels: + app.kubernetes.io/name: ai-gateway-queue + app.kubernetes.io/part-of: ai-gateway +spec: + targetRefs: + - kind: httproute + name: ai-api + namespace: spacegate + filters: + - code: wasm + name: ai-gateway-queue + enable: true + config: + # Harbor OCI 制品(本地 push 脚本见 open-source/harbor/push-ai-gateway-queue.sh) + url: oci+http://host.docker.internal:9081/ai-gateway/ai-gateway-queue:v1.0.0 + sha256: sha256:8e2b1d3271b7e7c44b01e96e1844e0a231896fbe21b23ba07a01f71a58b9e697 + oci_auth: + registry: host.docker.internal:9081 + username: admin + password: Harbor12345 + fail_strategy: fail_close + validate_on_create: false + plugin_name: ai-gateway-queue + plugin_root_id: ai-gateway-queue-root + plugin_vm_id: ai-gateway-queue-vm + vm_pool_size: 4 + wait_vm_pool_size: 4 + limits: + max_memory_pages: 64 + fuel_per_call: 20000000 + epoch_timeout_millis: 50 + max_body_bytes: 33554432 + max_pending_calls: 1 + plugin_config: + service_cluster: ai-gateway-service + service_authority: ai-gateway-service + rate_limit_path: /v1/ratelimit/check + enqueue_path: /v1/queue/enqueue + wait_path: /v1/queue/enqueue-and-wait + service_timeout_ms: 65000 + require_policy: true + clusters: + ai-gateway-service: http://ai-gateway-service:18080 diff --git a/deploy/k8s/ai-gateway/spacegate-rbac-cluster.yaml b/deploy/k8s/ai-gateway/spacegate-rbac-cluster.yaml new file mode 100644 index 00000000..5b753f06 --- /dev/null +++ b/deploy/k8s/ai-gateway/spacegate-rbac-cluster.yaml @@ -0,0 +1,28 @@ +# 补充 SpaceGate ServiceAccount 对 SgFilter / Gateway API 的集群级 list 权限 +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: spacegate-k8s-config +rules: + - apiGroups: ["gateway.networking.k8s.io"] + resources: ["gateways", "httproutes", "gateways/status", "httproutes/status"] + verbs: ["get", "list", "watch", "update"] + - apiGroups: ["spacegate.idealworld.group"] + resources: ["sgfilters", "httpspaceroutes"] + verbs: ["get", "list", "watch"] + - apiGroups: ["extensions.higress.io"] + resources: ["wasmplugins"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: spacegate-k8s-config +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: spacegate-k8s-config +subjects: + - kind: ServiceAccount + name: spacegate + namespace: spacegate diff --git a/deploy/k8s/ai-gateway/verify-infra.sh b/deploy/k8s/ai-gateway/verify-infra.sh new file mode 100755 index 00000000..443f2ccf --- /dev/null +++ b/deploy/k8s/ai-gateway/verify-infra.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# 基础设施部署验证(不要求 HTTPRoute / 网关流量) +set -euo pipefail +NS=spacegate +pass=0 +fail=0 + +check() { + local name="$1" expect="$2" got="$3" + if [[ "$got" == "$expect" ]]; then + echo "✅ $name" + pass=$((pass + 1)) + else + echo "❌ $name 期望=$expect 实际=$got" + fail=$((fail + 1)) + fi +} + +echo "==> Pod 状态" +kubectl get pods -n "$NS" -l 'app.kubernetes.io/name in (ai-gateway-redis,ai-gateway-service,ai-gateway-wasm,ai-gateway-mock-upstream)' + +echo "==> 不应存在默认 HTTPRoute ai-api" +if kubectl get httproute ai-api -n "$NS" >/dev/null 2>&1; then + echo "❌ HTTPRoute ai-api 仍存在(应已删除)" + fail=$((fail + 1)) +else + echo "✅ 无 HTTPRoute ai-api" + pass=$((pass + 1)) +fi + +echo "==> 不应存在默认 SgFilter ai-gateway-queue" +if kubectl get sgfilter ai-gateway-queue -n "$NS" >/dev/null 2>&1; then + echo "❌ SgFilter ai-gateway-queue 仍存在" + fail=$((fail + 1)) +else + echo "✅ 无 SgFilter ai-gateway-queue" + pass=$((pass + 1)) +fi + +echo "==> Gateway ai-gateway 存在" +kubectl get gateway ai-gateway -n "$NS" >/dev/null && check "Gateway ai-gateway" "ok" "ok" || fail=$((fail + 1)) + +echo "==> SpaceGate DaemonSet 运行中" +if kubectl get pods -n "$NS" -l app=spacegate -o jsonpath='{.items[0].status.phase}' 2>/dev/null | grep -q Running; then + echo "✅ spacegate DaemonSet Running" + pass=$((pass + 1)) +else + echo "❌ spacegate DaemonSet 未 Running" + fail=$((fail + 1)) +fi + +echo "==> ai-gateway-service 健康(集群内 curl)" +if kubectl run curl-health-$RANDOM --rm -i --restart=Never -n "$NS" \ + --image=curlimages/curl:8.5.0 --quiet -- \ + curl -sf http://ai-gateway-service:18080/healthz >/dev/null 2>&1; then + echo "✅ ai-gateway-service /healthz" + pass=$((pass + 1)) +else + echo "❌ ai-gateway-service /healthz" + fail=$((fail + 1)) +fi + +echo "==> Wasm HTTP 分发" +if kubectl exec -n "$NS" deploy/ai-gateway-wasm -- wget -qO- http://127.0.0.1/spacegate_plugin_ai_gateway_queue.wasm >/dev/null 2>&1; then + echo "✅ ai-gateway-wasm 可下载 .wasm" + pass=$((pass + 1)) +else + echo "❌ ai-gateway-wasm" + fail=$((fail + 1)) +fi + +echo "=== $pass 通过, $fail 失败 ===" +exit "$fail" diff --git a/deploy/k8s/ai-gateway/verify.sh b/deploy/k8s/ai-gateway/verify.sh new file mode 100755 index 00000000..5256d512 --- /dev/null +++ b/deploy/k8s/ai-gateway/verify.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# K8s 部署后冒烟验证 +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +NS=spacegate +GW="http://127.0.0.1:9993/v1/chat/completions" +PF="" + +pass=0 +fail=0 +check() { + local name="$1" expect="$2" got="$3" + if [[ "$got" == "$expect" ]]; then + echo "✅ $name ($got)" + pass=$((pass + 1)) + else + echo "❌ $name 期望=$expect 实际=$got" + fail=$((fail + 1)) + fi +} + +echo "==> 后端 health" +curl -sf "http://127.0.0.1:18080/healthz" >/dev/null 2>&1 \ + && echo "✅ ai-gateway-service /healthz(需 port-forward 或 hostNetwork 可达)" \ + || kubectl exec -n "$NS" deploy/ai-gateway-service -- wget -qO- http://127.0.0.1:18080/healthz >/dev/null \ + && echo "✅ ai-gateway-service /healthz(集群内)" \ + || { echo "⚠️ 跳过直连 health(请 kubectl port-forward svc/ai-gateway-service 18080:18080)"; } + +T="k8s-verify-$(date +%s)" +curl -sf -X PUT "http://127.0.0.1:18080/v1/admin/tenant-rate-limits" \ + -H 'Content-Type: application/json' \ + -d "{\"tenant\":\"$T\",\"rps\":5,\"burst\":5}" >/dev/null 2>&1 \ + || { kubectl port-forward -n "$NS" svc/ai-gateway-service 18080:18080 >/tmp/pf-18080.log 2>&1 & PF=$!; sleep 2; } +curl -sf -X PUT "http://127.0.0.1:18080/v1/admin/tenant-rate-limits" \ + -H 'Content-Type: application/json' \ + -d "{\"tenant\":\"$T\",\"rps\":5,\"burst\":5}" >/dev/null || true + +echo "==> 网关插件 (tenant=$T)" +check "缺 Policy" 400 "$(curl -s -o /dev/null -w '%{http_code}' -X POST "$GW" -H "X-Tenant-Id: $T" -H 'Content-Type: application/json' -d '{}')" +check "abandon 配额内" 200 "$(curl -s -o /dev/null -w '%{http_code}' -X POST "$GW" -H 'X-RateLimit-Policy: abandon' -H "X-Tenant-Id: $T" -H 'Content-Type: application/json' -d '{"p":1}')" + +for i in $(seq 1 10); do + curl -s -o /dev/null -X POST "$GW" -H 'X-RateLimit-Policy: abandon' -H "X-Tenant-Id: $T" -H 'Content-Type: application/json' -d "{\"p\":$i}" || true +done +check "abandon 超额" 429 "$(curl -s -o /dev/null -w '%{http_code}' -X POST "$GW" -H 'X-RateLimit-Policy: abandon' -H "X-Tenant-Id: $T" -H 'Content-Type: application/json' -d '{"p":99}')" + +kill "$PF" 2>/dev/null || true +echo "=== $pass 通过, $fail 失败 ===" +exit "$fail" diff --git a/deploy/k8s/ai-gateway/wasm-server.yaml b/deploy/k8s/ai-gateway/wasm-server.yaml new file mode 100644 index 00000000..9fbd8f5b --- /dev/null +++ b/deploy/k8s/ai-gateway/wasm-server.yaml @@ -0,0 +1,57 @@ +# 集群内 HTTP 分发 Wasm 二进制(SpaceGate 通过 http:// 拉取,无需改 DaemonSet 挂载) +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ai-gateway-wasm + labels: + app.kubernetes.io/name: ai-gateway-wasm + app.kubernetes.io/part-of: ai-gateway +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: ai-gateway-wasm + template: + metadata: + labels: + app.kubernetes.io/name: ai-gateway-wasm + spec: + containers: + - name: nginx + image: nginx:1.27-alpine + ports: + - containerPort: 80 + name: http + volumeMounts: + - name: wasm + mountPath: /usr/share/nginx/html + readOnly: true + readinessProbe: + httpGet: + path: /spacegate_plugin_ai_gateway_queue.wasm + port: http + initialDelaySeconds: 2 + periodSeconds: 5 + resources: + requests: + cpu: 10m + memory: 32Mi + volumes: + - name: wasm + configMap: + name: ai-gateway-queue-wasm + defaultMode: 0444 +--- +apiVersion: v1 +kind: Service +metadata: + name: ai-gateway-wasm + labels: + app.kubernetes.io/name: ai-gateway-wasm +spec: + selector: + app.kubernetes.io/name: ai-gateway-wasm + ports: + - name: http + port: 80 + targetPort: http diff --git a/deploy/k8s/ai-gateway/wasmplugin-ai-gateway-queue.yaml b/deploy/k8s/ai-gateway/wasmplugin-ai-gateway-queue.yaml new file mode 100644 index 00000000..20e10f19 --- /dev/null +++ b/deploy/k8s/ai-gateway/wasmplugin-ai-gateway-queue.yaml @@ -0,0 +1,27 @@ +# 可选:Higress WasmPlugin 方式(defaultConfig 不含 clusters,需改用 SgFilter 或扩展 CRD 转换) +# 默认注释掉,见 kustomization.yaml +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-gateway-queue + labels: + app.kubernetes.io/name: ai-gateway-queue + app.kubernetes.io/part-of: ai-gateway +spec: + url: http://ai-gateway-wasm/spacegate_plugin_ai_gateway_queue.wasm + pluginName: ai-gateway-queue + failStrategy: FAIL_CLOSE + phase: AUTHZ + priority: 100 + defaultConfigDisable: true + matchRules: + - ingress: + - ai-api + config: + service_cluster: ai-gateway-service + service_authority: ai-gateway-service + rate_limit_path: /v1/ratelimit/check + enqueue_path: /v1/queue/enqueue + wait_path: /v1/queue/enqueue-and-wait + service_timeout_ms: 65000 + require_policy: true diff --git a/deploy/push-wasm-oci.sh b/deploy/push-wasm-oci.sh new file mode 100755 index 00000000..e17614b1 --- /dev/null +++ b/deploy/push-wasm-oci.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# 编译 ai-gateway-queue Wasm 并推送到 OCI 仓库(需 oras + 仓库登录) +set -euo pipefail + +REGISTRY="${REGISTRY:?请设置 REGISTRY,例如 ghcr.io/your-org}" +TAG="${TAG:-v1.0.0}" +IMAGE="${IMAGE:-ai-gateway-queue}" +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +WASM="$ROOT/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm" + +if ! command -v oras >/dev/null 2>&1; then + echo "ERROR: 未找到 oras。安装: brew install oras" >&2 + exit 1 +fi + +echo "==> 编译 Wasm" +cd "$ROOT" +rustup target add wasm32-wasip1 2>/dev/null || true +cargo build --release \ + --target wasm32-wasip1 \ + --manifest-path plugins/wasm/Cargo.toml \ + -p spacegate_plugin_ai_gateway_queue + +DIGEST=$(shasum -a 256 "$WASM" | awk '{print $1}') +REF="${REGISTRY}/${IMAGE}:${TAG}" + +echo "==> 推送到 ${REF}" +oras push "$REF" \ + --artifact-type application/vnd.module.wasm.content.layer.v1+wasm \ + "${WASM}:application/wasm" + +echo "" +echo "推送完成。在 SgFilter / WasmPlugin 中使用:" +echo " url: oci://${REF}" +echo " sha256: sha256:${DIGEST}" diff --git a/docs/CODE_REVIEW.md b/docs/CODE_REVIEW.md new file mode 100644 index 00000000..7483d97a --- /dev/null +++ b/docs/CODE_REVIEW.md @@ -0,0 +1,222 @@ +# Spacegate 代码审核报告 + +> 文档生成日期:2026-05-12 +> 范围:整个 Spacegate workspace(`spacegate-kernel` / `spacegate-plugin` / `spacegate-model` / `spacegate-config` / `spacegate-shell` / extensions / binaries / SDK) + +本文档对仓库整体架构、各 crate 功能点、亮点及风险进行归纳,便于评审与后续修复跟踪。 + +--- + +## 一、项目总览 + +Spacegate 是基于 Rust 与 hyper 的 **库优先(library-first)** API 网关,强调云原生(Kubernetes Gateway API)与插件扩展。`Cargo.toml` 中 workspace 成员包含: + +- **二进制**:`binary/spacegate`、`binary/admin-server` +- **核心库**:`crates/kernel`、`crates/plugin`、`crates/model`、`crates/config`、`crates/shell` +- **扩展**:`crates/extension/axum`、`crates/extension/redis` +- **示例**:`examples/sayhello`、`examples/socks5-proxy`、`examples/mitm-proxy` 等 + +分层关系(自下而上): + +| 层 | crate | 职责 | +|----|--------|------| +| 数据模型 | `spacegate-model` | 网关/路由/插件/匹配规则 DTO,可选 ts-rs 导出 | +| 配置后端 | `spacegate-config` | 文件 / K8s / Redis / 内存:CRUD + 事件监听 | +| 核心运行时 | `spacegate-kernel` | TCP 监听、HTTPS、路由匹配、Backend、helper layer | +| 扩展库 | `spacegate-ext-axum`、`spacegate-ext-redis` | 全局 axum 服务、Redis 客户端仓库 | +| 插件系统 | `spacegate-plugin` | Plugin trait、动态库、内置插件、挂载点 | +| 集成入口 | `spacegate-shell` | 配置与内核映射、热更新、生命周期 | +| 二进制 | `spacegate`、`spacegate-admin-server` | 网关进程、管理后台 | +| SDK | `sdk/admin-client` | TypeScript,对接 admin-server | + +--- + +## 二、`spacegate-kernel` + +### 2.1 主要功能 + +1. **TCP 监听与协议嗅探**:`SgListen` 通过 `peek` 后由 `TcpService::sniff` 选择 HTTP/HTTPS/SOCKS5 等。 +2. **HTTP/1.1、HTTP/2、WebSocket、HTTPS**:`Http`/`Https` 实现 `TcpService`;`HyperServiceAdapter` 将请求转为 `SgBody` 并注入 `PeerAddr`、`EnterTime`、`Reflect`。 +3. **网关装配**:`http_gateway::Gateway`(builder)含网关级插件链、`HttpRoute` 表、`Reloader` 支持热更新路由。 +4. **主机名匹配**:`HostnameTree`(`match_hostname.rs`),支持 IPv4/IPv6、通配域名、优先级排序。 +5. **路由匹配**:`HttpRouteMatch` 支持 path(Exact/Prefix/Regex)、headers、query、method;多重 match 在单条规则内为 AND;`Vec` 层为 OR。 +6. **后端**:`http_backend_service`(`x-forwarded-for`、WebSocket 升级与双向拷贝)、`static_file_service`、全局 `ClientRepo` 与可插拔 `HttpClient`。 +7. **辅助层**:`TimeoutLayer`、`ReloadLayer`(`ShardedLock`)、`Balancer`(`IpHash` / 加权随机)、`MapRequest`/`MapFuture`、`RouterService`。 +8. **扩展与工具**:`Reflect`、`Defer`、`OriginalIpAddr`、`MatchedSgRouter`、`Authorization`、`SgBody`(dump 后可克隆)。 + +### 2.2 亮点 + +- `BoxLayer` 与 tower 组合良好,网关/路由/规则/后端多级挂载清晰。 +- `Reloader` + `OnceLock` 读多写少场景友好。 +- `HostnameTree` 设计文档与测试较完整。 + +### 2.3 风险与问题 + +- **TLS 客户端默认跳过服务端证书校验**:`ClientRepo::default` 使用 `get_rustls_config_dangerous`;`SgParameters::ignore_tls_verification` 在代码中未见实质接线。**生产环境风险高**,建议默认走系统根证书,仅显式配置才关闭校验。 +- **静态文件路径规范化**:`canonicalize` 失败时回退到 `dir`,可能削弱「必须在目录下」的语义,建议失败即 404。 +- **`HttpBackendService` 使用 `unwrap_unchecked`**:可改为显式处理以更清晰。 +- **`create_http_router` 中 hostname 索引**:当 `route.hostnames` 非空时,新建节点误落到 `"*"` 的逻辑需核对是否为 bug(应绑定具体 hostname)。 +- **`SgBody::clone` 未 dump 会 panic**:插件作者易踩坑,需在文档中突出。 +- **方法匹配注释与 `Vec` 的 OR 语义**:注释若写「仅当指定 method」易与实现不一致。 + +--- + +## 三、`spacegate-plugin` + +### 3.1 主要功能 + +1. **Plugin trait**:`CODE`、`call`、`create`,可选 `MONO`、`schema_opt`、元数据。 +2. **`PluginRepository`**:全局注册表、实例 CRUD、`register_dylib`、快照与挂载追踪。 +3. **`PluginInstance`**:`ArcSwap` 热替换函数、生命周期钩子、`DropTracer` 防悬挂挂载索引。 +4. **挂载点**:网关 / 路由 / 规则 / 后端四级 `MountPointIndex`。 +5. **内置插件(按 feature)**:如 `static-resource`、`limit`(Redis Lua)、`header-modifier`、`redirect`、`rewrite`、`set-version`、`set-scheme`、`maintenance`、`inject`、`east-west-traffic-white-list`,以及 Redis 系列(`redis-count`、`redis-limit`、`redis-time-range`、`redis-dynamic-route`)。 + +### 3.2 亮点 + +- `PluginError` 统一错误响应与 `X-Plugin-Error` 头。 +- 部分 Redis 插件含 testcontainers 集成测试。 + +### 3.3 风险与问题 + +- **`redirect` 插件**:解析 URL 后未真正返回 3xx 或未改写请求,接近 no-op,需补全实现。 +- **遗留/禁用模块**:`breaker.rs` 空文件;`decompression`/`status`/`retry` 等与旧 API 耦合且未在 `register_prelude` 启用,建议清理或重写。 +- **`SystemTime::now().duration_since(UNIX_EPOCH).expect(...)`**:时间异常时可能 panic。 +- **仓库锁**:`RwLock` + 多层 `expect`;钩子里若再次操作仓库可能死锁,需在文档约束。 +- **`reflect` 扩展缺失会 panic**:非标准入口构造的请求需注意。 +- **`FromBackend::unsafe new` 用法**:可与实际调用路径再核对是否必须 unsafe。 + +--- + +## 四、`spacegate-model` + +### 4.1 主要功能 + +- `SgGateway`、`SgHttpRoute`、`SgBackendRef`、`BackendHost`、`PluginInstanceId/Name`、`PluginConfig`、`PluginInstanceMap`。 +- 可选 `typegen`(ts-rs)供前端/SDK。 +- K8s 相关扩展(CRD 等)。 + +### 4.2 风险与问题 + +- **`PluginInstanceName` 的 `Display` 与 `FromStr` 不一致**:Mono 显示为 `m` 而解析期望 `g`,可能影响依赖字符串往返的配置/通道。 +- **`PluginInstanceMap` 反序列化**:错误路径使用 `eprintln!`,建议改为 `tracing`。 + +--- + +## 五、`spacegate-config` + +### 5.1 主要功能 + +- Trait:`Create`、`Retrieve`、`Update`、`Delete`;`CreateListener` + `Listen`;`ConfigType` / `ConfigEventType`。 +- **实现**:`Memory`(静态)、`Fs`(目录布局 + Unix SIGHUP / Windows notify)、`K8s`(多资源 watch + SIGHUP 全局重载)、`Redis`(hash + pubsub)。 +- **Discovery**:实例列表与可选后端发现(如 fs 下读 `/var/www`)。 + +### 5.2 风险与问题 + +- **K8s 路由事件**:`process_http_spaceroute_event` 中 Applied 与 Delete 的事件类型是否应区分 Update/Delete,需与 shell 中「全量拉路由」行为对照,避免语义混淆。 +- **监听任务中 `send(...).expect`**:通道关闭会导致 panic。 +- **`Fs::modify_cached` 全目录删建**:中断可能丢配置,宜加备份或原子写。 +- **`redis/listen.rs` 中未使用的 `CHANGE_CACHE`**:死代码。 +- **`RedisListener::CONFIG_LISTENER_NAME` 误写为 `"file"`**:应为 `"redis"` 以免日志误导。 + +--- + +## 六、`spacegate-shell` + +### 6.1 主要功能 + +- `startup_file` / `startup_k8s` / `startup_redis` / `startup_static` → 统一 `startup`。 +- `RunningSgGateway`:`global_init`、`global_reset`、`global_update`(`Reloader` 热更路由)。 +- 配置到内核:`collect_http_route`、`global_batch_mount_plugin`、K8s Service 扩展注入。 +- 启用 `ext-axum` 时:健康检查、`/control/push_event`、静态页等。 + +### 6.2 风险与问题 + +- **Route 类事件**:handler 对 Create/Update/Delete 一律 `retrieve_config_item_all_routes` 后整体更新,语义依赖「全量正确」;与 K8s 事件类型需一起审视。 +- **插件初始化失败**:当前多为日志后继续,可按策略支持 fail-fast。 +- **全局 `Mutex` 中毒**:`expect("poisoned lock")` 后难以恢复。 +- **TLS `enable_secret_extraction`**:生产宜可配置关闭。 + +--- + +## 七、扩展库 + +### `spacegate-ext-redis` + +- `RedisClient::get_conn` / `From<&str>` 使用 `unwrap`/`expect`,配置错误易 panic;建议提供 `try_*` API。 + +### `spacegate-ext-axum` + +- `GlobalAxumServer` 关停路径存在 `expect`;`InternalError` 里有 `unwrap` 组 Response。 + +--- + +## 八、`binary/spacegate` + +- Clap 参数:`file:`/`k8s:`/`redis:`/`static:`;可选动态库目录扫描加载。 +- 缺 feature 时 dylib 仅 `eprintln`,建议统一 tracing。 + +--- + +## 九、`binary/admin-server` + +### 主要功能 + +- `/config/*`、`/plugin/*`、`/auth/login`、`/discovery/*`。 +- JWT + 可选 SK 摘要;`X-Client-Version` / `X-Server-Version` 乐观并发。 +- 发现:实例健康、插件列表/schema 缓存、向网关 `push_event` 触发重载。 + +### 风险与问题 + +- **空文件**:`mw/instance_select.rs` 等遗留。 +- **跨平台**:`clap` 等处 `unix` 专有 import/默认值未守卫时 Windows 编译可能失败。 +- **健康检查与 `sync_attr_cache` 缓存**:若使用 `Instant::elapsed() >= Duration::ZERO` 判断是否过期,逻辑恒为「已过期」,缓存失效——应改为与 `Instant::now()` 比较。 +- **依赖版本**:如 `tower-http` 与 workspace 不一致可能导致重复编译。 +- **未配置鉴权时中间件放行**:部署文档需强调必须配置密钥。 + +--- + +## 十、`sdk/admin-client` + +- Axios 封装,与 admin-server API 对齐;版本冲突与 401 自定义异常。 +- 注意全局 client 与 `clientVersion` 刷新页丢失导致的首次 409。 + +--- + +## 十一、示例 + +- **sayhello**:动态库插件最小示例。 +- **socks5-proxy**:`TcpService` + 端口多协议嗅探。 +- **mitm-proxy**:CONNECT + 动态证书 MITM 演示。 + +--- + +## 十二、横向问题汇总 + +| 优先级 | 问题 | +|--------|------| +| 高 | TLS 默认信任任意后端证书;`ignore_tls_verification` 未接线 | +| 高 | `redirect` 插件未真正重定向 | +| 高 | `GatewayRouter` hostname 索引与通配 `*` 的逻辑需复核 | +| 高 | K8s 监听中路由事件类型与 shell 全量更新语义 | +| 高 | `PluginInstanceName` Display/FromStr 不一致 | +| 高 | admin-server 健康/attr 缓存时间判断错误 | +| 中 | 大量 `unwrap`/`expect`;Redis/Axum 扩展 panic 路径 | +| 中 | Windows 编译(unix-only 模块) | +| 中 | 死代码与误填常量(如 Redis listener 名称) | +| 低 | tracing 替代 eprintln;依赖版本对齐 | + +--- + +## 十三、建议修复顺序(供迭代跟踪) + +1. **安全与正确性**:TLS 默认策略、`redirect`、hostname 路由索引、K8s 事件语义、PluginInstanceName 往返、admin-server 缓存判断。 +2. **健壮性**:减少 expect;Redis `try_get`;插件钩子使用规范文档。 +3. **可维护性**:清理 breaker/status 等废弃路径;统一 tower-http 版本;修正 `CONFIG_LISTENER_NAME`。 + +--- + +## 十四、修订历史 + +| 日期 | 说明 | +|------|------| +| 2026-05-12 | 初版:基于全仓库结构与关键源码路径的审核汇总 | diff --git a/docs/ai-gateway-queue-admin-ui-guide.md b/docs/ai-gateway-queue-admin-ui-guide.md new file mode 100644 index 00000000..64f0ffe9 --- /dev/null +++ b/docs/ai-gateway-queue-admin-ui-guide.md @@ -0,0 +1,524 @@ +# AI 网关排队限流插件 — 管理界面配置指南 + +本文说明如何在 **SpaceGate Admin 管理界面** 中配置 **AI 请求队列网关**(`ai-gateway-queue` Wasm 插件),包括插件实例创建、网关/路由挂载、租户配额,以及客户端请求头约定。 + +相关文档: + +- 插件行为与 API:[`plugins/wasm/ai-gateway-queue/README.md`](../plugins/wasm/ai-gateway-queue/README.md) +- 编译与 K8s 部署:[`deploy/README.md`](../deploy/README.md) +- 测试用例:[`ai-gateway-queue-test-spec.md`](ai-gateway-queue-test-spec.md) + +--- + +## 1. 配置全景 + +管理界面上的配置分 **三层**,需按顺序完成: + +```text +┌─────────────────────────────────────────────────────────────┐ +│ ① 插件实例(插件页 → AI 请求队列网关) │ +│ 写入 plugin/wasm.ai-gateway-queue.json │ +│ 含 Wasm URL、后端地址、plugin_config 等 │ +└───────────────────────────┬─────────────────────────────────┘ + │ +┌───────────────────────────▼─────────────────────────────────┐ +│ ② 挂载引用(网关页 或 路由页 → 插件列表) │ +│ 仅引用 { code: wasm, name: ai-gateway-queue } │ +│ ⚠ 只选一层挂载,勿 Gateway + Route 重复 │ +└───────────────────────────┬─────────────────────────────────┘ + │ +┌───────────────────────────▼─────────────────────────────────┐ +│ ③ 租户配额(ai-gateway-service Admin API) │ +│ UI「队列配额」Tab 当前为占位;可用 API 或 curl 配置 │ +└─────────────────────────────────────────────────────────────┘ +``` + +| 层级 | 管理界面入口 | 落盘 / 存储 | +|------|-------------|-------------| +| 插件实例 | **插件** → Tab **AI** → **AI 请求队列网关** | `plugin/wasm.ai-gateway-queue.json` | +| 挂载引用 | **网关** 或 **路由** → 插件列表 | `gateway/{name}/config.json` 或 `route/{route}.json` | +| 租户配额 | API(见 §6) | Redis | + +--- + +## 2. 前置条件 + +### 2.1 依赖服务 + +| 服务 | 默认端口 | 说明 | +|------|---------|------| +| **spacegate-admin-server** | 9992(开发)/ 9080(Docker 管理端) | 读写 SpaceGate 配置 | +| **spacegate-admin-fe** | 4000 | Vue 管理界面 | +| **SpaceGate 网关** | 9993 | 加载 Wasm 并转发流量 | +| **ai-gateway-service** | 18080 | 限流 / 入队 / Worker 后端 | +| **Redis 7+** | 6379 | 令牌桶与队列 | +| **上游 LLM** | 9000(示例) | HTTPRoute 后端 | + +### 2.2 启动管理界面(本地开发) + +```bash +# 终端 1:Admin 后端(文件配置模式示例) +cd spacegate +cargo run -p spacegate-admin-server -- -c file:.docker/ai-gateway-demo + +# 终端 2:Admin 前端 +cd spacegate-admin-fe +npm install +npm run dev +# 浏览器打开 http://localhost:4000 +``` + +Docker 环境可直接访问 **`http://localhost:9080`**(`ai-gateway-web` 容器)。 + +### 2.3 配置 ai-gateway-service 地址(重要) + +插件 Drawer 中的 **Schema 表单**、**文档 Tab** 以及未来的 **租户配额** 均通过 `ai-gateway-service` 的 Admin API 拉取: + +```bash +# spacegate-admin-fe/.env.local(或构建时环境变量) +VITE_AI_GATEWAY_BASE_URL=http://127.0.0.1:18080 +``` + +| 是否配置 | 效果 | +|---------|------| +| **已配置** | Schema / Readme 正常加载;租户配额 API 可用 | +| **未配置** | 请求打到前端 `:4000`,Schema 加载失败,表单为空 | + +SpaceGate 配置 API(保存插件、网关、路由)走 `/api` 代理到 admin-server,**与上述变量无关**。 + +本地 Vite 代理(`vite.config.ts`): + +```text +/api/* → http://localhost:9992/* +``` + +--- + +## 3. 界面导航 + +### 3.1 选择网关 + +顶部 **SelectGateway** 下拉框选择目标网关(如 `ai-demo`)。后续菜单跳转会自动带上 `?gatewayName=ai-demo`。 + +### 3.2 左侧菜单 + +| 菜单 | 路径 | 与本插件相关用途 | +|------|------|-----------------| +| **网关** | `/gateway` | 网关级插件挂载、监听器 | +| **路由** | `/route` | 路由规则、后端、规则级插件 | +| **插件** | `/plugins` | **创建 / 编辑 AI 请求队列网关实例** | +| **实例** | `/instance` | SpaceGate 进程在线状态(与插件配置无关) | + +--- + +## 4. 分步配置 + +### 步骤 1:创建插件实例 + +1. 进入 **插件** 页 +2. 切换到 Tab **「AI」** +3. 找到卡片 **「AI 请求队列网关」** +4. 点击 **「配置」**,打开 **AI 请求队列网关** Drawer +5. 填写 **基础接入** 与 **基础配置**(见 §5) +6. 点击 **保存** + +保存后 admin-server 写入: + +```text +plugin/wasm.ai-gateway-queue.json +``` + +首次保存调用 `POST /config/plugin`;再次编辑调用 `PUT /config/plugin`。 + +卡片上会显示 **「已部署」** 标签。 + +### 步骤 2:挂载到网关或路由 + +插件实例创建后,还需在 **网关** 或 **路由** 中引用,流量才会经过 Wasm。 + +#### 方式 A:网关级挂载(推荐) + +1. 进入 **网关** 页 +2. 编辑目标网关(如 `ai-demo`) +3. 找到 **插件** 字段(PluginListForm) +4. 点击 **添加插件** +5. 选择: + - **Code**:`wasm` + - **Kind**:`named` + - **Name**:`ai-gateway-queue` +6. 保存网关配置 + +等价 JSON 片段: + +```json +{ + "plugins": [ + { + "code": "wasm", + "kind": "named", + "name": "ai-gateway-queue" + } + ] +} +``` + +#### 方式 B:路由级挂载 + +1. 进入 **路由** 页 +2. 编辑目标路由(如 `ai`)下的某条 **规则** +3. 在规则 **插件** 列表中添加同样的引用 +4. 配置 **后端** 指向 LLM 上游 +5. 保存 + +等价 JSON 片段(规则内): + +```json +{ + "matches": [{ "path": { "kind": "Prefix", "value": "/v1/" } }], + "plugins": [ + { "code": "wasm", "kind": "named", "name": "ai-gateway-queue" } + ], + "backends": [{ "host": { "kind": "Host", "host": "127.0.0.1" }, "port": 9000, "weight": 1 }] +} +``` + +> **⚠ 切勿重复挂载** +> 若 Gateway 与 Route **同时** 引用 `ai-gateway-queue`,每个请求会执行 **两次** 插件逻辑,导致 **双倍扣 token / 双倍限流**。 +> 生产环境请 **只选一层**;`resource/ai-gateway-demo` 示例为演示方便两处都挂了,本地验证时注意这一点。 + +### 步骤 3:配置路由与后端 + +在 **路由** 页确保: + +- 路径匹配 AI API(如 `/v1/` Prefix) +- **后端** 指向真实 LLM 服务地址与端口 +- 优先级(priority)按需设置 + +### 步骤 4:配置租户配额(可选) + +按租户 / 模型 / 路径 / 策略设置差异化令牌桶,见 **§6**。当前 Drawer 内 **「队列配额」Tab 为占位**,需通过 API 配置。 + +### 步骤 5:验证 + +```bash +# 经网关(插件生效) +curl -i http://127.0.0.1:9993/v1/chat/completions \ + -H 'X-RateLimit-Policy: abandon' \ + -H 'X-Tenant-Id: demo' \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hello"}' + +# 直连后端健康检查 +curl http://127.0.0.1:18080/healthz +``` + +期望:配额内 `200`;缺 Policy 且 `require=true` 时 `400`;超额 abandon `429`。 + +--- + +## 5. Drawer 字段说明 + +打开 **插件 → AI → AI 请求队列网关 → 配置** 后,Drawer 含四个 Tab。 + +### 5.1 Tab「基础配置」 + +#### 基础接入(Wasm 宿主层 → `spec` 顶层) + +| 界面字段 | 配置键 | 默认值 | 说明 | +|---------|--------|--------|------| +| Wasm URL | `url` | 空 | Wasm 制品地址。支持 `file://`、`http(s)://`、`oci://` | +| 插件名称 | `plugin_name` | `ai-gateway-queue` | 建议保持不变 | +| 失败策略 | `fail_strategy` | `fail_close` | `fail_close`:插件异常时拒绝请求;`fail_open`:放行 | +| 队列后端地址 | `clusters["ai-gateway-service"]` | `http://127.0.0.1:18080` | ai-gateway-service 的 HTTP 地址 | +| 普通 VM 池大小 | `vm_pool_size` | `4` | 处理 abandon / queue 短请求的 Wasm 实例数,≥1 | +| Wait VM 池大小 | `wait_vm_pool_size` | `4` | wait 长连接专用池;不用 wait 可设 `0` | + +**Wasm URL 示例:** + +| 环境 | 示例值 | +|------|--------| +| 本地 Cargo | `file:///path/to/spacegate/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm` | +| Docker 挂载 | `file:///etc/spacegate/plugins/spacegate_plugin_ai_gateway_queue.wasm` | +| K8s HTTP 分发 | `http://ai-gateway-wasm/spacegate_plugin_ai_gateway_queue.wasm` | +| OCI 制品 | `oci://ghcr.io/your-org/ai-gateway-queue:v1.0.0` | + +**界面未暴露、保存时会保留的字段**(来自已有配置文件): + +- `validate_on_create`、`plugin_root_id`、`plugin_vm_id` +- `limits`:`max_memory_pages`、`fuel_per_call`、`epoch_timeout_millis`、`max_body_bytes`、`max_pending_calls` + +#### 基础配置 Schema 表单(→ `spec.plugin_config`) + +表单字段由 `ai-gateway-service` 动态提供:`GET /v1/admin/plugins/ai-gateway-queue/schema`。 + +##### service — 队列后端接入 + +| 字段 | 默认 | 说明 | +|------|------|------| +| `cluster` | `ai-gateway-service` | SpaceGate cluster 名,须与 `clusters` 键一致 | +| `authority` | `ai-gateway-service` | HTTP 调用的 `:authority` | +| `timeout_ms` | `65000` | 调用后端超时;使用 wait 模式建议 ≥60000 | + +##### paths — 后端 API 路径 + +| 字段 | 默认 | +|------|------| +| `rate_limit` | `/v1/ratelimit/check` | +| `enqueue` | `/v1/queue/enqueue` | +| `wait` | `/v1/queue/enqueue-and-wait` | + +一般保持默认即可,除非后端改了路由前缀。 + +##### headers — 客户端请求头映射 + +| 字段 | 默认 HTTP 头 | 用途 | +|------|-------------|------| +| `policy` | `X-RateLimit-Policy` | 队列策略 | +| `tenant` | `X-Tenant-Id` | 租户标识 | +| `model` | `X-Model` | 模型名(优先级路由) | +| `priority` | `X-Queue-Priority` | 显式优先级 | + +HTTP 头名大小写不敏感;配置中通常写小写。 + +##### policies — 策略校验 + +| 字段 | 默认 | 说明 | +|------|------|------| +| `require` | `true` | 为 `true` 时,缺少 Policy 头 → **400** | +| `default` | 空 | `require=false` 时使用的默认策略:`abandon` / `queue` / `wait` | + +##### priority — 多优先级队列 + +| 字段 | 默认 | 说明 | +|------|------|------| +| `enabled` | `true` | 关闭后所有请求走 `default` 优先级 | +| `default` | `normal` | `high` / `normal` / `low` | +| `high_models` / `low_models` | `[]` | 模型名精确匹配 | +| `high_tenants` / `low_tenants` | `[]` | 租户 ID 列表 | + +> **扁平 vs 嵌套格式** +> 部分示例文件(如 `resource/ai-gateway-demo`)使用扁平键(`service_cluster`、`require_policy`)。 +> 管理界面 SchemaForm 使用 **嵌套 JSON**。Wasm 运行时两种格式均兼容;若从文件导入后表单显示异常,可在 Drawer 中重新保存一次以统一格式。 + +### 5.2 Tab「队列配额」 + +当前版本显示占位说明:**租户差异化限流 UI 尚未接入 Drawer**。 + +V1 行为说明(与界面提示一致): + +- **全局限流**在 `ai-gateway-service` 配置,非 Drawer 字段 +- 环境变量:`AI_RATE_LIMIT_RPS`、`AI_RATE_LIMIT_BURST`、`AI_RATE_LIMIT_COST` +- 或 TOML `[rate_limit]` 段 + +租户级配额请使用 **§6 API**。 + +### 5.3 Tab「文档」 + +从 `GET /v1/admin/plugins/ai-gateway-queue/readme` 拉取插件 README Markdown,便于在界面内查阅行为说明。 + +### 5.4 Tab「队列观测」 + +V1 预留,后续接入队列长度、消费速率、回调失败等指标。 + +--- + +## 6. 租户配额配置(Admin API) + +`TenantRateLimitTable` 组件已实现完整 CRUD,但尚未挂接到 Drawer「队列配额」Tab。可通过 HTTP API 或 curl 配置。 + +### 6.1 创建 / 更新配额 + +```bash +curl -X PUT http://127.0.0.1:18080/v1/admin/tenant-rate-limits \ + -H 'Content-Type: application/json' \ + -d '{ + "tenant": "demo", + "model": "", + "path": "", + "policy": "", + "rps": 10, + "burst": 20, + "cost": 1 + }' +``` + +### 6.2 字段说明 + +| 字段 | 必填 | 默认 | 说明 | +|------|------|------|------| +| `tenant` | 是 | — | 租户 ID,与 `X-Tenant-Id` 对应 | +| `model` | 否 | 空=通配 | 如 `gpt-4o` | +| `path` | 否 | 空=通配 | 如 `/v1/chat/completions` | +| `policy` | 否 | 空=通配 | `abandon` / `queue` / `wait` | +| `rps` | 是 | — | 每秒令牌恢复速率,>0 | +| `burst` | 是 | — | 突发容量(令牌桶大小),>0 | +| `cost` | 是 | 1 | 单次请求消耗令牌数,>0 | +| `ttl_secs` | 否 | 永久 | 临时配额过期秒数 | + +**匹配优先级**:维度越具体越优先(带 `model+path+policy` 的规则优先于仅 `tenant` 的规则)。 + +Redis key 预览格式: + +```text +ai:tenant:ratelimit:{tenant}[:model:...][:path:...][:policy:...] +``` + +### 6.3 查询与删除 + +```bash +# 列表(可按 tenant 过滤) +curl 'http://127.0.0.1:18080/v1/admin/tenant-rate-limits?tenant=demo' + +# 删除(body 与创建时维度一致) +curl -X DELETE http://127.0.0.1:18080/v1/admin/tenant-rate-limits \ + -H 'Content-Type: application/json' \ + -d '{"tenant":"demo","rps":10,"burst":20,"cost":1}' +``` + +--- + +## 7. 客户端请求头 + +配置完成后,调用方经网关 `:9993` 发送请求时需携带: + +| 请求头 | 必填 | 取值 | 说明 | +|--------|------|------|------| +| `X-RateLimit-Policy` | 当 `require=true` | `abandon` / `queue` / `wait` | 必须小写 | +| `X-Tenant-Id` | 建议 | 任意字符串 | 租户隔离与配额匹配 | +| `X-Callback-URL` | queue 超额时 | HTTPS URL | 异步回调地址 | +| `X-Model` | 否 | 模型名 | 影响优先级路由 | +| `X-Queue-Priority` | 否 | `high` / `normal` / `low` | 显式优先级 | + +**三种策略行为(均需先过令牌桶):** + +| 策略 | 配额内 | 超额 | +|------|--------|------| +| `abandon` | 直通上游 200 | 429,不入队 | +| `queue` | 直通上游 200 | 202 + job_id,回调/轮询取结果 | +| `wait` | 直通上游 200 | 阻塞等待结果,超时 504 | + +示例: + +```bash +# abandon — 超额返回 429 +curl -i http://127.0.0.1:9993/v1/chat/completions \ + -H 'X-RateLimit-Policy: abandon' \ + -H 'X-Tenant-Id: demo' \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hi"}' + +# queue — 超额返回 202 +curl -i http://127.0.0.1:9993/v1/chat/completions \ + -H 'X-RateLimit-Policy: queue' \ + -H 'X-Tenant-Id: demo' \ + -H 'X-Callback-URL: https://example.com/callback' \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hi"}' +``` + +--- + +## 8. 配置与存储映射 + +```text +管理界面操作 API 存储位置 +───────────────────────────────────────────────────────────────────────── +保存 AI 队列 Drawer POST/PUT /config/plugin plugin/wasm.ai-gateway-queue.json +保存网关 PUT /config/item/{gw}/gateway gateway/{gw}/config.json +保存路由规则 PUT .../route/item/{route} gateway/{gw}/route/{route}.json +租户配额 PUT PUT /v1/admin/tenant-rate-limits Redis +读取 Schema GET /v1/admin/plugins/.../schema (运行时生成) +``` + +文件命名规则:`{code}.{name}.json` → `wasm.ai-gateway-queue.json`。 + +--- + +## 9. 常见问题 + +### Q1:Schema 表单空白或加载失败? + +检查 `VITE_AI_GATEWAY_BASE_URL` 是否指向运行中的 `ai-gateway-service`(默认 `http://127.0.0.1:18080`),并确认 `/healthz` 可访问。 + +### Q2:保存插件成功但请求未限流? + +1. 是否在 **网关或路由** 中添加了插件引用? +2. 是否 **重复挂载** 导致行为异常? +3. SpaceGate 是否已加载最新配置(文件模式通常自动热更)? + +### Q3:第一次请求就 429? + +- 检查 Gateway + Route **双重挂载** +- 检查租户 `burst` 是否过小 +- 用 Admin API 调高配额或新建租户规则 + +### Q4:缺 Policy 返回 400? + +`plugin_config.policies.require=true`(默认)。客户端必须带 `X-RateLimit-Policy`,或在 Drawer 中关闭 require 并设置 default。 + +### Q5:保存插件配置报 `Read-only file system (os error 30)`? + +**原因:** Docker 队列模式下 `admin-server` 配置卷被挂成 **只读(`:ro`)**,无法写入 `plugin/wasm.ai-gateway-queue.json`。 + +**修复:** + +1. `docker-compose.queue.yml` 中 admin-server 使用 **整目录可写** 挂载(勿 `:ro`): + +```yaml +admin-server: + volumes: + - ./.docker/ai-gateway-demo:/etc/spacegate +``` + +2. Wasm 二进制挂到配置目录外,URL 用 `file:///opt/wasm/...`(见 `docker-compose.queue.yml` 注释)。 + +3. 重建 admin-server 镜像(含插件增量写入修复)后重启容器: + +```bash +cd spacegate +docker build -f resource/docker/Dockerfile.admin-server -t ai-gateway/admin-server:dev . +docker rm -f ai-gateway-admin-server && docker run -d --name ai-gateway-admin-server \ + --network container:ai-gateway-spacegate --restart unless-stopped \ + -e CONFIG=file:/etc/spacegate -e RUST_LOG=info \ + -v $(pwd)/../.docker/ai-gateway-demo:/etc/spacegate \ + ai-gateway/admin-server:dev -c file:/etc/spacegate -p 19992 -H 0.0.0.0 +``` + +**临时绕过:** 直接编辑宿主机 `.docker/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json`,无需走 UI 保存。 + +### Q6:`:9080` 管理端报 No such file or directory? + +admin-server 读不到 `/etc/spacegate` 配置。Docker 环境检查 **工作区根目录** `.docker/ai-gateway-demo` 是否正确挂载到容器内 `/etc/spacegate`。 + +### Q7:K8s 环境能用这套 UI 吗? + +可以管理 SpaceGate 配置(若 admin-server 连到同一配置源)。K8s 下 Wasm 常通过 **SgFilter** 内联 spec + HTTP/OCI 分发 Wasm,详见 [`deploy/README.md`](../deploy/README.md) §6。Higress **WasmPlugin** CR 的 `defaultConfig` **不含** `clusters`,生产建议用 **SgFilter**。 + +--- + +## 10. 推荐配置流程( checklist ) + +```text +□ Redis、ai-gateway-service、上游 LLM 已启动 +□ 编译 Wasm 并确认 url 可访问 +□ 设置 VITE_AI_GATEWAY_BASE_URL +□ 插件页 → AI → 配置 AI 请求队列网关 → 保存 +□ 网关或路由(二选一)添加 wasm / ai-gateway-queue 引用 +□ 路由后端指向 LLM 服务 +□ (可选)PUT /v1/admin/tenant-rate-limits 配置租户配额 +□ curl 冒烟:400(无 Policy)/ 200(配额内)/ 429(超额 abandon) +``` + +--- + +## 11. 相关源码索引 + +| 文件 | 说明 | +|------|------| +| `spacegate-admin-fe/components/config/src/components/PluginPanel.vue` | AI Tab 与 Drawer 入口 | +| `spacegate-admin-fe/components/config/src/components/AiGatewayQueueDrawer.vue` | 主配置 Drawer | +| `spacegate-admin-fe/components/config/src/components/TenantRateLimitTable.vue` | 租户配额表格(待接入 Tab) | +| `spacegate-admin-fe/components/config/src/api/aiGateway.ts` | ai-gateway-service Admin API 客户端 | +| `binary/ai-gateway-service/src/app/admin.rs` | Schema / Readme 端点 | +| `binary/ai-gateway-service/src/app/types.rs` | `AiGatewayQueuePluginConfig` 结构 | +| `resource/ai-gateway-demo/` | 文件模式配置模板 | diff --git a/docs/ai-gateway-queue-design-gap-fixlist.md b/docs/ai-gateway-queue-design-gap-fixlist.md new file mode 100644 index 00000000..2ac77740 --- /dev/null +++ b/docs/ai-gateway-queue-design-gap-fixlist.md @@ -0,0 +1,644 @@ +# AI Gateway Queue — 设计与代码差距修复清单 + +> 对照文档:`ai-gateway-queue-design.md`(桌面版) +> 审计范围:`spacegate/plugins/wasm/ai-gateway-queue` + `spacegate/binary/ai-gateway-service` +> 生成日期:2026-05-23 + +本文档将设计文档与当前实现的差异整理为**可执行的修复项**,按优先级排序。每项包含:差距说明、建议改法、涉及文件、验收标准、依赖关系。 + +--- + +## 优先级说明 + +| 级别 | 含义 | 建议节奏 | +|------|------|----------| +| **P0** | 核心语义错误,影响限流/队列正确性 | 立即修复 | +| **P1** | 设计明确要求的能力缺失或明显性能/可靠性缺口 | 下一迭代 | +| **P2** | 行为/格式与设计有差异,但不阻断主流程 | 按需排期 | +| **P3** | 文档、默认值、观测增强 | 低优先级 | + +--- + +## P0 — 核心语义 + +### GAP-001:queue/wait 入队前未做令牌桶准入判定 + +**设计期望** + +- 概述:`Gateway → [Rate Limiter] → Redis Stream → Worker → LLM` +- 限流策略表:三种模式在「触发限流时」分别 429 / 202 / 阻塞等待 +- abandon 示例:未触发限流 → 正常 LLM 响应;触发限流 → 429 + +**当前行为** + +- `abandon`:Wasm 调 `/v1/ratelimit/check`,通过则直通上游 ✅ +- `queue` / `wait`:Wasm **直接**调 `/v1/queue/enqueue` 或 `/v1/queue/enqueue-and-wait`,**不做限流判断**,所有请求全量入队 ❌ + +**建议改法** + +1. **方案 A(推荐,改 Wasm 插件)** + - `queue` / `wait` 在入队前先调 `/v1/ratelimit/check`(与 abandon 共用同一接口) + - `allowed: true` → `resume_http_request()` 直通上游(配额内直通) + - `allowed: false` → + - `queue` → 调 enqueue,返回 202 + - `wait` → 调 enqueue-and-wait,阻塞等待 + +2. **方案 B(改 service 入队接口)** + - 在 `enqueue_job()` 开头内联令牌桶逻辑;`allowed: true` 时同步调 upstream 并返回 200(wait)或直接 proxy(需 Gateway 配合,改动面更大) + +**涉及文件** + +- `spacegate/plugins/wasm/ai-gateway-queue/src/lib.rs`(主改) +- `spacegate/plugins/wasm/ai-gateway-queue/README.md` +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs`(若采用方案 B) +- `spacegate/binary/ai-gateway-service/src/app/queue.rs`(若采用方案 B) +- 集成测试 / e2e 脚本 + +**验收标准** + +- [ ] 租户配额内 + `queue` 策略 → **200**,响应来自上游,**不入队** +- [ ] 租户配额内 + `wait` 策略 → **200**,同步上游响应,**不入队** +- [ ] 租户超额 + `queue` → **202** + `X-Job-Id` +- [ ] 租户超额 + `wait` → 入队等待或 **504** 超时 +- [ ] `abandon` 行为保持不变 +- [ ] `rate_limited_total{policy,tenant}` 在 queue/wait 超额入队时也有计数 + +**依赖**:无(应最先做) + +**备注**:需与设计方确认 queue 示例中「无论是否触发限流都 202」是否作废;若保留该语义,则 queue 模式不做 GAP-001 直通,仅 wait/abandon 对齐。 + +--- + +### GAP-002:queue 模式语义与设计文档内部矛盾需定稿 + +**设计矛盾点** + +- 策略对比表:「限流时入队」 +- queue 示例:**「立即返回(无论是否触发限流)」** → 202 + +**当前行为** + +- 与 queue 示例一致:所有 queue 请求都 202 入队 + +**建议改法** + +- **产品/架构定稿二选一**,写入设计文档 v2: + - **模式 Q1(异步优先)**:queue 永远异步入队,不做直通(维持现状) + - **模式 Q2(配额内直通)**:配额内直通,超额才 202(需 GAP-001) + +**涉及文件** + +- 设计文档(外部) +- `spacegate/plugins/wasm/ai-gateway-queue/README.md` +- 前端配置手册 / Admin 文案 + +**验收标准** + +- [ ] 设计文档消除内部矛盾 +- [ ] README、前端说明与定稿一致 +- [ ] 测试用例覆盖定稿语义 + +**依赖**:阻塞 GAP-001 的实现细节 + +--- + +### GAP-003:多租户配额叠加无全局容量保护 + +**设计期望** + +- 设计强调租户隔离,未写全局上限;但生产上多租户「各自配额内」叠加仍可能打满上游 + +**当前行为** + +- 仅 per-tenant 令牌桶;无 cluster 级总 RPS / 总并发 Semaphore +- `abandon` 直通不受 `worker_concurrency` 约束 + +**建议改法** + +1. 增加 **全局令牌桶**(Redis key 如 `ai:global:ratelimit:tokens`),在 `/v1/ratelimit/check` 中 **先扣全局、再扣租户** +2. 增加 **upstream 并发 Semaphore**(`AI_UPSTREAM_MAX_INFLIGHT`),abandon 直通与 Worker 共享 +3. `/metrics` 暴露 `global_rate_limited_total`、`upstream_inflight` + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs` +- `spacegate/binary/ai-gateway-service/src/app/types.rs`(Lua 或新函数) +- `spacegate/binary/ai-gateway-service/src/app/config.rs` +- `spacegate/binary/ai-gateway-service/config/ai-gateway-service.example.toml` +- `spacegate/binary/ai-gateway-service/src/app/queue.rs`(Worker 侧 acquire permit) + +**验收标准** + +- [ ] 100 个租户各在配额内,全局上限触发后后续请求按策略 429/入队/等待 +- [ ] 指标可观测全局拒绝次数 +- [ ] 配置可独立调整全局 RPS 与 upstream inflight + +**依赖**:建议在 GAP-001 之后 + +--- + +## P1 — 性能与可靠性 + +### GAP-004:S3 multipart 上传与 XADD 顺序执行,非设计所述并发 + +**设计期望** + +> 入队(S3 卸载):S3 PutObject 与 XADD 并发执行,瓶颈在 S3 + +**当前行为** + +- `store_body()` 完整完成后才 `XADD` + +**建议改法** + +- 小 refactor:`store_body` 返回 `(BodyLocation, future)` 或在超阈值时: + 1. 先 `XADD` 占位 entry(status=uploading)或 + 2. 并行:`tokio::join!(multipart_upload, prepare_metadata)`,最后 XADD +- 最小改动:XADD 只写 ref/metadata,body 上传异步完成后更新 entry 或 Worker 按 ref 拉取(Worker 已支持 ref) + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/queue.rs` +- `spacegate/binary/ai-gateway-service/src/app/object_store.rs` + +**验收标准** + +- [ ] 大 body 场景 enqueue P99 不因「上传完成 + XADD 串行」线性叠加 +- [ ] 上传失败时 entry 不处于不可消费状态(abort + DLQ 或重试) + +**依赖**:无 + +--- + +### GAP-005:wait 模式每请求新建 SubscriberClient,未实现连接复用 + +**设计期望** + +> 1000 个 wait 并发共享同一物理连接(fred 多路复用订阅) + +**当前行为** + +- 每次 `enqueue_and_wait` 调用 `build_subscriber_client()` 新建连接 + +**建议改法** + +- 在 `AppState` 中维护 **共享 SubscriberClient 池** 或单例 multiplexer +- 按 `result:{job_id}` channel 注册/oneshot 等待,避免 per-request 连接 +- 注意:fred API 下订阅与命令连接分离的要求仍满足 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs` +- `spacegate/binary/ai-gateway-service/src/app/runtime.rs`(AppState 初始化) +- `spacegate/binary/ai-gateway-service/src/app/util.rs` +- 新增 `wait_subscriber.rs`(可选) + +**验收标准** + +- [ ] 100 并发 wait 时 Redis 连接数不随请求线性增长 +- [ ] 竞态保险(subscribe 后 get result)仍正确 +- [ ] 超时后 subscriber 无泄漏 + +**依赖**:无 + +--- + +### GAP-006:Worker XREADGROUP 读 5 条但串行处理 + +**设计期望** + +> 每次 XREADGROUP 取 5 条,**批量并发处理** + +**当前行为** + +- `read_worker_stream` 循环内逐条 `process_stream_entry`(串行) + +**建议改法** + +- 对同一 batch 用 `FuturesUnordered` / `tokio::spawn` 并发处理 +- 仍受 `worker_concurrency` 或独立 `worker_inflight` Semaphore 约束 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/queue.rs` + +**验收标准** + +- [ ] 队列积压时 Worker 吞吐随 concurrency 提升 +- [ ] job lease 机制下无重复执行 +- [ ] upstream inflight(GAP-003)不被突破 + +**依赖**:建议与 GAP-003 一并设计 + +--- + +### GAP-007:未配置 object_store 时大 body 仍 inline 进 Redis + +**设计期望** + +- 超 128KB 应 offload 到 S3;Redis entry 只存 ref + +**当前行为** + +- 仅当 `object_store.endpoint` 配置存在时才 multipart;否则 >128KB 仍 base64 写入 Stream + +**建议改法** + +- 启动时:若 `inline_threshold` 较小但未配 object_store,**warn 或 fail_fast**(生产配置) +- 或:超阈值且无 S3 时拒绝入队并返回 **413 Payload Too Large** + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/object_store.rs` +- `spacegate/binary/ai-gateway-service/src/app/config.rs`(校验) +- `spacegate/binary/ai-gateway-service/config/ai-gateway-service.example.toml` + +**验收标准** + +- [ ] 生产配置下 >128KB 请求不会把大 payload 塞进 Redis +- [ ] 本地无 MinIO 时行为明确(拒绝或强制配 endpoint) + +**依赖**:无 + +--- + +## P2 — 协议与行为对齐 + +### GAP-008:`rate_limited_total{policy,tenant}` 仅 abandon 路径计数 + +**设计期望** + +- 监控:各策略触发限流次数 + +**当前行为** + +- 仅在 `check_rate_limit` handler 内 increment;queue/wait 超额入队不计数 + +**建议改法** + +- GAP-001 完成后,在「超额转 queue/wait 分支」同样 `inc_labeled` +- 或抽取 `record_rate_limited(policy, tenant)` 共用 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs` +- `spacegate/plugins/wasm/ai-gateway-queue/src/lib.rs`(若 Wasm 侧判定) + +**验收标准** + +- [ ] queue/wait 因配额拒绝而入队时,`rate_limited_total{policy="queue",tenant="..."}` 递增 + +**依赖**:GAP-001 + +--- + +### GAP-009:回调 JSON 与设计示例字段不完全一致 + +**设计期望** + +```json +{ + "job_id": "...", + "status": "completed", + "result": { ...LLM 响应... }, + "completed_at": "2024-01-01T12:00:01Z" +} +``` + +**当前行为** + +- 额外字段:`http_status`、`headers`、`body_base64`、`completed_at_ms`、`error` + +**建议改法** + +- **方案 A**:文档化当前 schema 为正式 API(推荐,向后兼容) +- **方案 B**:增加 `callback_format=v1|v2` 或 Accept 头切换精简格式 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/callback.rs` +- `spacegate/binary/ai-gateway-service/README.md` + +**验收标准** + +- [ ] API 文档与实现一致 +- [ ] 若有 v1 精简格式,集成测试覆盖 + +**依赖**:无 + +--- + +### GAP-010:job_id 格式与设计示例不一致 + +**设计期望** + +- 示例:`01J8XYZABC`(类 ULID) + +**当前行为** + +- `{timestamp_hex}{counter_hex}` + +**建议改法** + +- 改用 ULID / UUID v7;或保留现状并更新设计文档 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/util.rs`(`new_job_id`) + +**验收标准** + +- [ ] job_id 全局唯一、可排序(若用 ULID) +- [ ] 旧 job 查询不受影响(无需迁移) + +**依赖**:无 + +--- + +### GAP-011:`X-RateLimit-Policy` 可通过配置绕过 + +**设计期望** + +- 请求头表格:Policy **必填** + +**当前行为** + +- `require_policy=false` 且无 default 时 Wasm `Action::Continue` 完全 bypass + +**建议改法** + +- 生产 preset:`require_policy=true` 且文档标注勿关闭 +- 或移除 bypass 路径,仅允许 `default_policy` fallback + +**涉及文件** + +- `spacegate/plugins/wasm/ai-gateway-queue/src/lib.rs` +- Admin 前端默认值 / 校验 + +**验收标准** + +- [ ] 生产配置无法意外 bypass 插件 +- [ ] 缺少 policy 一律 400 + +**依赖**:无 + +--- + +### GAP-012:HTTPS 回调要求可关闭 + +**设计期望** + +- `X-Callback-URL` 需 HTTPS + +**当前行为** + +- `require_https_callback` 默认 true,可 env 关闭 + +**建议改法** + +- 生产 profile 强制 HTTPS;dev profile 允许 HTTP +- 配置校验:非 dev 且 `require_https=false` 启动 warning/error + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/config.rs` +- `spacegate/binary/ai-gateway-service/src/app/queue.rs`(`validate_callback_url`) + +**验收标准** + +- [ ] 生产启动检查通过 +- [ ] 本地 `AI_REQUIRE_HTTPS_CALLBACK=false` 仍可用 + +**依赖**:无 + +--- + +### GAP-013:令牌桶粒度设计写「仅 Tenant」,实现为 tenant+model+path + +**设计期望** + +- 限流粒度按 `X-Tenant-Id` 隔离 + +**当前行为** + +- Redis key:`ai:ratelimit:{tenant}:{model}:{path}` + Admin 多维规则 + +**建议改法** + +- **推荐**:更新设计文档 v2,声明更细粒度为 intentional enhancement +- 若需严格 tenant-only:增加配置 `rate_limit_granularity=tenant|tenant_model_path` + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs` +- 设计文档 + +**验收标准** + +- [ ] 文档与实现一致 +- [ ] 可选配置切换粒度(若做) + +**依赖**:无 + +--- + +## P3 — 默认配置、观测与文档 + +### GAP-014:优先级 Stream 默认关闭 + +**设计期望** + +- 扩展:多 Stream 优先级(high/low) + +**当前行为** + +- `enable_priority_streams` 默认 `false` + +**建议改法** + +- 生产 example toml 设为 `true` +- 或 Wasm `plugin_config.priority.enabled` 与 service 配置联动文档化 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/config/ai-gateway-service.example.toml` +- `spacegate/plugins/wasm/ai-gateway-queue/README.md` + +**验收标准** + +- [ ] 启用后 high/low stream 有深度指标 +- [ ] Worker 按权重消费 + +**依赖**:无 + +--- + +### GAP-015:监控指标命名与设计略有差异 + +**设计期望** + +- `enqueue_latency_ms{policy,size_bucket}` 等 + +**当前行为** + +- Prometheus 文本 + `_bucket{le=...}` histogram 风格;部分为 counter + +**建议改法** + +- 导出与设计对齐的 gauge/histogram(OpenMetrics) +- 或更新设计文档指标名 + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/handlers.rs`(`/metrics`) +- `spacegate/binary/ai-gateway-service/src/app/metrics.rs` + +**验收标准** + +- [ ] Grafana 面板可按设计指标名查询 +- [ ] `queue_depth > 1000`、`pel_size > 100` 告警规则可配置 + +**依赖**:无 + +--- + +### GAP-016:Redis 版本未校验 + +**设计期望** + +- Redis 7+(Stream、Pub/Sub) + +**当前行为** + +- 运行时未检查版本 + +**建议改法** + +- 启动时 `INFO server` 检查 major >= 7,否则 warn/error + +**涉及文件** + +- `spacegate/binary/ai-gateway-service/src/app/runtime.rs` + +**验收标准** + +- [ ] Redis 6 启动给出明确错误信息 + +**依赖**:无 + +--- + +### GAP-017:Wasm 层 README / 插件文档与实现对齐 + +**当前缺口** + +- README 仍描述「超额入队」,未说明 queue/wait 当前全量入队 +- 未说明 abandon 与 queue/wait 限流路径差异 + +**建议改法** + +- GAP-001 / GAP-002 定稿后一次性更新: + - `spacegate/plugins/wasm/ai-gateway-queue/README.md` + - Admin 内嵌 readme API 同源 + - `spacegate/docs/` 前端配置手册(若有) + +**验收标准** + +- [ ] 文档描述与代码行为一致 +- [ ] curl 示例可 copy 运行通过 + +**依赖**:GAP-001、GAP-002 + +--- + +## 建议实施顺序(Roadmap) + +```text +Phase 0 — 定稿(1-2 天) + GAP-002 queue 模式语义定稿 + GAP-013 限流粒度文档对齐 + +Phase 1 — 核心正确性(1-2 周) + GAP-001 queue/wait 入队前令牌桶(或确认 Q1 不做) + GAP-008 限流指标补全 + GAP-003 全局容量保护 + GAP-011 生产禁止 bypass policy + +Phase 2 — 性能与可靠性(1-2 周) + GAP-005 wait Subscriber 连接复用 + GAP-006 Worker 批量并发 + GAP-004 S3 + XADD 并发 + GAP-007 大 body 无 S3 保护 + +Phase 3 — 对齐与 polish(按需) + GAP-009 ~ GAP-017 +``` + +--- + +## 测试清单(每项修复必跑) + +| 场景 | 命令/用例 | +|------|-----------| +| abandon 配额内 | Policy=abandon,RPS 内 → 200 来自 upstream | +| abandon 超额 | → 429 + Retry-After | +| queue 配额内 | 定稿 Q2:200 直通;Q1:202 | +| queue 超额 | → 202 + callback | +| wait 配额内 | 定稿 Q2:200 同步;否则入队等待 | +| wait 超额/超时 | → 504 + poll_url,job 仍完成 | +| 大 body offload | >128KB + MinIO → Redis 仅 ref | +| Worker 崩溃 | kill worker → XAUTOCLAIM 重认领 | +| 回调失败 | 不可达 URL → retry stream → DLQ | +| 多租户叠加 | 触发 GAP-003 全局上限 | + +现有脚本参考: + +- `spacegate/binary/ai-gateway-service` 下 unit tests +- `tests/queue-object-store-e2e.sh`(若有) + +--- + +## 变更影响矩阵 + +| GAP | Wasm 插件 | ai-gateway-service | Admin 前端 | 破坏性 | +|-----|-----------|-------------------|------------|--------| +| 001 | ✅ | 可选 | 文案 | **高**(queue 从全 202 变为配额内 200) | +| 002 | — | — | 文案 | 产品决策 | +| 003 | — | ✅ | 可选配额 UI | 中 | +| 004 | — | ✅ | — | 低 | +| 005 | — | ✅ | — | 低 | +| 006 | — | ✅ | — | 低 | +| 007 | — | ✅ | — | 中(大 body 可能从能入队变 413) | +| 008 | 可选 | ✅ | — | 低 | +| 011 | ✅ | — | ✅ | 中 | +| 014 | 配置 | ✅ | ✅ | 低 | + +--- + +## 开放问题(实施前需确认) + +1. **queue 模式**:永远 202(Q1)还是配额内直通(Q2)? +2. **wait 模式**:配额内是否应同步直通上游,还是始终走队列(便于统一 observability)? +3. **全局容量**:是否需要独立配置项暴露给 Admin,还是仅 ops 环境变量? +4. **GAP-001 方案 A vs B**:限流判定放在 Wasm 还是 service 入队接口内? +5. **job_id 是否改为 ULID**:有无外部系统已依赖当前 hex 格式? + +--- + +## 修订记录 + +| 日期 | 说明 | +|------|------| +| 2026-05-23 | 初版:基于设计文档 vs 代码审计生成 | +| 2026-05-24 | **DOC-01/02 定稿**:遵循概述 `Gateway → [Rate Limiter] → …` 与策略表「限流时」语义。三种策略均先过令牌桶;**配额内直通上游**(`resume_http_request`);超额时 abandon→429、queue→202 入队、wait→入队并阻塞等待。queue 示例「无论是否限流都 202」以策略表为准作废。 | +| 2026-05-24 | **全量差距项实施完成**:G-01~G-20、A/Q/W 分项及 DOC 定稿均已落地;`cargo test -p ai-gateway-service` 14/14 通过。详见各模块 commit 与 README 更新。 | + +--- + +## DOC-01 / DOC-02 定稿结论(2026-05-24) + +| 策略 | 配额内(allowed) | 超额(rate limited) | +|------|------------------|---------------------| +| abandon | 直通上游 | 429 | +| queue | 直通上游 | 202 异步入队 | +| wait | 直通上游 | 入队 + Pub/Sub 阻塞等待 | diff --git a/docs/ai-gateway-queue-test-spec.md b/docs/ai-gateway-queue-test-spec.md new file mode 100644 index 00000000..1d41d1ad --- /dev/null +++ b/docs/ai-gateway-queue-test-spec.md @@ -0,0 +1,370 @@ +# AI Gateway 队列 — 测试用例规格 + +**设计文档:** [`/Users/sh.zhang/Documents/ai-gateway-queue-design.md`](/Users/sh.zhang/Documents/ai-gateway-queue-design.md) +**语义基准:** DOC-01/02 定稿(配额内三种策略均直通上游;超额时 abandon→429 / queue→202 / wait→阻塞或 504) + +--- + +## Traceability 矩阵 + +| 设计文档章节 | 用例 ID | 自动化 | +|-------------|---------|--------| +| §限流策略 / 请求头 | TC-HDR-* | Rust IT / Hurl / GW E2E | +| §abandon 示例 | TC-AB-* | Rust IT / Hurl / GW E2E | +| §queue 示例 / 时序 | TC-Q-* | Rust IT / Hurl | +| §wait 示例 / 时序 | TC-W-* | Rust IT / Hurl | +| §核心组件 §1 限流器 | TC-RL-* | Rust IT / Hurl | +| §核心组件 §2 Body | TC-BODY-* | Rust IT / MinIO E2E | +| §核心组件 §3 Stream | TC-Q-* / TC-WK-* | Rust IT | +| §核心组件 §4 Pub/Sub | TC-W-* | Rust IT | +| §性能设计 | TC-BODY-05/07, TC-W-06 | Rust IT | +| §可靠性 | TC-WK-* | Rust IT | +| §监控指标 | TC-MET-* | Rust IT / Hurl | +| §部署 | TC-DEP-* | Shell | +| Wasm 网关层 | TC-GW-* | GW E2E(可选) | + +**图例:** Rust IT = `cargo test --test integration`;Hurl = `tests/hurl/*.hurl`;GW E2E = `scripts/run-gateway-e2e.sh` + +--- + +## 1. 请求头与策略(TC-HDR) + +### TC-HDR-01 缺 Policy 且无 default + +- **设计映射:** §限流策略 — `X-RateLimit-Policy` 必填 +- **前置:** Wasm `default_policy=null`;Service 直接调用入队接口 +- **步骤:** POST `/v1/queue/enqueue`,不带 `x-ratelimit-policy` +- **期望:** 400;Service 侧 bad request(若直接打 service 则 policy 可选但 Wasm 层 400) + +### TC-HDR-02 Policy 非法值 + +- **步骤:** `x-ratelimit-policy: invalid` +- **期望:** Wasm 400 `missing_or_invalid_rate_limit_policy` + +### TC-HDR-03 缺 X-Tenant-Id + +- **步骤:** 任意策略,不带 tenant +- **期望:** Wasm 400 `missing_x_tenant_id`;Service `/v1/ratelimit/check` 400 + +### TC-HDR-04 queue 缺 X-Callback-URL + +- **步骤:** POST `/v1/queue/enqueue`,policy=queue,无 callback +- **期望:** 400 `missing required header x-callback-url` + +### TC-HDR-05 queue 回调非 HTTPS(生产配置) + +- **前置:** `require_https_callback=true` +- **步骤:** `x-callback-url: http://example.com/cb` +- **期望:** 400 `x-callback-url must use https` + +### TC-HDR-06 wait 默认 timeout 60s + +- **前置:** 上游/mock 延迟 >60s;`wait_timeout_secs=60` +- **步骤:** wait 入队并等待 +- **期望:** 504;JSON 含 `error=timeout`、`waited_ms`≈60000 + +### TC-HDR-07 wait 自定义 X-Request-Timeout + +- **步骤:** `x-request-timeout: 2`(测试配置缩短) +- **期望:** ~2s 后 504 + +--- + +## 2. 限流器(TC-RL) + +### TC-RL-01 租户隔离 + +- **设计映射:** §核心组件 §1 — 限流粒度按 X-Tenant-Id +- **前置:** RPS=1, burst=1 +- **步骤:** tenant-A 连续 2 次 check;tenant-B 1 次 check +- **期望:** A 第二次 `allowed=false`;B 第一次 `allowed=true` + +### TC-RL-02 配额内 allowed + +- **步骤:** 首次 check +- **期望:** `{ "allowed": true, "retry_after_ms": 0 }` + +### TC-RL-03 超额与指标 + +- **步骤:** 耗尽 burst 后再 check +- **期望:** `allowed=false`,`retry_after_ms>0`;`/metrics` 中 `rate_limited_total{policy,tenant}` +1 + +### TC-RL-04 burst 超发后拒绝 + +- **前置:** burst=2 +- **步骤:** 连续 3 次 check(同 tenant) +- **期望:** 前 2 次 allowed,第 3 次 denied + +### TC-RL-05 Admin 租户规则覆盖 + +- **步骤:** PUT `/v1/admin/tenant-rate-limits` 设置 tenant 低 RPS;再 check +- **期望:** 新 RPS 生效(更快触发 denied) + +### TC-RL-06 规则 lookup 优先级 + +- **步骤:** 写入 tenant 全局规则 + tenant+model 更严格规则;带 model header check +- **期望:** 使用更具体规则 + +### TC-RL-07 Redis key tenant-only + +- **步骤:** check 后 Redis KEYS `ai:ratelimit:*` +- **期望:** 仅 `ai:ratelimit:{tenant}:tokens` 与 `:ts`;不含 model/path + +--- + +## 3. abandon(TC-AB) + +### TC-AB-01 配额内直通 + +- **设计映射:** §abandon — 未触发限流时正常返回 LLM 响应 +- **步骤:** Wasm policy=abandon,配额内 +- **期望:** 200,body 来自 upstream(非 202/429) + +### TC-AB-02 超额 429 + +- **步骤:** 触发限流 +- **期望:** 429;`Retry-After`;`{"error":"rate_limited","retry_after_ms":N}` + +### TC-AB-03 不调用 enqueue + +- **步骤:** 配额内 abandon;监控 service 日志/无 enqueue 指标增长 +- **期望:** 无 `/v1/queue/enqueue` 调用 + +--- + +## 4. queue(TC-Q) + +### TC-Q-01 配额内直通(定稿) + +- **步骤:** Wasm policy=queue,配额内 +- **期望:** 200 上游响应(**非**设计文档 queue 示例「永远 202」) + +### TC-Q-02 超额 202 入队 + +- **步骤:** 超额 queue +- **期望:** 202;Header `X-Job-Id`;JSON `poll_url=/jobs/{id}/status` + +### TC-Q-03 202 JSON 字段 + +- **期望:** `job_id` 为 ULID 格式;`status=queued`;`poll_url` 正确 + +### TC-Q-04 Worker 回调 + +- **前置:** mock callback server +- **步骤:** 超额入队 → worker 完成 +- **期望:** POST 回调;Header `X-Gateway-Job-Id` + +### TC-Q-05 回调 JSON 四字段 + +- **期望:** `{ job_id, status, result, completed_at }` 仅此四字段(result 为 LLM JSON) + +### TC-Q-06 Stream entry 字段 + +- **步骤:** XREAD 或 XRANGE 读 stream +- **期望:** job_id, body/ref, size, policy, callback_url, headers, created_at 等齐全 + +### TC-Q-07 dev HTTP 回调 + +- **前置:** `require_https_callback=false` +- **步骤:** `http://` callback URL 入队 +- **期望:** 202 + +--- + +## 5. wait(TC-W) + +### TC-W-01 配额内直通 + +- **期望:** 200 上游响应,无入队等待 + +### TC-W-02 超额成功 + +- **步骤:** 超额 wait,worker 正常 +- **期望:** 200;`X-Job-Id`;`X-Queue-Wait-Ms`;LLM body + +### TC-W-03 竞态保险 + +- **前置:** worker 即时完成(0 延迟 upstream) +- **步骤:** enqueue-and-wait +- **期望:** 200(subscribe 前 result 已写入) + +### TC-W-04 超时 504 + +- **前置:** upstream 延迟 > timeout +- **期望:** 504;`error/timeout/job_id/waited_ms/message` + +### TC-W-05 504 后 poll + +- **步骤:** 504 后等待 worker 完成;GET `/jobs/{id}/status` +- **期望:** 200 原始 LLM 响应体 + +### TC-W-06 Pub/Sub 连接复用 smoke + +- **步骤:** 并发 N 个 wait(N=10 smoke) +- **期望:** 全部完成;Redis 连接数无 N 倍 subscriber 连接 + +--- + +## 6. Body 处理(TC-BODY) + +### TC-BODY-01 inline ≤128KB + +- **期望:** storage=inline;Redis entry 含 base64 body + +### TC-BODY-02 S3 卸载 >128KB + +- **前置:** 配置 object_store_endpoint +- **期望:** storage=object;entry 仅 ref;`object_offload_total`+1 + +### TC-BODY-03 无 S3 大 body + +- **期望:** 413 Payload Too Large + +### TC-BODY-04 超 MAX_BODY_BYTES + +- **期望:** 413 + +### TC-BODY-05 S3 与 XADD 并发 + +- **期望:** 入队在合理时间内完成(相对串行基线) + +### TC-BODY-06 multipart 失败 Abort + +- **前置:** mock S3 返回 500 +- **期望:** 入队失败;无成功 XADD + +### TC-BODY-07 body Semaphore + +- **前置:** body_read_concurrency=2(测试配置) +- **步骤:** 3 个并发大 body 入队 +- **期望:** 第三个延迟开始(可选 smoke) + +--- + +## 7. Worker / 可靠性(TC-WK) + +### TC-WK-01 批量并发消费 + +- **步骤:** 一次 XADD 5 条;观察 worker 处理 +- **期望:** 5 条均完成(并发处理) + +### TC-WK-02 XAUTOCLAIM + +- **前置:** reclaim_interval_secs=2(测试);模拟 PEL 未 ACK +- **期望:** 重认领后重新处理 + +### TC-WK-03 回调失败 → retry stream + +- **前置:** callback URL 不可达 +- **期望:** callback_retry_stream 有 entry + +### TC-WK-04 回调 DLQ + +- **前置:** 超过 max retry +- **期望:** callback_dlq_stream 有 entry + +### TC-WK-05 job DLQ + +- **前置:** max_delivery_attempts=1;反复失败 +- **期望:** job_dlq_stream + +### TC-WK-06 result TTL 120s + +- **前置:** result_ttl_secs=2(测试) +- **步骤:** 完成后等待 TTL;poll +- **期望:** 404 not_found + +### TC-WK-07 优先级 Stream + +- **前置:** enable_priority_streams=true;high/normal 均有积压 +- **期望:** high 优先被消费完 + +--- + +## 8. 监控与部署(TC-MET / TC-DEP) + +### TC-MET-01 metrics 基础 + +- **步骤:** GET `/metrics` +- **期望:** 200;含 `queue_depth`、`pel_size` + +### TC-MET-02 rate_limited 标签 + +- **期望:** `rate_limited_total{policy="...",tenant="..."}` 行存在 + +### TC-MET-03 enqueue_latency 分桶 + +- **期望:** `enqueue_latency_ms_bucket{policy,size_bucket,le=...}` 存在 + +### TC-DEP-01 Redis 6 拒绝 + +- **步骤:** 对 Redis 6 启动 service +- **期望:** 启动失败,明确错误信息 + +### TC-DEP-02 Redis 7+ 通过 + +- **期望:** 正常启动 + +--- + +## 9. Wasm 网关层(TC-GW,可选) + +### TC-GW-01 abandon 超额 + +- **期望:** 429 + +### TC-GW-02 queue 超额 + +- **期望:** 202 + +### TC-GW-03 wait 超额 + +- **期望:** 200 或 504(视 upstream 延迟) + +### TC-GW-04 service 不可达 + +- **期望:** 502 + +--- + +## 运行命令 + +```bash +# 单元测试(无需 Redis) +cd spacegate && cargo test -p ai-gateway-service + +# 集成测试(需 Redis 7+) +./spacegate/binary/ai-gateway-service/scripts/run-integration-tests.sh + +# Hurl 黑盒 +./spacegate/binary/ai-gateway-service/scripts/run-hurl-tests.sh + +# MinIO E2E +./spacegate/binary/ai-gateway-service/scripts/queue-object-store-e2e.sh + +# Wasm 策略逻辑(host) +cd spacegate/plugins/wasm/ai-gateway-queue && cargo test --lib +``` + +--- + +## 修订记录 + +| 日期 | 说明 | +|------|------| +| 2026-05-24 | 初版:55 条 TC-* 用例 + traceability | +| 2026-05-24 | 落地 Rust 集成测试 22 项、Hurl 5 文件、脚本 4 个、Wasm policy host 测试 3 项 | + +## 已实现自动化映射 + +| 用例 ID | Rust IT | Hurl | 脚本 | +|---------|---------|------|------| +| TC-HDR-03~05 | body_store / enqueue_queue | queue | | +| TC-RL-01~07 | ratelimit / admin | ratelimit / admin | | +| TC-Q-02~07 | enqueue_queue | queue | | +| TC-W-02~05 | enqueue_wait | wait | | +| TC-BODY-01/03 | body_store | | | +| TC-WK-01/03 | worker_reliability | | | +| TC-MET-01/02, TC-DEP-02 | metrics | metrics | | +| TC-BODY-02 | | | queue-object-store-e2e.sh | +| TC-GW / TC-HDR-02 | policy host tests | | run-gateway-e2e.sh (stub) | diff --git a/docs/k8s/gateway-api-compatibility.md b/docs/k8s/gateway-api-compatibility.md index 0985831b..617e7f66 100644 --- a/docs/k8s/gateway-api-compatibility.md +++ b/docs/k8s/gateway-api-compatibility.md @@ -143,3 +143,30 @@ Fields: - kind - `Gateway` `HTTPRoute` - namespace (option) - name + +### Higress-compatible WasmPlugin + +Spacegate can read Higress-style `extensions.higress.io/v1alpha1` `WasmPlugin` resources and translate them into the internal `code = "wasm"` plugin runtime configuration. + +Supported fields: + +- `spec.url` - local path, `file://`, `http://`, `https://`, or OCI wasm image URL (`oci://`, `docker://`, `image://`). +- `spec.sha256` - optional wasm byte digest, plain hex or `sha256:`. +- `spec.pluginName` - exposed to the proxy-wasm guest as `plugin_name`. +- `spec.defaultConfig` - converted to a Spacegate wasm plugin instance and mounted at Gateway level. +- `spec.defaultConfigDisable` - disables the generated Gateway-level default plugin instance. +- `spec.matchRules[].ingress` - generates per-rule wasm plugin instances and mounts them on matching Spacegate routes. +- `spec.matchRules[].domain` - generates per-rule wasm plugin instances and mounts them on routes whose hostnames match. +- `spec.matchRules[].service` - generates per-rule wasm plugin instances and mounts them on matching backends. +- `spec.matchRules[].config/configDisable` - configures or disables each generated rule-level plugin instance. +- `spec.failStrategy` - accepts `FAIL_OPEN`/`FAIL_CLOSE` and maps to Spacegate `fail_open`/`fail_close`. +- `spec.phase` - participates in plugin ordering (`AUTHN` before `AUTHZ` before unspecified before `STATS`). +- `spec.priority` - used inside the same phase; higher priority plugins are mounted earlier. +- `spec.imagePullPolicy` - `Always` disables the Spacegate module cache for that plugin instance. +- `spec.imagePullSecret` - for OCI URLs, Spacegate reads Docker config (`.dockerconfigjson`/`.dockercfg`) or basic-auth (`username`/`password`) Kubernetes Secrets and passes the registry credentials to the wasm runtime. +- `status` - Spacegate writes `observedGeneration`, `phase`, `digest`, and `message` during K8s watch reconciliation. + +Current limitations: + +- OCI layer selection supports wasm media types (`application/vnd.module.wasm.content.layer.v1+wasm`, `application/vnd.wasm.content.layer.v1+wasm`, `application/wasm`) and falls back to a single-layer artifact. +- `phase` currently maps to ordering only, not to separate Spacegate execution pipelines. diff --git a/docs/otlp/otel-three-signals-guide.md b/docs/otlp/otel-three-signals-guide.md new file mode 100644 index 00000000..7b369eb2 --- /dev/null +++ b/docs/otlp/otel-three-signals-guide.md @@ -0,0 +1,445 @@ +# SpaceGate OTEL 三信号说明 + +本文说明 SpaceGate 当前接入 OpenTelemetry 后,`logs`、`traces`、`metrics` 三类数据分别如何上报、数据结构大致是什么样、以及分别适合哪些审计和监控需求。 + +## 1. 当前链路 + +当前本地验证链路: + +```text +SpaceGate + -> OTLP gRPC + -> OpenTelemetry Collector + -> ClickHouse +``` + +本地配置位置: + +```text +/tmp/spacegate-otel/config/config.json +/tmp/spacegate-otel/otel-collector.yaml +``` + +SpaceGate OTLP endpoint: + +```json +"otlp_endpoint": "http://127.0.0.1:4317", +"protocol": "grpc" +``` + +Collector 接收: + +```yaml +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 +``` + +Collector 写入 ClickHouse: + +```yaml +exporters: + clickhouse: + endpoint: tcp://spacegate-clickhouse:9000?dial_timeout=10s + database: otel + create_schema: true +``` + +## 2. Logs:审计明细 + +### 上报方式 + +SpaceGate 使用 Rust `tracing::info!` 生成结构化日志事件,再通过 OpenTelemetry logs exporter 走 OTLP 推送给 Collector。 + +当前每个请求完成时,网关会生成一条 `info` 级别的 access log: + +```text +event = "http_access" +``` + +插件可以把业务审计字段写入请求级 `TelemetryContext`,请求结束时统一进入 access log 的 `telemetry` 字段。 + +### 数据结构 + +ClickHouse 表: + +```text +otel_logs +``` + +常见字段形态: + +```text +Timestamp +TraceId +SpanId +SeverityText +Body +LogAttributes +ResourceAttributes +``` + +其中 `LogAttributes` 是 key/value map。当前 SpaceGate access log 关键字段: + +```text +LogAttributes['event'] = 'http_access' +LogAttributes['gateway'] = 'local' +LogAttributes['method'] = 'GET' +LogAttributes['path'] = '/' +LogAttributes['host'] = '127.0.0.1:9000' +LogAttributes['authority'] = '127.0.0.1:9000' +LogAttributes['client_ip'] = '127.0.0.1' +LogAttributes['x_forwarded_for'] +LogAttributes['user_agent'] = 'curl/8.7.1' +LogAttributes['downstream_remote_address'] = '127.0.0.1:xxxxx' +LogAttributes['route_name'] = 'local-test' +LogAttributes['upstream_host'] = '127.0.0.1' +LogAttributes['trace_id'] = '4bf92f3577b34da6a3ce929d0e0e4736' +LogAttributes['status_code'] = '200' +LogAttributes['request_id'] = '...' +LogAttributes['peer_addr'] = '127.0.0.1:xxxxx' +LogAttributes['duration_ms'] = '...' +LogAttributes['bytes_received'] +LogAttributes['bytes_sent'] +LogAttributes['request_body_size'] +LogAttributes['response_body_size'] +LogAttributes['telemetry'] = '{"ai.asset_id":"...","ai.total_tokens":"37"}' +``` + +`client_ip` 优先取 `X-Forwarded-For` 的第一个 IP,缺失时退回 TCP peer IP。MAC 地址不是 HTTP 请求语义的一部分,网关在代理、NAT、K8s 场景下无法可靠获得;如果审计确实需要,只能由上游可信组件或插件以业务字段写入 `telemetry`。 + +插件写入的审计字段在 `telemetry` JSON 里。例如: + +```json +{ + "ai.asset_id": "deepseek-chat", + "ai.asset_type": "model", + "ai.prompt_tokens": "24", + "ai.completion_tokens": "13", + "ai.total_tokens": "37", + "mcp.server": "search-service", + "mcp.tool": "web_search", + "auth.app_id": "demo-app" +} +``` + +查询示例: + +```sql +SELECT + Timestamp, + LogAttributes['request_id'] AS request_id, + LogAttributes['path'] AS path, + LogAttributes['status_code'] AS status_code, + JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') AS asset_id, + toUInt64OrZero(JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens')) AS total_tokens, + JSONExtractString(LogAttributes['telemetry'], 'mcp.tool') AS mcp_tool +FROM otel_logs +WHERE LogAttributes['event'] = 'http_access' +ORDER BY Timestamp DESC +LIMIT 20; +``` + +### 适合的需求 + +Logs 适合做**审计明细**: + +- 每次接口调用记录 +- 请求状态码、耗时、request_id +- 应用、API Key 摘要、租户信息 +- 大模型 asset_id、token 用量 +- MCP server/tool 调用信息 +- 错误码、失败原因 +- 审计中心按请求维度查询和导出 + +### 不适合的需求 + +Logs 不适合作为高频实时监控聚合的唯一数据源。虽然可以统计,但大量 JSON 提取和明细扫描成本较高。高频监控建议用 metrics。 + +## 3. Traces:调用链路 + +### 上报方式 + +SpaceGate 在请求入口创建 HTTP server span。插件内部使用 `tracing` 打出的事件可以挂到当前 span 上。OpenTelemetry traces exporter 通过 OTLP 推送给 Collector。 + +### 数据结构 + +ClickHouse 表: + +```text +otel_traces +``` + +常见字段形态: + +```text +Timestamp +TraceId +SpanId +ParentSpanId +SpanName +ServiceName +Duration +StatusCode +SpanAttributes +ResourceAttributes +Events +``` + +当前请求 span 示例字段: + +```text +SpanName = 'http.server.request' +SpanAttributes['http.method'] = 'GET' +SpanAttributes['http.path'] = '/' +SpanAttributes['http.host'] = '127.0.0.1:9000' +SpanAttributes['http.protocol'] = 'HTTP/1.1' +SpanAttributes['http.status_code'] = '200' +SpanAttributes['request_id'] = '...' +SpanAttributes['peer_addr'] = '127.0.0.1:xxxxx' +SpanAttributes['duration_ms'] = '...' +``` + +查询示例: + +```sql +SELECT + Timestamp, + TraceId, + SpanId, + ParentSpanId, + SpanName, + Duration, + StatusCode, + SpanAttributes['http.status_code'] AS http_status_code, + SpanAttributes['request_id'] AS request_id +FROM otel_traces +ORDER BY Timestamp DESC +LIMIT 20; +``` + +### 适合的需求 + +Traces 适合做**链路诊断**: + +- 一次请求经过了哪些内部阶段 +- 哪个插件或后端调用耗时高 +- 请求失败时定位失败发生在哪一段 +- 根据 `TraceId` 把 logs 和 spans 串起来 +- 抽样分析慢请求和异常请求 + +### 不适合的需求 + +Traces 不适合作为完整审计账本。生产环境通常会采样,例如 1%、0.1% 或 parent-based sampling。审计要求完整性时,应以 logs 为准。 + +## 4. Metrics:聚合监控 + +### 上报方式 + +SpaceGate 使用 OpenTelemetry metrics SDK 定期导出指标。当前本地配置里有: + +```json +"metrics": { + "enabled": true, + "export_interval_ms": 5000 +} +``` + +这表示每 5 秒导出一次当前指标数据。即使没有新请求,累计型指标也可能周期性写入 ClickHouse,所以 metrics 表行数会持续增加。 + +### 数据结构 + +ClickHouse 表通常包括: + +```text +otel_metrics_sum +otel_metrics_histogram +otel_metrics_gauge +otel_metrics_summary +otel_metrics_exp_histogram +``` + +当前 SpaceGate 请求级指标包括: + +```text +http.server.requests +http.server.errors +http.server.errors.4xx +http.server.errors.5xx +http.server.active_requests +http.server.request.duration +http.server.request.body.size +http.server.response.body.size +``` + +指标属性使用低基数字段: + +```text +gateway +http.request.method +http.response.status_code +network.protocol.name +network.protocol.version +``` + +示例含义: + +```text +http.server.requests + 类型:Counter + 作用:请求总量 + +http.server.request.duration + 类型:Histogram + 单位:s + 作用:请求耗时分布,可计算 P50/P95/P99 + +http.server.errors.5xx + 类型:Counter + 作用:服务端错误数量 + +http.server.active_requests + 类型:UpDownCounter + 作用:当前活跃请求数 +``` + +### 适合的需求 + +Metrics 适合做**监控和告警**: + +- QPS +- 错误率 +- P95/P99 延迟 +- 活跃请求数 +- 请求/响应大小分布 +- 4xx/5xx 趋势 +- 容量规划 +- SLO/SLA 面板 + +### 不适合的需求 + +Metrics 不适合做逐请求审计: + +- 不包含完整 request_id +- 不应该带 api_key、user_id、asset_id 这类高基数字段 +- 不记录每次请求的完整业务明细 +- 周期性导出会产生重复时间序列点,不能用行数代表请求数 + +## 5. 三者对比 + +| 信号 | 粒度 | 数据完整性 | 成本 | 主要用途 | ClickHouse 表 | +| --- | --- | --- | --- | --- | --- | +| Logs | 单请求明细 | 高 | 中到高 | 审计、账单、问题回溯 | `otel_logs` | +| Traces | 调用链路 | 取决于采样 | 中 | 慢请求诊断、链路分析 | `otel_traces` | +| Metrics | 聚合数据 | 聚合后数据 | 低到中 | 监控、告警、趋势 | `otel_metrics_*` | + +## 6. 审计中心推荐使用方式 + +审计中心建议以 logs 为主: + +```text +otel_logs + WHERE LogAttributes['event'] = 'http_access' +``` + +核心查询字段: + +```text +Timestamp +LogAttributes['request_id'] +LogAttributes['gateway'] +LogAttributes['method'] +LogAttributes['path'] +LogAttributes['status_code'] +LogAttributes['duration_ms'] +LogAttributes['telemetry'] +``` + +插件业务字段从 `telemetry` JSON 里解析: + +```sql +JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') +JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens') +JSONExtractString(LogAttributes['telemetry'], 'mcp.tool') +``` + +如果审计中心需要高频统计,例如按模型统计 token: + +```sql +SELECT + JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') AS asset_id, + sum(toUInt64OrZero(JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens'))) AS total_tokens +FROM otel_logs +WHERE LogAttributes['event'] = 'http_access' +GROUP BY asset_id; +``` + +生产上建议对常用字段建 ClickHouse 物化视图,把 JSON 字段抽成列,提升查询性能。 + +## 7. 监控系统推荐使用方式 + +监控面板建议使用 metrics: + +- `http.server.requests` 计算请求量 +- `http.server.errors` / `http.server.requests` 计算错误率 +- `http.server.request.duration` 计算延迟分位数 +- `http.server.active_requests` 观察并发压力 + +本地测试阶段如果只验证审计入库,可以关闭 metrics: + +```json +"metrics": { + "enabled": false, + "export_interval_ms": 60000 +} +``` + +生产建议使用较低频率: + +```json +"metrics": { + "enabled": true, + "export_interval_ms": 30000 +} +``` + +如果 metrics 长期写 ClickHouse,建议配置 TTL 或单独存入更适合时序数据的系统。 + +## 8. 推荐配置策略 + +### 本地审计验证 + +```text +logs: enabled +traces: enabled, sample_ratio = 1.0 +metrics: disabled +``` + +### 生产审计 + +```text +logs: enabled +traces: enabled, parent-based sampling +metrics: enabled, 30s 或 60s interval +``` + +### 生产监控 + +```text +metrics: enabled +logs: only access/audit logs +traces: sampling +``` + +## 9. 总结 + +- **Logs 是审计主数据**:每个请求一条 `http_access`,插件审计字段在 `telemetry` JSON 中。 +- **Traces 是诊断数据**:用 `TraceId` 追踪一次请求的链路和耗时。 +- **Metrics 是监控数据**:周期性聚合导出,用于 QPS、错误率、延迟、告警。 +- 不要把业务审计字段作为 metrics label。 +- 不要把 traces 当完整审计账本。 +- 审计中心主要查 `otel_logs`,监控系统主要用 `otel_metrics_*`。 diff --git a/docs/otlp/telemetry-pluginized-audit-plan.md b/docs/otlp/telemetry-pluginized-audit-plan.md new file mode 100644 index 00000000..49129614 --- /dev/null +++ b/docs/otlp/telemetry-pluginized-audit-plan.md @@ -0,0 +1,615 @@ +# Telemetry Pluginized Audit Fields Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 将请求级业务审计字段改造成真正插件化的 `TelemetryContext`:插件按命名空间写入字段,网关只负责统一携带、校验、序列化到 access log,并通过 OTLP logs 入库。 + +**Architecture:** `kernel` 提供请求级 `TelemetryContext` 和字段校验/序列化能力;`plugin` 提供原生插件写入 API;`service.rs` 只输出通用 access log 字段和一个 JSON 字符串 `telemetry`,不再硬编码 AI/MCP/token 等业务字段。ClickHouse 查询侧从 `LogAttributes['telemetry']` JSON 解析插件自定义字段。 + +**Tech Stack:** Rust, hyper request extensions, tracing structured logs, OpenTelemetry logs, ClickHouse `Map`/JSON extraction. + +--- + +## 设计补充与风险点 + +当前方向基本正确,但还需要明确这些边界: + +- **命名空间冲突**:A/B 插件不能直接共用 `total_tokens` 这类裸 key,推荐 `ai.total_tokens`、`mcp.tool`、`auth.app_id`。 +- **保留前缀**:禁止插件写 `http.*`、`net.*`、`gateway.*`、`spacegate.*`、`otel.*`,避免和网关/OTEL 主字段混淆。 +- **字段结构**:`TelemetryContext` 保持扁平 `key/value`,不接受嵌套 JSON 对象,避免合并语义和查询复杂度失控。 +- **字段长度**:限制 key/value 大小,防止插件误写完整 prompt、response body 或超大错误堆栈。 +- **覆盖策略**:同 key 后写覆盖前写;这是同命名空间内插件自己的责任。跨插件通过 namespace 避免冲突。 +- **敏感信息**:不建议写完整 `api_key`,推荐写 `api_key_hash` 或脱敏值。 +- **metrics 边界**:业务审计字段只进入 logs/traces,不作为 metrics label,避免高基数。 +- **ClickHouse 性能**:低频查询可直接 `JSONExtract*`;高频统计建议建物化视图抽取常用字段。 +- **WASM 对齐**:WASM host function 也必须遵守同一套 key 校验、namespace 和长度限制。 + +## File Structure + +- Modify: `crates/kernel/src/observability.rs` + - 定义 telemetry 字段校验规则。 + - 提供 `TelemetryError`。 + - 提供 `TelemetryContext::insert_checked`。 + - 提供 `TelemetryContext::insert_namespaced`。 + - 提供 `telemetry_json`. + +- Modify: `crates/kernel/src/service.rs` + - 删除硬编码 `telemetry.asset_id`、`telemetry.total_tokens` 等业务字段。 + - access log 只输出一个 `telemetry` JSON 字符串。 + +- Modify: `crates/plugin/src/lib.rs` + - 保留 `set_telemetry_field`,内部走 checked insert。 + - 新增推荐 API `set_plugin_telemetry_field(req, namespace, key, value)`。 + - 返回 `Result<(), TelemetryError>`,让插件可感知字段被拒绝。 + +- Modify: `crates/plugin/tests/test_telemetry.rs` + - 覆盖原生插件 API、命名空间 API、非法 key、保留前缀。 + +- Modify: `scripts/otel-local/query-access-logs.sh` + - 从 `LogAttributes['telemetry']` JSON 中解析字段,不再查询 `LogAttributes['telemetry.asset_id']`。 + +- Modify: `docs/wasm-telemetry-host-function-plan.md` + - 对齐本计划中的校验规则和 telemetry JSON 入库形态。 + +--- + +### Task 1: Kernel Telemetry Validation + +**Files:** +- Modify: `crates/kernel/src/observability.rs` + +- [ ] **Step 1: Write failing tests for validation** + +Add these tests inside `#[cfg(test)] mod tests` in `crates/kernel/src/observability.rs`: + +```rust +#[test] +fn telemetry_key_validation_accepts_namespaced_keys() { + assert!(super::validate_telemetry_key("ai.total_tokens").is_ok()); + assert!(super::validate_telemetry_key("mcp.tool-name").is_ok()); + assert!(super::validate_telemetry_key("auth.api_key_hash").is_ok()); +} + +#[test] +fn telemetry_key_validation_rejects_bad_keys() { + assert_eq!(super::validate_telemetry_key(""), Err(super::TelemetryError::EmptyKey)); + assert_eq!(super::validate_telemetry_key("total_tokens"), Err(super::TelemetryError::MissingNamespace)); + assert_eq!(super::validate_telemetry_key("ai total_tokens"), Err(super::TelemetryError::InvalidKey)); + assert_eq!(super::validate_telemetry_key("http.status_code"), Err(super::TelemetryError::ReservedPrefix)); + assert_eq!(super::validate_telemetry_key("spacegate.internal"), Err(super::TelemetryError::ReservedPrefix)); +} + +#[test] +fn telemetry_value_validation_rejects_oversized_values() { + let value = "x".repeat(super::MAX_TELEMETRY_VALUE_LEN + 1); + assert_eq!(super::validate_telemetry_value(&value), Err(super::TelemetryError::ValueTooLong)); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: + +```bash +cargo test -p spacegate-kernel telemetry_key_validation 2>&1 | head -c 12000 +``` + +Expected: FAIL because `validate_telemetry_key`, `validate_telemetry_value`, `TelemetryError`, or `MAX_TELEMETRY_VALUE_LEN` are missing. + +- [ ] **Step 3: Implement validation** + +Add near `TelemetryContext` in `crates/kernel/src/observability.rs`: + +```rust +pub const MAX_TELEMETRY_KEY_LEN: usize = 128; +pub const MAX_TELEMETRY_VALUE_LEN: usize = 4096; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TelemetryError { + EmptyKey, + MissingNamespace, + ReservedPrefix, + InvalidKey, + KeyTooLong, + ValueTooLong, +} + +pub fn validate_telemetry_key(key: &str) -> Result<(), TelemetryError> { + if key.is_empty() { + return Err(TelemetryError::EmptyKey); + } + if key.len() > MAX_TELEMETRY_KEY_LEN { + return Err(TelemetryError::KeyTooLong); + } + if !key.contains('.') { + return Err(TelemetryError::MissingNamespace); + } + if ["http.", "net.", "gateway.", "spacegate.", "otel."].iter().any(|prefix| key.starts_with(prefix)) { + return Err(TelemetryError::ReservedPrefix); + } + if !key.bytes().all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-')) { + return Err(TelemetryError::InvalidKey); + } + Ok(()) +} + +pub fn validate_telemetry_value(value: &str) -> Result<(), TelemetryError> { + if value.len() > MAX_TELEMETRY_VALUE_LEN { + return Err(TelemetryError::ValueTooLong); + } + Ok(()) +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: + +```bash +cargo test -p spacegate-kernel telemetry_key_validation 2>&1 | head -c 12000 +cargo test -p spacegate-kernel telemetry_value_validation 2>&1 | head -c 12000 +``` + +Expected: PASS. + +--- + +### Task 2: Checked TelemetryContext API + +**Files:** +- Modify: `crates/kernel/src/observability.rs` + +- [ ] **Step 1: Write failing tests for checked insertion** + +Add tests: + +```rust +#[test] +fn telemetry_context_checked_insert_rejects_invalid_key_without_mutating_context() { + let context = super::TelemetryContext::default(); + + let result = context.insert_checked("total_tokens", "37"); + + assert_eq!(result, Err(super::TelemetryError::MissingNamespace)); + assert!(context.snapshot().is_empty()); +} + +#[test] +fn telemetry_context_namespaced_insert_builds_stable_key() { + let context = super::TelemetryContext::default(); + + context.insert_namespaced("ai", "total_tokens", 37).expect("insert"); + + let fields = context.snapshot(); + assert_eq!(fields.get("ai.total_tokens").map(String::as_str), Some("37")); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: + +```bash +cargo test -p spacegate-kernel telemetry_context_checked_insert 2>&1 | head -c 12000 +cargo test -p spacegate-kernel telemetry_context_namespaced_insert 2>&1 | head -c 12000 +``` + +Expected: FAIL because methods are missing. + +- [ ] **Step 3: Implement checked APIs** + +Replace or extend `impl TelemetryContext` with: + +```rust +impl TelemetryContext { + pub fn insert(&self, key: impl Into, value: impl Into) { + let Ok(mut fields) = self.fields.lock() else { + return; + }; + fields.insert(key.into(), value.into()); + } + + pub fn insert_checked(&self, key: impl Into, value: impl ToString) -> Result<(), TelemetryError> { + let key = key.into(); + let value = value.to_string(); + validate_telemetry_key(&key)?; + validate_telemetry_value(&value)?; + let Ok(mut fields) = self.fields.lock() else { + return Ok(()); + }; + fields.insert(key, value); + Ok(()) + } + + pub fn insert_namespaced(&self, namespace: &str, key: &str, value: impl ToString) -> Result<(), TelemetryError> { + self.insert_checked(format!("{namespace}.{key}"), value) + } + + pub fn snapshot(&self) -> BTreeMap { + self.fields.lock().map(|fields| fields.clone()).unwrap_or_default() + } + + pub fn is_empty(&self) -> bool { + self.fields.lock().map(|fields| fields.is_empty()).unwrap_or(true) + } +} +``` + +- [ ] **Step 4: Run tests** + +Run: + +```bash +cargo test -p spacegate-kernel telemetry_context_ 2>&1 | head -c 12000 +``` + +Expected: PASS. + +--- + +### Task 3: Access Log Uses Generic Telemetry JSON + +**Files:** +- Modify: `crates/kernel/src/observability.rs` +- Modify: `crates/kernel/src/service.rs` + +- [ ] **Step 1: Write failing JSON serialization test** + +Add test: + +```rust +#[test] +fn telemetry_json_serializes_plugin_defined_fields() { + let fields = BTreeMap::from([ + ("ai.asset_id".to_string(), "deepseek-chat".to_string()), + ("ai.total_tokens".to_string(), "37".to_string()), + ("mcp.tool".to_string(), "search".to_string()), + ]); + + let json = super::telemetry_json(&fields); + + assert!(json.contains("\"ai.asset_id\":\"deepseek-chat\"")); + assert!(json.contains("\"ai.total_tokens\":\"37\"")); + assert!(json.contains("\"mcp.tool\":\"search\"")); +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +cargo test -p spacegate-kernel telemetry_json_serializes_plugin_defined_fields 2>&1 | head -c 12000 +``` + +Expected: FAIL because `telemetry_json` is missing. + +- [ ] **Step 3: Add serde_json dependency if needed** + +Check `crates/kernel/Cargo.toml`. If `serde_json` is not already present, add: + +```toml +serde_json = { workspace = true } +``` + +- [ ] **Step 4: Implement telemetry_json** + +Add to `crates/kernel/src/observability.rs`: + +```rust +pub fn telemetry_json(fields: &BTreeMap) -> String { + serde_json::to_string(fields).unwrap_or_else(|_| "{}".to_string()) +} +``` + +- [ ] **Step 5: Refactor service.rs access log** + +In `crates/kernel/src/service.rs`: + +1. Replace import of `telemetry_log_attributes` with `telemetry_json`. +2. Delete all hardcoded `telemetry.asset_id`, `telemetry.total_tokens`, `telemetry.mcp_tool`, etc. +3. Emit only: + +```rust +let telemetry = telemetry_json(&access_log.telemetry); +tracing::info!( + event = "http_access", + gateway = %access_log.gateway, + method = %access_log.method, + path = %access_log.path, + host = %access_log.host, + protocol_name = %access_log.protocol_name, + protocol_version = %access_log.protocol_version, + status_code = access_log.status_code, + request_id = %access_log.request_id, + peer_addr = %access_log.peer_addr, + duration_ms = access_log.duration_ms, + request_body_size = ?access_log.request_body_size, + response_body_size = ?access_log.response_body_size, + telemetry = %telemetry, + "http access log" +); +``` + +- [ ] **Step 6: Run tests** + +Run: + +```bash +cargo test -p spacegate-kernel observability::tests 2>&1 | head -c 12000 +``` + +Expected: PASS. + +--- + +### Task 4: Plugin API Becomes Namespaced and Checked + +**Files:** +- Modify: `crates/plugin/src/lib.rs` +- Modify: `crates/plugin/tests/test_telemetry.rs` + +- [ ] **Step 1: Update plugin tests** + +Replace `crates/plugin/tests/test_telemetry.rs` content with: + +```rust +use spacegate_plugin::{set_plugin_telemetry_field, set_telemetry_field, SgBody}; + +fn request_with_telemetry() -> hyper::Request { + let mut req = hyper::Request::builder().body(SgBody::empty()).expect("request"); + req.extensions_mut().insert(spacegate_kernel::observability::TelemetryContext::default()); + req +} + +#[test] +fn set_telemetry_field_writes_checked_request_context() { + let req = request_with_telemetry(); + + set_telemetry_field(&req, "ai.asset_id", "deepseek-chat").expect("insert"); + set_telemetry_field(&req, "ai.total_tokens", 37).expect("insert"); + + let fields = req.extensions().get::().expect("telemetry context").snapshot(); + assert_eq!(fields.get("ai.asset_id").map(String::as_str), Some("deepseek-chat")); + assert_eq!(fields.get("ai.total_tokens").map(String::as_str), Some("37")); +} + +#[test] +fn set_plugin_telemetry_field_adds_namespace() { + let req = request_with_telemetry(); + + set_plugin_telemetry_field(&req, "mcp", "tool", "search").expect("insert"); + + let fields = req.extensions().get::().expect("telemetry context").snapshot(); + assert_eq!(fields.get("mcp.tool").map(String::as_str), Some("search")); +} + +#[test] +fn set_telemetry_field_rejects_unqualified_key() { + let req = request_with_telemetry(); + + let result = set_telemetry_field(&req, "total_tokens", 37); + + assert_eq!(result, Err(spacegate_kernel::observability::TelemetryError::MissingNamespace)); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: + +```bash +cargo test -p spacegate-plugin test_telemetry 2>&1 | head -c 12000 +``` + +Expected: FAIL because `set_telemetry_field` currently returns `()` and `set_plugin_telemetry_field` is missing. + +- [ ] **Step 3: Implement plugin APIs** + +In `crates/plugin/src/lib.rs`, replace current `set_telemetry_field` with: + +```rust +pub fn set_telemetry_field( + req: &SgRequest, + key: impl Into, + value: impl ToString, +) -> Result<(), spacegate_kernel::observability::TelemetryError> { + if let Some(context) = req.extensions().get::() { + context.insert_checked(key, value)?; + } + Ok(()) +} + +pub fn set_plugin_telemetry_field( + req: &SgRequest, + namespace: &str, + key: &str, + value: impl ToString, +) -> Result<(), spacegate_kernel::observability::TelemetryError> { + if let Some(context) = req.extensions().get::() { + context.insert_namespaced(namespace, key, value)?; + } + Ok(()) +} +``` + +- [ ] **Step 4: Run tests** + +Run: + +```bash +cargo test -p spacegate-plugin test_telemetry 2>&1 | head -c 12000 +``` + +Expected: PASS. + +--- + +### Task 5: Update ClickHouse Query Script + +**Files:** +- Modify: `scripts/otel-local/query-access-logs.sh` + +- [ ] **Step 1: Update SQL to parse telemetry JSON** + +Replace selected telemetry fields with: + +```sql +JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') AS ai_asset_id, +JSONExtractString(LogAttributes['telemetry'], 'ai.asset_type') AS ai_asset_type, +JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens') AS ai_total_tokens, +JSONExtractString(LogAttributes['telemetry'], 'mcp.server') AS mcp_server, +JSONExtractString(LogAttributes['telemetry'], 'mcp.tool') AS mcp_tool, +JSONExtractString(LogAttributes['telemetry'], 'mcp.success') AS mcp_success, +JSONExtractString(LogAttributes['telemetry'], 'auth.app_id') AS auth_app_id, +JSONExtractString(LogAttributes['telemetry'], 'auth.api_key_hash') AS auth_api_key_hash +``` + +Keep base fields: + +```sql +Timestamp, +Body, +SeverityText, +LogAttributes['event'] AS event, +LogAttributes['gateway'] AS gateway, +LogAttributes['method'] AS method, +LogAttributes['path'] AS path, +LogAttributes['status_code'] AS status_code, +LogAttributes['request_id'] AS request_id, +LogAttributes['duration_ms'] AS duration_ms, +LogAttributes['telemetry'] AS telemetry +``` + +- [ ] **Step 2: Validate shell syntax** + +Run: + +```bash +bash -n scripts/otel-local/query-access-logs.sh +``` + +Expected: no output and exit code 0. + +--- + +### Task 6: Update WASM Plan + +**Files:** +- Modify: `docs/wasm-telemetry-host-function-plan.md` + +- [ ] **Step 1: Align ABI plan with namespaced telemetry** + +Update the plan so `spacegate_set_telemetry_field` requires a fully qualified key: + +```text +ai.total_tokens +mcp.tool +auth.app_id +``` + +And add optional convenience SDK wrapper: + +```rust +pub fn set_plugin_telemetry_field(namespace: &str, key: &str, value: impl ToString) -> Result<(), Status> { + set_telemetry_field(&format!("{namespace}.{key}"), value) +} +``` + +- [ ] **Step 2: Align validation section** + +Document the same rules: + +- key max 128 bytes +- value max 4096 bytes +- key must contain `.` +- allowed key chars: `[A-Za-z0-9_.-]` +- reserved prefixes rejected: `http.`, `net.`, `gateway.`, `spacegate.`, `otel.` + +--- + +### Task 7: Full Verification + +**Files:** +- No code changes. + +- [ ] **Step 1: Format touched Rust files** + +Run: + +```bash +rustfmt --edition 2021 crates/kernel/src/observability.rs crates/kernel/src/service.rs crates/plugin/src/lib.rs crates/plugin/tests/test_telemetry.rs +``` + +Expected: no output and exit code 0. + +- [ ] **Step 2: Run targeted tests** + +Run: + +```bash +cargo test -p spacegate-kernel observability::tests 2>&1 | head -c 12000 +cargo test -p spacegate-plugin test_telemetry 2>&1 | head -c 12000 +``` + +Expected: PASS. + +- [ ] **Step 3: Run integration compile check** + +Run: + +```bash +cargo check -p spacegate-shell --features fs,plugin-wasm 2>&1 | head -c 12000 +``` + +Expected: PASS. Existing unrelated warnings may remain. + +- [ ] **Step 4: Validate local scripts** + +Run: + +```bash +for f in scripts/otel-local/*.sh; do bash -n "$f" || exit 1; done +``` + +Expected: no output and exit code 0. + +--- + +## Expected Result + +插件写: + +```rust +set_plugin_telemetry_field(&req, "ai", "asset_id", "deepseek-chat")?; +set_plugin_telemetry_field(&req, "ai", "total_tokens", 37)?; +set_plugin_telemetry_field(&req, "mcp", "tool", "search")?; +``` + +access log 入库: + +```text +LogAttributes['event'] = 'http_access' +LogAttributes['telemetry'] = '{"ai.asset_id":"deepseek-chat","ai.total_tokens":"37","mcp.tool":"search"}' +``` + +审计查询: + +```sql +SELECT + Timestamp, + LogAttributes['request_id'] AS request_id, + JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') AS asset_id, + toUInt64OrZero(JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens')) AS total_tokens, + JSONExtractString(LogAttributes['telemetry'], 'mcp.tool') AS mcp_tool +FROM otel_logs +WHERE LogAttributes['event'] = 'http_access' +ORDER BY Timestamp DESC; +``` + +## Self-Review + +- Spec coverage: covers namespace, validation, generic JSON access log, plugin API, ClickHouse query, WASM plan alignment. +- Placeholder scan: no TBD/TODO placeholders. +- Type consistency: `TelemetryError` lives in kernel and is returned by plugin APIs; `TelemetryContext` remains request extension. +- Boundary check: no AI/MCP/token business semantics remain in `service.rs`; those appear only in docs/scripts as examples. diff --git a/docs/otlp/wasm-telemetry-host-function-plan.md b/docs/otlp/wasm-telemetry-host-function-plan.md new file mode 100644 index 00000000..d831b902 --- /dev/null +++ b/docs/otlp/wasm-telemetry-host-function-plan.md @@ -0,0 +1,251 @@ +# WASM 插件审计字段 Host Function 技术方案 + +## 目标 + +让 WASM 插件也能像原生插件一样写入请求级业务审计字段,例如: + +- `ai.asset_id` +- `ai.asset_type` +- `ai.prompt_tokens` +- `ai.completion_tokens` +- `ai.total_tokens` +- `auth.app_id` +- `auth.api_key_hash` +- `mcp.server` +- `mcp.tool` +- `mcp.success` +- `error.code` + +这些字段最终随请求结束时的 `http_access` 日志进入 OTLP logs,再由 Collector 写入 ClickHouse 的 `otel_logs`。 + +## 当前原生插件链路 + +原生插件调用: + +```rust +spacegate_plugin::set_plugin_telemetry_field(&req, "ai", "asset_id", "deepseek-chat")?; +spacegate_plugin::set_plugin_telemetry_field(&req, "ai", "total_tokens", 37)?; +``` + +数据流: + +```text +SgRequest.extensions.TelemetryContext + -> kernel 请求结束生成 http_access 日志 + -> telemetry JSON log attribute + -> OTLP logs + -> Collector + -> ClickHouse otel_logs +``` + +## 推荐 WASM ABI + +新增非 proxy-wasm 标准的 SpaceGate 扩展 host function: + +```text +env.spacegate_set_telemetry_field(key_ptr, key_len, value_ptr, value_len) -> status +``` + +参数: + +- `key_ptr: i32` +- `key_len: i32` +- `value_ptr: i32` +- `value_len: i32` + +返回: + +- `Status::Ok` +- `Status::BadArgument` +- `Status::InvalidMemoryAccess` +- `Status::NotFound` + +命名选择: + +- 不复用 `proxy_call_foreign_function`,避免把核心审计能力塞进不透明 FFI。 +- 使用 `spacegate_` 前缀,明确这是 SpaceGate 扩展,不污染 proxy-wasm 标准 ABI。 + +## Host 侧实现 + +### 1. HostState 增加请求级 telemetry 存储 + +在 `crates/plugin-wasm/src/host_state.rs` 的 `RequestContext` 增加: + +```rust +pub telemetry_fields: BTreeMap, +``` + +原因: + +- WASM `Vm::process` 目前会把 `SgRequest` 拆成 `parts/body`,再重建 `new_req` 给 `inner.call`。 +- host fn 执行期间拿不到原始 `SgRequest` 引用。 +- 因此 WASM 调 host fn 时先写到当前 `RequestContext`,请求结束或调用 inner 前再同步到 `SgRequest.extensions.TelemetryContext`。 + +### 2. 注册 host function + +在 `crates/plugin-wasm/src/host_fn.rs` 增加: + +```rust +fn register_spacegate_telemetry(linker: &mut Linker) -> Result<(), wasmtime::Error> +``` + +并在 `register_all` 中调用。 + +处理逻辑: + +1. 用 `MemoryHelper::from_caller` 读取 guest memory。 +2. 读取 `key` 和 `value` 字符串。 +3. 校验: + - key 非空 + - key 最大 128 字节 + - value 最大 4096 字节 + - key 必须包含命名空间分隔符 `.` + - key 只能包含 `[A-Za-z0-9_.-]` + - 禁止保留前缀:`http.`、`net.`、`gateway.`、`spacegate.`、`otel.` +4. 获取 `caller.data_mut().current_context_mut()`。 +5. 写入 `ctx.telemetry_fields.insert(key, value)`。 +6. 返回 `Status::Ok`。 + +### 3. 同步到 SgRequest + +在 `crates/plugin-wasm/src/vm.rs` 的 `Vm::process` 中: + +- 重建 `new_req` 后、调用 `inner.call(new_req).await` 前: + +```rust +if let Some(kernel_ctx) = new_req.extensions().get::() { + for (key, value) in wasm_ctx.telemetry_fields { + kernel_ctx.insert_checked(key, value)?; + } +} +``` + +注意: + +- 需要把 `let new_req = ...` 改成 `let mut new_req = ...` 或在构造前保留 extensions。 +- 当前 request parts 来自原始 `SgRequest`,extensions 会保留,所以 kernel 插入的 `TelemetryContext` 可以继续存在。 + +### 4. 本地响应短路场景 + +如果 WASM 在 request 阶段通过 `proxy_send_local_response` 直接返回,不会调用 `inner.call`。 + +这种情况下也需要把 telemetry 同步回 access log: + +- 方案 A:短路前直接从原始 request extensions 同步。 +- 方案 B:在 `Vm::process` 开始时把 `TelemetryContext` clone 存进 `HostState` 或当前 `RequestContext`。 + +推荐方案 B: + +```rust +RequestContext { + telemetry_sink: Option, +} +``` + +在 `Vm::process` 开始时: + +```rust +let telemetry_sink = parts.extensions.get::().cloned(); +ctx.telemetry_sink = telemetry_sink; +``` + +host fn 写字段时: + +```rust +if let Some(sink) = &ctx.telemetry_sink { + sink.insert_checked(key.clone(), value.clone())?; +} +ctx.telemetry_fields.insert(key, value); +``` + +这样即使本地响应短路,kernel 请求结束时也能读到审计字段。 + +## Guest SDK 封装 + +WASM 插件侧建议提供一个薄封装: + +```rust +#[link(wasm_import_module = "env")] +extern "C" { + fn spacegate_set_telemetry_field( + key_ptr: i32, + key_len: i32, + value_ptr: i32, + value_len: i32, + ) -> i32; +} + +pub fn set_telemetry_field(key: &str, value: impl ToString) -> Result<(), Status> { + let value = value.to_string(); + let status = unsafe { + spacegate_set_telemetry_field( + key.as_ptr() as i32, + key.len() as i32, + value.as_ptr() as i32, + value.len() as i32, + ) + }; + Status::from_i32(status) +} + +pub fn set_plugin_telemetry_field(namespace: &str, key: &str, value: impl ToString) -> Result<(), Status> { + set_telemetry_field(&format!("{namespace}.{key}"), value) +} +``` + +插件使用: + +```rust +set_plugin_telemetry_field("ai", "asset_id", "deepseek-chat")?; +set_plugin_telemetry_field("ai", "prompt_tokens", 24)?; +set_plugin_telemetry_field("ai", "completion_tokens", 13)?; +set_plugin_telemetry_field("ai", "total_tokens", 37)?; +set_plugin_telemetry_field("mcp", "tool", "search")?; +``` + +## 测试计划 + +### 单元测试 + +- `host_fn` 能读取 guest memory 中的 key/value。 +- 非法 key 返回 `BadArgument`。 +- 空 key 返回 `BadArgument`。 +- 超长 value 返回 `BadArgument`。 +- 无当前 HTTP context 返回 `NotFound`。 + +### WASM 集成测试 + +新增一个测试 wasm: + +- 在 `proxy_on_request_headers` 写 `ai.asset_id`。 +- 在 `proxy_on_response_body` 写 token 字段。 +- 请求结束后断言 `TelemetryContext.snapshot()` 包含这些字段。 + +### 端到端测试 + +本地脚本启动: + +```bash +scripts/otel-local/start-clickhouse.sh +scripts/otel-local/start-collector.sh +scripts/otel-local/start-mock-ac.sh +scripts/otel-local/start-spacegate.sh +scripts/otel-local/request.sh +scripts/otel-local/query-access-logs.sh +``` + +确认 ClickHouse `otel_logs` 中包含: + +```text +JSONExtractString(LogAttributes['telemetry'], 'ai.asset_id') +JSONExtractString(LogAttributes['telemetry'], 'ai.total_tokens') +JSONExtractString(LogAttributes['telemetry'], 'mcp.tool') +``` + +## 风险与边界 + +- 这是 SpaceGate 扩展 ABI,不是 proxy-wasm 标准函数。 +- 字段 key 必须限制字符集和长度,避免 ClickHouse 查询侧难以治理。 +- 不建议把 `request_id`、用户 ID、完整 prompt、完整 response body 写入 telemetry 字段。 +- token、MCP、模型 ID 属于审计日志字段,不应作为 metrics label。 +- WASM 当前单 VM 串行处理请求,字段必须存放在 `RequestContext`,不能放全局 map。 diff --git a/examples/wasm-hello/Cargo.toml b/examples/wasm-hello/Cargo.toml new file mode 100644 index 00000000..58a93142 --- /dev/null +++ b/examples/wasm-hello/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "spacegate_wasm_hello" +version = "0.0.0" +edition = "2021" +publish = false +description = "Minimal Proxy-Wasm hello world plugin for Spacegate." + +# Standalone workspace: this crate builds for wasm32-wasip1 and is not part of +# the host Spacegate workspace. +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm = "0.2" + +[profile.release] +codegen-units = 1 +opt-level = "z" +lto = "fat" +strip = true +panic = "abort" diff --git a/examples/wasm-hello/README.md b/examples/wasm-hello/README.md new file mode 100644 index 00000000..67312486 --- /dev/null +++ b/examples/wasm-hello/README.md @@ -0,0 +1,49 @@ +# Spacegate Wasm Hello World + +This is a minimal Proxy-Wasm guest plugin for Spacegate. + +Build the wasm: + +```bash +cd examples/wasm-hello +cargo build --release --target wasm32-wasip1 +cd ../.. +cp examples/wasm-hello/target/wasm32-wasip1/release/spacegate_wasm_hello.wasm resource/wasm/spacegate_wasm_hello.wasm +``` + +If you rebuild the wasm, update `resource/wasm-hello-demo/plugin/wasm.hello-world.json` +with the new digest: + +```bash +shasum -a 256 resource/wasm/spacegate_wasm_hello.wasm +``` + +The wasm host also supports remote loading: + +```json +{ + "url": "https://example.com/plugins/spacegate_wasm_hello.wasm", + "sha256": "sha256:<64-char-hex-digest>", + "module_cache_key": "spacegate-wasm-hello:v1", + "use_cache": true +} +``` + +Run Spacegate with the demo config from the repository root: + +```bash +RUST_LOG=info cargo run -p spacegate --features wasm -- -c file:resource/wasm-hello-demo +``` + +On startup, Spacegate should log: + +```text +hello world from spacegate wasm plugin +hello world wasm plugin configured +``` + +The demo route also lets the plugin return a direct response: + +```bash +curl http://127.0.0.1:18082/hello-world +``` diff --git a/examples/wasm-hello/src/lib.rs b/examples/wasm-hello/src/lib.rs new file mode 100644 index 00000000..bbdda714 --- /dev/null +++ b/examples/wasm-hello/src/lib.rs @@ -0,0 +1,52 @@ +use proxy_wasm::hostcalls; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +const HELLO: &str = "hello world from spacegate wasm plugin"; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Info); + proxy_wasm::set_root_context(|_| -> Box { Box::new(HelloRoot) }); +}} + +struct HelloRoot; + +impl Context for HelloRoot {} + +impl RootContext for HelloRoot { + fn on_vm_start(&mut self, _: usize) -> bool { + let _ = hostcalls::log(LogLevel::Info, HELLO); + true + } + + fn on_configure(&mut self, _: usize) -> bool { + let _ = hostcalls::log(LogLevel::Info, "hello world wasm plugin configured"); + true + } + + fn create_http_context(&self, _: u32) -> Option> { + Some(Box::new(HelloHttp)) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} + +struct HelloHttp; + +impl Context for HelloHttp {} + +impl HttpContext for HelloHttp { + fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { + let _ = hostcalls::log(LogLevel::Info, "hello world request reached wasm plugin"); + self.add_http_request_header("x-wasm-hello", "hello-world"); + + if self.get_http_request_header(":path").as_deref() == Some("/hello-world") { + self.send_http_response(200, vec![("content-type", "text/plain"), ("x-powered-by", "spacegate-wasm")], Some(b"hello world\n")); + return Action::Pause; + } + + Action::Continue + } +} diff --git a/plugins/wasm/.cargo/config.toml b/plugins/wasm/.cargo/config.toml new file mode 100644 index 00000000..9b923aa9 --- /dev/null +++ b/plugins/wasm/.cargo/config.toml @@ -0,0 +1,3 @@ +[build] +target = "wasm32-wasip1" + diff --git a/plugins/wasm/.gitignore b/plugins/wasm/.gitignore new file mode 100644 index 00000000..a788e5e5 --- /dev/null +++ b/plugins/wasm/.gitignore @@ -0,0 +1,4 @@ +/target/ +**/*.wasm +!README.md + diff --git a/plugins/wasm/Cargo.toml b/plugins/wasm/Cargo.toml new file mode 100644 index 00000000..c7aa5bec --- /dev/null +++ b/plugins/wasm/Cargo.toml @@ -0,0 +1,22 @@ +[workspace] +members = [ + "hello-world", + "ai-gateway-queue", +] +resolver = "2" + +[workspace.package] +version = "0.0.0" +edition = "2021" +publish = false + +[workspace.dependencies] +proxy-wasm = "0.2" +serde_json = "1" + +[profile.release] +codegen-units = 1 +opt-level = "z" +lto = "fat" +strip = true +panic = "abort" diff --git a/plugins/wasm/README.md b/plugins/wasm/README.md new file mode 100644 index 00000000..4e9db871 --- /dev/null +++ b/plugins/wasm/README.md @@ -0,0 +1,70 @@ +# Spacegate Wasm Plugins + +This directory is the dedicated development workspace for Spacegate Proxy-Wasm plugins. + +## Layout + +```text +plugins/wasm/ + Cargo.toml + hello-world/ + Cargo.toml + src/lib.rs + plugin.yaml + ai-gateway-queue/ + Cargo.toml + src/lib.rs + plugin.yaml +``` + +Use this directory for plugin source code. Keep compiled `.wasm` files in `resource/wasm/` for local demos, or publish them as OCI artifacts/images for Kubernetes usage. + +## Build + +Install the wasm target once: + +```bash +rustup target add wasm32-wasip1 +``` + +Build all plugins: + +```bash +cargo build --release --target wasm32-wasip1 --manifest-path plugins/wasm/Cargo.toml +``` + +The output for `hello-world` is: + +```text +plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_hello_world.wasm +``` + +The AI gateway queue plugin output is: + +```text +plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm +``` + +If you run commands from inside `plugins/wasm/`, the local `.cargo/config.toml` already sets the wasm target: + +```bash +cd plugins/wasm +cargo build --release +``` + +For a local file-based demo, copy or package the built wasm into `resource/wasm/` and reference it with `file://...`. + +For production-style delivery, publish the wasm as an OCI artifact/image and reference it from a Higress-compatible `WasmPlugin`: + +```yaml +spec: + url: oci://registry.example.com/spacegate/plugins/hello-world:v1 +``` + +## Adding A Plugin + +1. Create `plugins/wasm//`. +2. Add it to `plugins/wasm/Cargo.toml` under `workspace.members`. +3. Set the crate type to `cdylib`. +4. Implement the Proxy-Wasm entry point with `proxy_wasm::main!`. +5. Add a `plugin.yaml` example that shows the intended `WasmPlugin` config. diff --git a/plugins/wasm/ai-gateway-queue/Cargo.toml b/plugins/wasm/ai-gateway-queue/Cargo.toml new file mode 100644 index 00000000..711a2531 --- /dev/null +++ b/plugins/wasm/ai-gateway-queue/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "spacegate_plugin_ai_gateway_queue" +version.workspace = true +edition.workspace = true +publish.workspace = true +description = "AI gateway rate-limit and queue Proxy-Wasm plugin for SpaceGate." + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +proxy-wasm.workspace = true +serde_json.workspace = true diff --git a/plugins/wasm/ai-gateway-queue/README.md b/plugins/wasm/ai-gateway-queue/README.md new file mode 100644 index 00000000..9ae621b8 --- /dev/null +++ b/plugins/wasm/ai-gateway-queue/README.md @@ -0,0 +1,351 @@ +# ai-gateway-queue + +`ai-gateway-queue` 是一个运行在 SpaceGate Wasm 里的 **AI 请求队列网关**插件:在入口处对 AI 请求按租户做令牌桶限流,再根据 `X-RateLimit-Policy` 分流。**三种策略均先调用 `/v1/ratelimit/check`**;配额内直通上游,超额时分别返回 429 / 202 入队 / 入队并阻塞等待。 + +支持三种队列模式(通过 `X-RateLimit-Policy` 请求头选择,名字保留兼容历史): + +- `abandon`:配额内直通上游;超额返回 429(不入队) +- `queue`:配额内直通上游;超额入队后立即返回 `202`,结果通过回调或轮询拿到 +- `wait`:配额内直通上游;超额入队后同步等待结果(类长轮询),超时返回 `504` + +插件本身不直接访问 Redis,而是通过 `dispatch_http_call` 调用外部队列后端(`ai-gateway-service`),再由该后端处理 Redis Streams、worker 消费、回调重试、结果回收等队列基础设施。 + +## 架构 + +```text +Client + -> SpaceGate / ai-gateway-queue wasm plugin + -> ai-gateway-service + -> Redis / Worker / Upstream AI Service +``` + +## 依赖 + +- SpaceGate 已启用 Wasm 支持 +- Rust 工具链 +- `wasm32-wasip1` 目标 +- Redis +- `ai-gateway-service` + +安装 wasm 目标: + +```bash +rustup target add wasm32-wasip1 +``` + +## 构建 + +在 `spacegate` 目录下执行: + +```bash +cargo build --release --target wasm32-wasip1 --manifest-path plugins/wasm/Cargo.toml -p spacegate_plugin_ai_gateway_queue +``` + +编译产物: + +```text +plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm +``` + +## 制作 OCI 制品 + +生产环境建议将 `.wasm` 以 **OCI Artifact** 形式推送到镜像仓库(Harbor、GHCR、ACR 等),SpaceGate 通过 `oci://` / `oci+http://` URL 拉取,而不是挂载本地文件或 HTTP 分发服务。 + +> OCI 制品是单层 Wasm 文件,用 **oras** 推送,不是 `docker build` 容器镜像。Docker Hub 通常不支持 Wasm OCI Artifact,请用 Harbor / GHCR / ACR 等。 + +### 前置条件 + +```bash +# 安装 oras(OCI 推送/拉取工具) +brew install oras + +# 确保已编译 wasm(见上一节「构建」) +rustup target add wasm32-wasip1 +``` + +### 方式一:推送到本地 Harbor(推荐联调) + +本地 Harbor 示例:`http://localhost:9081`,默认账号 `admin` / `Harbor12345`。 + +**手动推送**: + +```bash +WASM_DIR=plugins/wasm/target/wasm32-wasip1/release +WASM=spacegate_plugin_ai_gateway_queue.wasm + +# 1. 创建 Harbor 项目(已存在可跳过) +curl -u 'admin:Harbor12345' -X POST 'http://localhost:9081/api/v2.0/projects' \ + -H 'Content-Type: application/json' \ + -d '{"project_name":"ai-gateway","public":false}' + +# 2. 登录 Harbor(HTTP 需加 --plain-http) +echo 'Harbor12345' | oras login localhost:9081 -u admin --password-stdin --plain-http + +# 3. 计算 digest(生产建议写入插件配置) +shasum -a 256 "$WASM_DIR/$WASM" + +# 4. 推送 OCI Artifact +cd "$WASM_DIR" +oras push localhost:9081/ai-gateway/ai-gateway-queue:v1.0.0 --plain-http \ + --artifact-type application/vnd.module.wasm.content.layer.v1+wasm \ + "${WASM}:application/wasm" +``` + +推送成功后可在 Harbor UI **项目 → ai-gateway → ai-gateway-queue** 查看制品。 + +### 方式二:推送到任意 OCI 仓库(GHCR / ACR / 私有 Harbor) + +在 `spacegate` 仓库根目录: + +```bash +# 登录目标仓库 +oras login ghcr.io -u YOUR_USER + +# 使用仓库自带脚本 +REGISTRY=ghcr.io/your-org TAG=v1.0.0 ./deploy/push-wasm-oci.sh + +# 或从 ai-gateway-dev 工作区根目录 +REGISTRY=ghcr.io/your-org ./scripts/deploy.sh oci push +``` + +### 插件配置中引用 OCI + +SpaceGate 支持的 URL 形式: + +| 场景 | 示例 | +|------|------| +| HTTPS 仓库 | `oci://ghcr.io/your-org/ai-gateway-queue:v1.0.0` | +| 本地 Harbor(HTTP) | `oci+http://localhost:9081/ai-gateway/ai-gateway-queue:v1.0.0` | +| K8s 拉取宿主机 Harbor(Docker Desktop) | `oci+http://host.docker.internal:9081/ai-gateway/ai-gateway-queue:v1.0.0` | + +完整配置示例(含校验与私有仓库凭证)可参考工作区 `open-source/harbor/plugins/wasm.ai-gateway-queue.oci.json`(与 `ai-gateway-dev` 同级目录下的 Harbor 联调配置): + +```json +{ + "url": "oci+http://host.docker.internal:9081/ai-gateway/ai-gateway-queue:v1.0.0", + "sha256": "sha256:<编译产物 shasum -a 256 输出>", + "oci_auth": { + "registry": "host.docker.internal:9081", + "username": "admin", + "password": "Harbor12345" + }, + "fail_strategy": "fail_close", + "plugin_name": "ai-gateway-queue", + "plugin_config": { "...": "..." }, + "clusters": { + "ai-gateway-service": "http://ai-gateway-service:18080" + } +} +``` + +K8s **SgFilter** 中在 `spec.filters[].config` 写入上述字段即可;`clusters` 仍指向 `ai-gateway-service`(与 Wasm 存储方式无关)。 + +### 版本更新 + +1. 修改代码后重新 `cargo build --release --target wasm32-wasip1 ...` +2. 用新 tag 执行 `oras push`(如 `v1.0.1`) +3. 更新 SpaceGate 配置中的 `url` 与 `sha256`(或 `module_cache_key`) + +### 注意事项 + +- **`sha256`**:建议生产开启,防止同 tag 被覆盖后加载错误版本 +- **私有仓库**:配置 `oci_auth`,或 K8s WasmPlugin 使用 `imagePullSecret` +- **K8s 网络**:Pod 内勿用 `localhost:9081` 指宿主机 Harbor,Docker Desktop 用 `host.docker.internal:9081` +- 更多细节见 [`deploy/README.md`](../../../deploy/README.md) 第 7 节 + +## 启动外部服务 + +`ai-gateway-queue` 依赖外部服务来完成限流、入队、等待和回调。 + +```bash +cargo run -p ai-gateway-service -- \ + --redis-url redis://127.0.0.1/ \ + --upstream-base-url http://127.0.0.1:9000 +``` + +常用环境变量: + +```bash +REDIS_URL=redis://127.0.0.1/ +AI_UPSTREAM_BASE_URL=http://127.0.0.1:9000 +AI_RATE_LIMIT_RPS=100 +AI_RATE_LIMIT_BURST=200 +AI_WAIT_TIMEOUT_SECS=60 +AI_WORKER_CONCURRENCY=4 +AI_MAX_BODY_BYTES=33554432 +AI_INLINE_THRESHOLD=131072 +AI_QUEUE_MAX_LEN=100000 +AI_RECLAIM_INTERVAL_SECS=30 +AI_RECLAIM_MIN_IDLE_SECS=30 +``` + +如果不设置 `AI_UPSTREAM_BASE_URL`,队列任务仍会写入 Redis,但不会由本地 worker 消费。 + +本地调试如果使用 HTTP 回调地址,可以临时加上: + +```bash +AI_REQUIRE_HTTPS_CALLBACK=false +``` + +## SpaceGate 配置 + +可参考: + +- 文件模板:[`resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json`](../../resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json) +- **管理界面操作步骤**:[`docs/ai-gateway-queue-admin-ui-guide.md`](../../docs/ai-gateway-queue-admin-ui-guide.md) + +关键配置项: + +```json +{ + "url": "plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm", + "fail_strategy": "fail_close", + "plugin_name": "ai-gateway-queue", + "vm_pool_size": 4, + "wait_vm_pool_size": 4, + "limits": { + "max_memory_pages": 64, + "fuel_per_call": 20000000, + "epoch_timeout_millis": 50, + "max_body_bytes": 33554432, + "max_pending_calls": 1 + }, + "plugin_config": { + "service": { + "cluster": "ai-gateway-service", + "authority": "ai-gateway-service", + "timeout_ms": 65000 + }, + "paths": { + "rate_limit": "/v1/ratelimit/check", + "enqueue": "/v1/queue/enqueue", + "wait": "/v1/queue/enqueue-and-wait" + }, + "headers": { + "policy": "x-ratelimit-policy", + "tenant": "x-tenant-id", + "model": "x-model", + "priority": "x-queue-priority" + }, + "policies": { + "require": true, + "default": null + }, + "priority": { + "enabled": true, + "default": "normal", + "high_models": ["gpt-4o"], + "low_tenants": ["free"] + } + }, + "clusters": { + "ai-gateway-service": "http://127.0.0.1:18080" + } +} +``` + +### `plugin_config` 说明 + +- `service_cluster`:外部服务所在 cluster 名称 +- `service_authority`:转发时使用的 `:authority` +- `rate_limit_path`:限流检查接口 +- `enqueue_path`:入队接口 +- `wait_path`:入队并等待接口 +- `service_timeout_ms`:调用外部服务超时 +- `require_policy`:是否强制要求请求头携带策略 +- `headers.*`:自定义客户端侧策略、租户、模型、优先级 header;插件会转成外部服务统一使用的 `x-ratelimit-policy`、`x-tenant-id`、`x-model`、`x-queue-priority` +- `policies.default`:未携带策略 header 时使用的默认策略;为空且 `require=true` 时会返回 `400` +- `priority.*`:插件侧优先级推导规则,支持按模型或租户自动设置 `high` / `low` + +插件配置优先支持上面的结构化 JSON;旧的扁平字段仍兼容,例如 `service_cluster`、`rate_limit_path`、`tenant_header`、`default_policy`、`high_priority_models`。 + +## 请求头 + +插件依赖下列请求头: + +- `X-RateLimit-Policy`:必填,取值为 `abandon`、`queue`、`wait` +- `X-Tenant-Id`:必填 +- `X-Callback-URL`:`queue` 场景下必填,默认要求 HTTPS +- `X-Request-Timeout`:`wait` 场景下可选,单位为秒 +- `X-Model`:可选,透传给外部服务 +- `X-Queue-Priority`:可选,启用优先级队列后可传 `high` 或 `low` + +Header 名称大小写不敏感;`X-RateLimit-Policy` 的值请使用小写。 + +## 三种模式 + +### 1. `abandon` + +先调用限流接口,允许则继续转发到后端,拒绝则返回 `429`。 + +示例: + +```bash +curl -i http://localhost:9080/your/api \ + -H 'X-RateLimit-Policy: abandon' \ + -H 'X-Tenant-Id: demo' \ + -H 'X-Model: gpt-4o-mini' \ + -d '{"prompt":"hello"}' +``` + +### 2. `queue` + +先调用限流接口。配额内继续转发到上游(与 `abandon` 相同);超额时请求体进入队列,插件返回 `202 Accepted`,响应里会带 `X-Job-Id`。 + +示例: + +```bash +curl -i http://localhost:9080/your/api \ + -H 'X-RateLimit-Policy: queue' \ + -H 'X-Tenant-Id: demo' \ + -H 'X-Callback-URL: https://example.com/callback' \ + -d '{"prompt":"hello"}' +``` + +### 3. `wait` + +先调用限流接口。配额内继续转发到上游;超额时请求体进入队列后等待结果返回。成功时直接返回上游响应,超时则返回 `504`。 + +示例: + +```bash +curl -i http://localhost:9080/your/api \ + -H 'X-RateLimit-Policy: wait' \ + -H 'X-Tenant-Id: demo' \ + -H 'X-Request-Timeout: 60' \ + -d '{"prompt":"hello"}' +``` + +## 返回行为 + +- `400`:缺少必要请求头或策略非法(无 `X-RateLimit-Policy` 且无 `default_policy` 时也会 400) +- `429`:`abandon` 超额限流 +- `202`:`queue` 超额入队已接收 +- `200`/`4xx`/`5xx`:配额内三种策略均由上游返回;`wait` 超额完成后也由外部服务返回上游响应 +- `502`:外部服务不可达或调用失败 + +`wait` 成功响应会带 `X-Job-Id` 和 `X-Queue-Wait-Ms`;`queue` 响应会带 `X-Job-Id` 和 `Location`。 + +## 生产化能力 + +- Redis Stream 支持 `MAXLEN ~` 裁剪,通过 `AI_QUEUE_MAX_LEN` 控制 +- 租户限流支持按租户、模型、路由、策略多维覆盖,并支持单请求 cost +- 优先级队列支持 header、模型、租户规则推导,并由 worker 按权重消费高/普通/低优先级 Stream +- Worker 崩溃后通过 `XAUTOCLAIM` 重认领 pending job,并通过 Redis 处理租约避免长任务被重复执行 +- 回调失败会进入 `AI_CALLBACK_RETRY_STREAM`,按指数退避重试,超过最大次数后进入 `AI_CALLBACK_DLQ_STREAM` +- 大 body 可通过 `AI_OBJECT_STORE_ENDPOINT` 走 S3-compatible multipart 卸载,Redis Stream 中只保留 `ref`;未配置 S3 且 body 超过 `AI_INLINE_THRESHOLD` 时返回 `413` +- 租户限流令牌桶按 `X-Tenant-Id` 隔离;可通过 Admin API `PUT /v1/admin/tenant-rate-limits` 或 Redis 配置键覆盖每租户 RPS/Burst(支持 model/path/policy 维度 lookup,桶 key 仍为 tenant-only) +- `/metrics` 暴露 Prometheus 文本指标,包含队列深度、PEL、DLQ、入队延迟、body 大小、wait 超时、回调重试和 worker 处理耗时 + +## 调试建议 + +- 先确认 `ai-gateway-service` 已启动并能连上 Redis +- 再确认 SpaceGate 的 `clusters.ai-gateway-service` 指向正确地址 +- `wait` 模式建议单独使用 `wait_vm_pool_size`,避免拖垮普通请求 +- 如果请求一直返回 `400 missing_or_invalid_rate_limit_policy`,先检查 `X-RateLimit-Policy` + +## 备注 + +- 这个插件当前是面向 OpenAI 风格 AI 请求的队列和限流入口 +- Redis 相关逻辑被放在 wasm 外部服务中,便于隔离和演进 +- 具体协议和接口字段,以 `ai-gateway-service` 的实现为准 diff --git a/plugins/wasm/ai-gateway-queue/plugin.yaml b/plugins/wasm/ai-gateway-queue/plugin.yaml new file mode 100644 index 00000000..49a04291 --- /dev/null +++ b/plugins/wasm/ai-gateway-queue/plugin.yaml @@ -0,0 +1,25 @@ +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-gateway-queue + namespace: spacegate +spec: + url: oci://registry.example.com/spacegate/plugins/ai-gateway-queue:v1 + pluginName: ai-gateway-queue + phase: AUTHN + priority: 90 + failStrategy: FAIL_CLOSE + defaultConfig: + service_cluster: ai-gateway-service + service_authority: ai-gateway-service + rate_limit_path: /v1/ratelimit/check + enqueue_path: /v1/queue/enqueue + wait_path: /v1/queue/enqueue-and-wait + service_timeout_ms: 65000 + require_policy: true + policy_header: x-ratelimit-policy + tenant_header: x-tenant-id + model_header: x-model + priority_header: x-queue-priority + priority_enabled: true + default_priority: normal diff --git a/plugins/wasm/ai-gateway-queue/src/lib.rs b/plugins/wasm/ai-gateway-queue/src/lib.rs new file mode 100644 index 00000000..eeb68519 --- /dev/null +++ b/plugins/wasm/ai-gateway-queue/src/lib.rs @@ -0,0 +1,552 @@ +use std::time::Duration; + +use proxy_wasm::hostcalls; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; +use serde_json::Value; + +mod policy; +use policy::{contains_allowed_true, extract_json_number, normalize_policy, normalize_priority}; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Info); + proxy_wasm::set_root_context(|_| -> Box { Box::new(AiGatewayRoot::default()) }); +}} + +#[derive(Clone)] +struct AiGatewayConfig { + service_cluster: String, + service_authority: String, + rate_limit_path: String, + enqueue_path: String, + wait_path: String, + service_timeout_ms: u64, + require_policy: bool, + policy_header: String, + tenant_header: String, + model_header: String, + priority_header: String, + default_policy: Option, + priority_enabled: bool, + default_priority: String, + high_priority_models: Vec, + low_priority_models: Vec, + high_priority_tenants: Vec, + low_priority_tenants: Vec, +} + +impl Default for AiGatewayConfig { + fn default() -> Self { + Self { + service_cluster: "ai-gateway-service".to_string(), + service_authority: "ai-gateway-service".to_string(), + rate_limit_path: "/v1/ratelimit/check".to_string(), + enqueue_path: "/v1/queue/enqueue".to_string(), + wait_path: "/v1/queue/enqueue-and-wait".to_string(), + service_timeout_ms: 65_000, + require_policy: true, + policy_header: "x-ratelimit-policy".to_string(), + tenant_header: "x-tenant-id".to_string(), + model_header: "x-model".to_string(), + priority_header: "x-queue-priority".to_string(), + default_policy: None, + priority_enabled: true, + default_priority: "normal".to_string(), + high_priority_models: Vec::new(), + low_priority_models: Vec::new(), + high_priority_tenants: Vec::new(), + low_priority_tenants: Vec::new(), + } + } +} + +impl AiGatewayConfig { + fn parse(raw: &[u8]) -> Self { + let mut cfg = Self::default(); + if raw.is_empty() { + return cfg; + } + + if let Ok(value) = serde_json::from_slice::(raw) { + cfg.apply_json(&value); + return cfg.normalized(); + } + + let text = String::from_utf8_lossy(raw); + cfg.apply_legacy_lines(&text); + cfg.normalized() + } + + fn apply_json(&mut self, value: &Value) { + set_string(value, &["service_cluster"], &mut self.service_cluster); + set_string(value, &["service", "cluster"], &mut self.service_cluster); + set_string(value, &["service_authority"], &mut self.service_authority); + set_string(value, &["service", "authority"], &mut self.service_authority); + set_string(value, &["rate_limit_path"], &mut self.rate_limit_path); + set_string(value, &["paths", "rate_limit"], &mut self.rate_limit_path); + set_string(value, &["enqueue_path"], &mut self.enqueue_path); + set_string(value, &["paths", "enqueue"], &mut self.enqueue_path); + set_string(value, &["wait_path"], &mut self.wait_path); + set_string(value, &["paths", "wait"], &mut self.wait_path); + set_u64(value, &["service_timeout_ms"], &mut self.service_timeout_ms); + set_u64(value, &["service", "timeout_ms"], &mut self.service_timeout_ms); + set_bool(value, &["require_policy"], &mut self.require_policy); + set_bool(value, &["policies", "require"], &mut self.require_policy); + self.default_policy = string_at(value, &["default_policy"]).or_else(|| string_at(value, &["policies", "default"])); + set_string(value, &["policy_header"], &mut self.policy_header); + set_string(value, &["headers", "policy"], &mut self.policy_header); + set_string(value, &["tenant_header"], &mut self.tenant_header); + set_string(value, &["headers", "tenant"], &mut self.tenant_header); + set_string(value, &["model_header"], &mut self.model_header); + set_string(value, &["headers", "model"], &mut self.model_header); + set_string(value, &["priority_header"], &mut self.priority_header); + set_string(value, &["headers", "priority"], &mut self.priority_header); + set_bool(value, &["priority_enabled"], &mut self.priority_enabled); + set_bool(value, &["priority", "enabled"], &mut self.priority_enabled); + set_string(value, &["default_priority"], &mut self.default_priority); + set_string(value, &["priority", "default"], &mut self.default_priority); + self.high_priority_models = string_vec_at(value, &["priority", "high_models"]).or_else(|| string_vec_at(value, &["high_priority_models"])).unwrap_or_default(); + self.low_priority_models = string_vec_at(value, &["priority", "low_models"]).or_else(|| string_vec_at(value, &["low_priority_models"])).unwrap_or_default(); + self.high_priority_tenants = string_vec_at(value, &["priority", "high_tenants"]).or_else(|| string_vec_at(value, &["high_priority_tenants"])).unwrap_or_default(); + self.low_priority_tenants = string_vec_at(value, &["priority", "low_tenants"]).or_else(|| string_vec_at(value, &["low_priority_tenants"])).unwrap_or_default(); + } + + fn apply_legacy_lines(&mut self, text: &str) { + for line in text.lines() { + let Some((key, value)) = line.split_once(':') else { continue }; + let key = key.trim().trim_matches(['"', '\'', '{', ',', ' '].as_ref()); + let value = value.trim().trim_matches(['"', '\'', ',', ' '].as_ref()); + match key { + "service_cluster" => self.service_cluster = value.to_string(), + "service_authority" => self.service_authority = value.to_string(), + "rate_limit_path" => self.rate_limit_path = value.to_string(), + "enqueue_path" => self.enqueue_path = value.to_string(), + "wait_path" => self.wait_path = value.to_string(), + "service_timeout_ms" => self.service_timeout_ms = value.parse().unwrap_or(self.service_timeout_ms), + "require_policy" => self.require_policy = value.parse().unwrap_or(self.require_policy), + "policy_header" => self.policy_header = value.to_string(), + "tenant_header" => self.tenant_header = value.to_string(), + "model_header" => self.model_header = value.to_string(), + "priority_header" => self.priority_header = value.to_string(), + "default_policy" => self.default_policy = Some(value.to_string()), + "priority_enabled" => self.priority_enabled = value.parse().unwrap_or(self.priority_enabled), + "default_priority" => self.default_priority = value.to_string(), + "high_priority_models" => self.high_priority_models = parse_csv(value), + "low_priority_models" => self.low_priority_models = parse_csv(value), + "high_priority_tenants" => self.high_priority_tenants = parse_csv(value), + "low_priority_tenants" => self.low_priority_tenants = parse_csv(value), + _ => {} + } + } + } + + fn normalized(mut self) -> Self { + self.policy_header = normalize_header_name(&self.policy_header, "x-ratelimit-policy"); + self.tenant_header = normalize_header_name(&self.tenant_header, "x-tenant-id"); + self.model_header = normalize_header_name(&self.model_header, "x-model"); + self.priority_header = normalize_header_name(&self.priority_header, "x-queue-priority"); + self.default_priority = normalize_priority(&self.default_priority).unwrap_or_else(|| "normal".to_string()); + self.default_policy = self.default_policy.and_then(|value| normalize_policy(&value)); + self + } +} + +#[derive(Default)] +struct AiGatewayRoot { + cfg: AiGatewayConfig, +} + +impl Context for AiGatewayRoot {} + +impl RootContext for AiGatewayRoot { + fn on_vm_start(&mut self, _: usize) -> bool { + let _ = hostcalls::log(LogLevel::Info, "ai-gateway-queue wasm plugin started"); + true + } + + fn on_configure(&mut self, _: usize) -> bool { + let raw = self.get_plugin_configuration().unwrap_or_default(); + self.cfg = AiGatewayConfig::parse(&raw); + true + } + + fn create_http_context(&self, _: u32) -> Option> { + Some(Box::new(AiGatewayHttp { + cfg: self.cfg.clone(), + pending: None, + rate_limited_enqueue: None, + body_pending: false, + })) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum Policy { + Abandon, + Queue, + Wait, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum Pending { + RateLimit { policy: Policy }, + Queue, + Wait, +} + +struct AiGatewayHttp { + cfg: AiGatewayConfig, + pending: Option<(u32, Pending)>, + /// 限流拒绝后等待 request body 再入队(queue / wait)。 + rate_limited_enqueue: Option, + /// 请求体尚未读完,需在 body 回调中继续处理。 + body_pending: bool, +} + +impl Context for AiGatewayHttp { + fn on_http_call_response(&mut self, token_id: u32, _num_headers: usize, body_size: usize, _num_trailers: usize) { + let Some((pending_token, pending)) = self.pending else { + return; + }; + if token_id != pending_token { + return; + } + self.pending = None; + + match pending { + Pending::RateLimit { policy } => self.handle_rate_limit_response(body_size, policy), + Pending::Queue | Pending::Wait => self.forward_service_response(body_size), + } + } +} + +impl HttpContext for AiGatewayHttp { + fn on_http_request_headers(&mut self, _: usize, end_of_stream: bool) -> Action { + let Some(policy) = self.request_policy() else { + // 设计文档要求 Policy 必填;仅允许 default_policy 兜底,禁止无策略 bypass。 + self.send_json(400, r#"{"error":"missing_or_invalid_rate_limit_policy"}"#); + return Action::Pause; + }; + + if self.tenant_id().is_none() { + self.send_json(400, r#"{"error":"missing_x_tenant_id"}"#); + return Action::Pause; + } + + // 三种策略统一先走令牌桶(DOC-01/02 定稿)。 + self.body_pending = !end_of_stream && matches!(policy, Policy::Queue | Policy::Wait); + if self.dispatch_service_call(Pending::RateLimit { policy }, &self.cfg.rate_limit_path.clone(), None) { + Action::Pause + } else { + self.send_json(502, r#"{"error":"rate_limit_service_unavailable"}"#); + Action::Pause + } + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + if !self.body_pending && self.rate_limited_enqueue.is_none() { + return Action::Continue; + } + if !end_of_stream { + return Action::Pause; + } + self.body_pending = false; + + let Some(policy) = self.rate_limited_enqueue.take() else { + return Action::Continue; + }; + let body = self.get_http_request_body(0, body_size).unwrap_or_default(); + self.dispatch_enqueue(policy, &body); + Action::Pause + } +} + +impl AiGatewayHttp { + fn request_policy(&self) -> Option { + let value = self.get_http_request_header(&self.cfg.policy_header).or_else(|| self.cfg.default_policy.clone())?; + match normalize_policy(&value).as_deref() { + Some("abandon") => Some(Policy::Abandon), + Some("queue") => Some(Policy::Queue), + Some("wait") => Some(Policy::Wait), + _ => None, + } + } + + fn dispatch_service_call(&mut self, pending: Pending, path: &str, body: Option<&[u8]>) -> bool { + let headers = self.service_headers(path); + let refs = headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect::>(); + match self.dispatch_http_call(&self.cfg.service_cluster, refs, body, vec![], Duration::from_millis(self.cfg.service_timeout_ms)) { + Ok(token) => { + self.pending = Some((token, pending)); + true + } + Err(status) => { + let _ = hostcalls::log(LogLevel::Warn, &format!("dispatch service call failed: {status:?}")); + false + } + } + } + + fn dispatch_enqueue(&mut self, policy: Policy, body: &[u8]) { + let pending = if policy == Policy::Queue { Pending::Queue } else { Pending::Wait }; + let path = if policy == Policy::Queue { + self.cfg.enqueue_path.clone() + } else { + self.cfg.wait_path.clone() + }; + if !self.dispatch_service_call(pending, &path, Some(body)) { + self.send_json(502, r#"{"error":"queue_service_unavailable"}"#); + } + } + + fn service_headers(&self, path: &str) -> Vec<(String, String)> { + let policy = self.request_policy().map(policy_name).unwrap_or("abandon").to_string(); + let tenant_id = self.tenant_id().unwrap_or_default(); + let model = self.model().unwrap_or_else(|| "default".to_string()); + let priority = self.queue_priority(&tenant_id, &model); + let mut out = vec![ + (":method".to_string(), "POST".to_string()), + (":path".to_string(), path.to_string()), + (":authority".to_string(), self.cfg.service_authority.clone()), + ( + "x-original-method".to_string(), + self.get_http_request_header(":method").unwrap_or_else(|| "POST".to_string()), + ), + ("x-original-path".to_string(), self.get_http_request_header(":path").unwrap_or_else(|| "/".to_string())), + ("x-ratelimit-policy".to_string(), policy), + ("x-tenant-id".to_string(), tenant_id), + ("x-model".to_string(), model), + ]; + if let Some(priority) = priority { + out.push(("x-queue-priority".to_string(), priority)); + } + + for (name, value) in self.get_http_request_headers() { + if should_forward_to_service(&name) { + out.push((name, value)); + } + } + out + } + + fn tenant_id(&self) -> Option { + self.get_http_request_header(&self.cfg.tenant_header).filter(|value| !value.trim().is_empty()) + } + + fn model(&self) -> Option { + self.get_http_request_header(&self.cfg.model_header).filter(|value| !value.trim().is_empty()) + } + + fn queue_priority(&self, tenant_id: &str, model: &str) -> Option { + if !self.cfg.priority_enabled { + return None; + } + if let Some(priority) = self.get_http_request_header(&self.cfg.priority_header).and_then(|value| normalize_priority(&value)) { + return Some(priority); + } + if contains_value(&self.cfg.high_priority_tenants, tenant_id) || contains_value(&self.cfg.high_priority_models, model) { + return Some("high".to_string()); + } + if contains_value(&self.cfg.low_priority_tenants, tenant_id) || contains_value(&self.cfg.low_priority_models, model) { + return Some("low".to_string()); + } + normalize_priority(&self.cfg.default_priority) + } + + fn handle_rate_limit_response(&mut self, body_size: usize, policy: Policy) { + let status = self.service_status(); + let body = self.get_http_call_response_body(0, body_size).unwrap_or_default(); + let text = String::from_utf8_lossy(&body); + if status == 200 && contains_allowed_true(&text) { + // 配额内:三种策略均直通上游。 + self.resume_http_request(); + return; + } + if status == 200 { + match policy { + Policy::Abandon => self.send_rate_limited_response(&text), + Policy::Queue | Policy::Wait => { + if self.body_pending { + self.rate_limited_enqueue = Some(policy); + } else { + self.dispatch_enqueue(policy, &[]); + } + } + } + return; + } + self.send_json(502, r#"{"error":"rate_limit_service_error"}"#); + } + + fn send_rate_limited_response(&self, text: &str) { + let retry_after_ms = extract_json_number(text, "retry_after_ms").unwrap_or(1000); + let retry_after_secs = ((retry_after_ms + 999) / 1000).max(1).to_string(); + let response_body = format!(r#"{{"error":"rate_limited","retry_after_ms":{retry_after_ms}}}"#); + let headers = [ + ("content-type".to_string(), "application/json".to_string()), + ("retry-after".to_string(), retry_after_secs), + ("x-ratelimit-retry-after-ms".to_string(), retry_after_ms.to_string()), + ]; + let headers = headers.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect::>(); + self.send_http_response(429, headers, Some(response_body.as_bytes())); + } + + fn forward_service_response(&mut self, body_size: usize) { + let status = self.service_status(); + let body = self.get_http_call_response_body(0, body_size).unwrap_or_default(); + let header_storage = self.response_headers_for_client(); + let headers = header_storage.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect::>(); + self.send_http_response(status as u32, headers, Some(&body)); + } + + fn response_headers_for_client(&self) -> Vec<(String, String)> { + let mut out = Vec::new(); + for name in ["content-type", "x-job-id", "x-queue-wait-ms", "x-gateway-job-id", "retry-after", "location"] { + if let Some(value) = self.get_http_call_response_header(name) { + out.push((name.to_string(), value)); + } + } + out + } + + fn service_status(&self) -> u16 { + self.get_http_call_response_header(":status").and_then(|v| v.parse().ok()).unwrap_or(502) + } + + fn send_json(&self, status: u32, body: &str) { + self.send_http_response(status, vec![("content-type", "application/json")], Some(body.as_bytes())); + } +} + +fn should_forward_to_service(name: &str) -> bool { + let lower = name.to_ascii_lowercase(); + !lower.starts_with(':') + && !matches!( + lower.as_str(), + "host" + | "connection" + | "content-length" + | "transfer-encoding" + | "x-original-method" + | "x-original-path" + | "x-ratelimit-policy" + | "x-tenant-id" + | "x-model" + | "x-queue-priority" + ) +} + +fn policy_name(policy: Policy) -> &'static str { + match policy { + Policy::Abandon => "abandon", + Policy::Queue => "queue", + Policy::Wait => "wait", + } +} + +fn contains_value(values: &[String], needle: &str) -> bool { + values.iter().any(|value| value.eq_ignore_ascii_case(needle)) +} + +fn value_at<'a>(value: &'a Value, path: &[&str]) -> Option<&'a Value> { + let mut current = value; + for key in path { + current = current.get(*key)?; + } + Some(current) +} + +fn string_at(value: &Value, path: &[&str]) -> Option { + value_at(value, path).and_then(|value| value.as_str().map(ToOwned::to_owned)) +} + +fn set_string(value: &Value, path: &[&str], target: &mut String) { + if let Some(value) = string_at(value, path).filter(|value| !value.trim().is_empty()) { + *target = value; + } +} + +fn set_u64(value: &Value, path: &[&str], target: &mut u64) { + if let Some(value) = value_at(value, path).and_then(|value| value.as_u64()) { + *target = value; + } +} + +fn set_bool(value: &Value, path: &[&str], target: &mut bool) { + if let Some(value) = value_at(value, path).and_then(|value| value.as_bool()) { + *target = value; + } +} + +fn string_vec_at(value: &Value, path: &[&str]) -> Option> { + let value = value_at(value, path)?; + if let Some(raw) = value.as_str() { + return Some(parse_csv(raw)); + } + let values = value.as_array()?; + Some(values.iter().filter_map(|value| value.as_str().map(ToOwned::to_owned)).collect()) +} + +fn parse_csv(value: &str) -> Vec { + value.split(',').map(str::trim).filter(|value| !value.is_empty()).map(ToOwned::to_owned).collect() +} + +fn normalize_header_name(value: &str, fallback: &str) -> String { + let value = value.trim().to_ascii_lowercase(); + if value.is_empty() || value.starts_with(':') { + fallback.to_string() + } else { + value + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_nested_json_config() { + let cfg = AiGatewayConfig::parse( + br#"{ + "service": {"cluster": "svc", "authority": "svc.local", "timeout_ms": 1200}, + "paths": {"rate_limit": "/rl", "enqueue": "/q", "wait": "/w"}, + "headers": {"policy": "X-Policy", "tenant": "X-Org", "model": "X-LLM", "priority": "X-Priority"}, + "policies": {"require": false, "default": "queue"}, + "priority": {"default": "low", "high_models": ["gpt-4"], "low_tenants": "free,basic"} + }"#, + ); + + assert_eq!(cfg.service_cluster, "svc"); + assert_eq!(cfg.service_timeout_ms, 1200); + assert_eq!(cfg.rate_limit_path, "/rl"); + assert_eq!(cfg.policy_header, "x-policy"); + assert!(!cfg.require_policy); + assert_eq!(cfg.default_policy.as_deref(), Some("queue")); + assert_eq!(cfg.default_priority, "low"); + assert_eq!(cfg.high_priority_models, vec!["gpt-4"]); + assert_eq!(cfg.low_priority_tenants, vec!["free", "basic"]); + } + + #[test] + fn parses_legacy_config_lines() { + let cfg = AiGatewayConfig::parse( + br#" +service_cluster: ai-gateway +service_timeout_ms: 3000 +tenant_header: X-Org +default_policy: wait +high_priority_models: qwen-max, deepseek-chat +"#, + ); + + assert_eq!(cfg.service_cluster, "ai-gateway"); + assert_eq!(cfg.service_timeout_ms, 3000); + assert_eq!(cfg.tenant_header, "x-org"); + assert_eq!(cfg.default_policy.as_deref(), Some("wait")); + assert_eq!(cfg.high_priority_models, vec!["qwen-max", "deepseek-chat"]); + } +} diff --git a/plugins/wasm/ai-gateway-queue/src/policy.rs b/plugins/wasm/ai-gateway-queue/src/policy.rs new file mode 100644 index 00000000..d57ba364 --- /dev/null +++ b/plugins/wasm/ai-gateway-queue/src/policy.rs @@ -0,0 +1,68 @@ +//! 可在 host 侧单测的策略/JSON 纯逻辑(TC-GW / TC-HDR 相关)。 + +/// 规范化 X-RateLimit-Policy 取值。 +pub fn normalize_policy(value: &str) -> Option { + match value.trim().to_ascii_lowercase().as_str() { + "abandon" => Some("abandon".to_string()), + "queue" => Some("queue".to_string()), + "wait" => Some("wait".to_string()), + _ => None, + } +} + +/// 规范化队列优先级 header。 +pub fn normalize_priority(value: &str) -> Option { + match value.trim().to_ascii_lowercase().as_str() { + "high" => Some("high".to_string()), + "normal" | "default" | "medium" => Some("normal".to_string()), + "low" => Some("low".to_string()), + _ => None, + } +} + +/// 解析限流 check 响应中的 allowed=true。 +pub fn contains_allowed_true(text: &str) -> bool { + text.contains(r#""allowed":true"#) || text.contains(r#""allowed": true"#) +} + +/// 从 JSON 文本提取数字字段(如 retry_after_ms)。 +pub fn extract_json_number(text: &str, key: &str) -> Option { + let quoted = format!("\"{key}\""); + for needle in [format!("{quoted}:"), format!("{quoted}:\"")] { + let Some(pos) = text.find(&needle) else { continue }; + let digits = text[pos + needle.len()..] + .chars() + .skip_while(|c| c.is_whitespace() || *c == '"') + .take_while(|c| c.is_ascii_digit()) + .collect::(); + if let Ok(v) = digits.parse() { + return Some(v); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tc_hdr_02_rejects_invalid_policy() { + assert!(normalize_policy("invalid").is_none()); + assert_eq!(normalize_policy("QUEUE").as_deref(), Some("queue")); + } + + #[test] + fn tc_gw_rate_limit_response_parsing() { + assert!(contains_allowed_true(r#"{"allowed":true,"retry_after_ms":0}"#)); + assert!(!contains_allowed_true(r#"{"allowed":false}"#)); + assert_eq!(extract_json_number(r#"{"retry_after_ms":3000}"#, "retry_after_ms"), Some(3000)); + } + + #[test] + fn normalize_priority_values() { + assert_eq!(normalize_priority("HIGH").as_deref(), Some("high")); + assert_eq!(normalize_priority("medium").as_deref(), Some("normal")); + assert!(normalize_priority("urgent").is_none()); + } +} diff --git a/plugins/wasm/hello-world/Cargo.toml b/plugins/wasm/hello-world/Cargo.toml new file mode 100644 index 00000000..9489d84b --- /dev/null +++ b/plugins/wasm/hello-world/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "spacegate_plugin_hello_world" +version.workspace = true +edition.workspace = true +publish.workspace = true +description = "Hello World Proxy-Wasm plugin template for Spacegate." + +[lib] +crate-type = ["cdylib"] + +[dependencies] +proxy-wasm.workspace = true + diff --git a/plugins/wasm/hello-world/plugin.yaml b/plugins/wasm/hello-world/plugin.yaml new file mode 100644 index 00000000..68c87c7a --- /dev/null +++ b/plugins/wasm/hello-world/plugin.yaml @@ -0,0 +1,13 @@ +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: hello-world + namespace: spacegate +spec: + url: oci://registry.example.com/spacegate/plugins/hello-world:v1 + pluginName: hello-world + phase: AUTHN + priority: 100 + failStrategy: FAIL_OPEN + defaultConfig: + message: hello world diff --git a/plugins/wasm/hello-world/src/lib.rs b/plugins/wasm/hello-world/src/lib.rs new file mode 100644 index 00000000..af93d62e --- /dev/null +++ b/plugins/wasm/hello-world/src/lib.rs @@ -0,0 +1,50 @@ +use proxy_wasm::hostcalls; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Info); + proxy_wasm::set_root_context(|_| -> Box { Box::new(HelloWorldRoot) }); +}} + +struct HelloWorldRoot; + +impl Context for HelloWorldRoot {} + +impl RootContext for HelloWorldRoot { + fn on_vm_start(&mut self, _: usize) -> bool { + let _ = hostcalls::log(LogLevel::Info, "hello world wasm plugin started"); + true + } + + fn on_configure(&mut self, _: usize) -> bool { + let _ = hostcalls::log(LogLevel::Info, "hello world wasm plugin configured"); + true + } + + fn create_http_context(&self, _: u32) -> Option> { + Some(Box::new(HelloWorldHttp)) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} + +struct HelloWorldHttp; + +impl Context for HelloWorldHttp {} + +impl HttpContext for HelloWorldHttp { + fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { + let _ = hostcalls::log(LogLevel::Info, "hello world request reached wasm plugin"); + self.add_http_request_header("x-spacegate-wasm-plugin", "hello-world"); + + if self.get_http_request_header(":path").as_deref() == Some("/hello-world") { + self.send_http_response(200, vec![("content-type", "text/plain"), ("x-powered-by", "spacegate-wasm")], Some(b"hello world\n")); + return Action::Pause; + } + + Action::Continue + } +} diff --git a/resource/ai-gateway-demo/config.json b/resource/ai-gateway-demo/config.json new file mode 100644 index 00000000..6195e0f0 --- /dev/null +++ b/resource/ai-gateway-demo/config.json @@ -0,0 +1,5 @@ +{ + "gateways": {}, + "plugins": {}, + "api_port": 19880 +} diff --git a/resource/ai-gateway-demo/gateway/ai-demo/config.json b/resource/ai-gateway-demo/gateway/ai-demo/config.json new file mode 100644 index 00000000..844ec103 --- /dev/null +++ b/resource/ai-gateway-demo/gateway/ai-demo/config.json @@ -0,0 +1,23 @@ +{ + "gateway": { + "name": "ai-demo", + "parameters": {}, + "listeners": [ + { + "name": "http", + "ip": "127.0.0.1", + "port": 9993, + "protocol": { + "type": "http" + } + } + ], + "plugins": [ + { + "code": "wasm", + "kind": "named", + "name": "ai-gateway-queue" + } + ] + } +} diff --git a/resource/ai-gateway-demo/gateway/ai-demo/route/ai.json b/resource/ai-gateway-demo/gateway/ai-demo/route/ai.json new file mode 100644 index 00000000..061381cc --- /dev/null +++ b/resource/ai-gateway-demo/gateway/ai-demo/route/ai.json @@ -0,0 +1,33 @@ +{ + "route_name": "ai", + "rules": [ + { + "matches": [ + { + "path": { + "kind": "Prefix", + "value": "/v1/" + } + } + ], + "plugins": [ + { + "code": "wasm", + "kind": "named", + "name": "ai-gateway-queue" + } + ], + "backends": [ + { + "host": { + "kind": "Host", + "host": "127.0.0.1" + }, + "port": 9000, + "weight": 1 + } + ] + } + ], + "priority": 0 +} diff --git a/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json b/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json new file mode 100644 index 00000000..977ff1ec --- /dev/null +++ b/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.json @@ -0,0 +1,29 @@ +{ + "url": "file:///Users/sh.zhang/Workspace/huayun/jiyan/ai-gateway-dev/spacegate/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm", + "validate_on_create": false, + "fail_strategy": "fail_close", + "plugin_name": "ai-gateway-queue", + "plugin_root_id": "ai-gateway-queue-root", + "plugin_vm_id": "ai-gateway-queue-vm", + "vm_pool_size": 4, + "wait_vm_pool_size": 4, + "limits": { + "max_memory_pages": 64, + "fuel_per_call": 20000000, + "epoch_timeout_millis": 50, + "max_body_bytes": 33554432, + "max_pending_calls": 1 + }, + "plugin_config": { + "service_cluster": "ai-gateway-service", + "service_authority": "ai-gateway-service", + "rate_limit_path": "/v1/ratelimit/check", + "enqueue_path": "/v1/queue/enqueue", + "wait_path": "/v1/queue/enqueue-and-wait", + "service_timeout_ms": 65000, + "require_policy": true + }, + "clusters": { + "ai-gateway-service": "http://127.0.0.1:18080" + } +} diff --git a/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.local.json b/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.local.json new file mode 100644 index 00000000..977ff1ec --- /dev/null +++ b/resource/ai-gateway-demo/plugin/wasm.ai-gateway-queue.local.json @@ -0,0 +1,29 @@ +{ + "url": "file:///Users/sh.zhang/Workspace/huayun/jiyan/ai-gateway-dev/spacegate/plugins/wasm/target/wasm32-wasip1/release/spacegate_plugin_ai_gateway_queue.wasm", + "validate_on_create": false, + "fail_strategy": "fail_close", + "plugin_name": "ai-gateway-queue", + "plugin_root_id": "ai-gateway-queue-root", + "plugin_vm_id": "ai-gateway-queue-vm", + "vm_pool_size": 4, + "wait_vm_pool_size": 4, + "limits": { + "max_memory_pages": 64, + "fuel_per_call": 20000000, + "epoch_timeout_millis": 50, + "max_body_bytes": 33554432, + "max_pending_calls": 1 + }, + "plugin_config": { + "service_cluster": "ai-gateway-service", + "service_authority": "ai-gateway-service", + "rate_limit_path": "/v1/ratelimit/check", + "enqueue_path": "/v1/queue/enqueue", + "wait_path": "/v1/queue/enqueue-and-wait", + "service_timeout_ms": 65000, + "require_policy": true + }, + "clusters": { + "ai-gateway-service": "http://127.0.0.1:18080" + } +} diff --git a/resource/docker/Dockerfile.admin-server b/resource/docker/Dockerfile.admin-server new file mode 100644 index 00000000..4a598229 --- /dev/null +++ b/resource/docker/Dockerfile.admin-server @@ -0,0 +1,14 @@ +# 仅重建 spacegate-admin-server(含插件增量写入修复) +FROM rust:1.88-bookworm AS builder +WORKDIR /app +COPY . . +RUN cargo build --release -p spacegate-admin-server + +FROM debian:bookworm-slim +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates tini \ + && rm -rf /var/lib/apt/lists/* +COPY --from=builder /app/target/release/spacegate-admin-server /usr/local/bin/spacegate-admin-server +ENV CONFIG=file:/etc/spacegate RUST_LOG=info +EXPOSE 19992 +ENTRYPOINT ["/usr/bin/tini", "--", "/usr/local/bin/spacegate-admin-server"] +CMD ["-c", "file:/etc/spacegate", "-p", "19992", "-H", "0.0.0.0"] diff --git a/resource/hai-wasm-demo/config.json b/resource/hai-wasm-demo/config.json new file mode 100644 index 00000000..20d4e030 --- /dev/null +++ b/resource/hai-wasm-demo/config.json @@ -0,0 +1,5 @@ +{ + "gateways": {}, + "plugins": {}, + "api_port": 19876 +} diff --git a/resource/hai-wasm-demo/gateway/hai-demo/config.json b/resource/hai-wasm-demo/gateway/hai-demo/config.json new file mode 100644 index 00000000..c0b15c81 --- /dev/null +++ b/resource/hai-wasm-demo/gateway/hai-demo/config.json @@ -0,0 +1,16 @@ +{ + "gateway": { + "name": "hai-demo", + "parameters": {}, + "listeners": [ + { + "name": "http", + "ip": "127.0.0.1", + "port": 18080, + "protocol": { + "type": "http" + } + } + ] + } +} diff --git a/resource/hai-wasm-demo/gateway/hai-demo/route/demo.json b/resource/hai-wasm-demo/gateway/hai-demo/route/demo.json new file mode 100644 index 00000000..17448cef --- /dev/null +++ b/resource/hai-wasm-demo/gateway/hai-demo/route/demo.json @@ -0,0 +1,33 @@ +{ + "route_name": "demo", + "rules": [ + { + "matches": [ + { + "path": { + "kind": "Prefix", + "value": "/" + } + } + ], + "plugins": [ + { + "code": "wasm", + "kind": "named", + "name": "hai-mix" + } + ], + "backends": [ + { + "host": { + "kind": "Host", + "host": "127.0.0.1" + }, + "port": 18099, + "weight": 1 + } + ] + } + ], + "priority": 0 +} diff --git a/resource/hai-wasm-demo/mock_backends.py b/resource/hai-wasm-demo/mock_backends.py new file mode 100644 index 00000000..3a36bc26 --- /dev/null +++ b/resource/hai-wasm-demo/mock_backends.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""一个进程内启动三个 mock HTTP 服务,配合 hai-process-mix.wasm 联调: + +- 18091 ac-service:API Key 鉴权(返回 ApiKeyRecord JSON) +- 18092 asset-service:资产查询(返回 AssetRecord JSON) +- 18099 upstream-echo:扮演 hai-gw-server / 任意业务上游,回声请求头与方法 + +每个服务都在独立线程内跑标准库 http.server.HTTPServer,无第三方依赖。 +启动方式:python3 mock_backends.py +""" + +import json +import threading +from datetime import datetime, timedelta, timezone +from http.server import BaseHTTPRequestHandler, HTTPServer + + +# 简单的 API Key → ApiKeyRecord 字典(按 demo 需要可继续扩) +API_KEYS = { + "demo-key": { + "app_id": "demo-app", + # 包含 demo-asset 在内,hai 才会放行 + "asset_ids": ["demo-asset"], + "allow_ips": [], + "deny_ips": [], + "allow_mac_addrs": [], + "deny_mac_addrs": [], + # ISO 8601 UTC,预留够久 + "expired_at": (datetime.now(tz=timezone.utc) + timedelta(days=3650)).strftime("%Y-%m-%dT%H:%M:%SZ"), + } +} + +ASSETS = { + "demo-asset": { + "asset_id": "demo-asset", + "asset_type": "tool", + "asset_status": "published", + # 让 hai 走"分支 A 转发":写 Hai-Upstream-URL 等头,路由到 backend + "runtime_endpoint": "http://upstream-echo.demo/echo", + "runtime_endpoint_method": ["POST"], + "asset_content": None, + "asset_url": None, + "max_concurrent": 16, + "timeout_sec": 30, + "qps_limit": 100, + "asset_secret_params": [], + "asset_secret_values": {}, + "allowed_output_targets": [], + } +} + + +class AcHandler(BaseHTTPRequestHandler): + # 静默日志(避免刷屏,可改成 print 来调试) + def log_message(self, format, *args): + return + + def do_GET(self): + if self.path != "/ai-agent/internal/v1/ac/auth": + self.send_response(404) + self.end_headers() + return + # hai 把 API Key 放到 hai-api-key 请求头 + api_key = self.headers.get("hai-api-key") or self.headers.get("Hai-Api-Key") + rec = API_KEYS.get((api_key or "").strip()) + if not rec: + self.send_response(401) + self.send_header("Content-Type", "application/json") + body = json.dumps({"code": "invalid_api_key", "message": "unknown key"}).encode() + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + return + body = json.dumps(rec).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + +class AssetHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args): + return + + def do_GET(self): + # 路径形如 /ai-agent/internal/v1/am/assets/ + prefix = "/ai-agent/internal/v1/am/assets/" + if not self.path.startswith(prefix): + self.send_response(404) + self.end_headers() + return + asset_id = self.path[len(prefix):].split("?", 1)[0].split("/", 1)[0] + rec = ASSETS.get(asset_id) + if not rec: + self.send_response(404) + self.end_headers() + return + body = json.dumps(rec).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + +class EchoHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args): + return + + def _echo(self): + # 回声:把请求头(重点是 Hai-* / x-*)放在响应 JSON 内,便于断言注入 + seen_headers = {k.lower(): v for k, v in self.headers.items()} + payload = { + "method": self.command, + "path": self.path, + "headers": seen_headers, + } + body = json.dumps(payload).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("X-Upstream-Echo", "ok") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_GET(self): + self._echo() + + def do_POST(self): + # 排掉 body(不影响 echo) + _ = self.rfile.read(int(self.headers.get("Content-Length") or 0) or 0) + self._echo() + + +def serve(port, handler): + httpd = HTTPServer(("127.0.0.1", port), handler) + print(f"[mock] listen 127.0.0.1:{port} ({handler.__name__})") + httpd.serve_forever() + + +def main(): + threads = [ + threading.Thread(target=serve, args=(18091, AcHandler), daemon=True), + threading.Thread(target=serve, args=(18092, AssetHandler), daemon=True), + threading.Thread(target=serve, args=(18099, EchoHandler), daemon=True), + ] + for t in threads: + t.start() + print("[mock] all three services up; Ctrl-C to stop") + for t in threads: + t.join() + + +if __name__ == "__main__": + main() diff --git a/resource/hai-wasm-demo/plugin/wasm.hai-mix.json b/resource/hai-wasm-demo/plugin/wasm.hai-mix.json new file mode 100644 index 00000000..e898942d --- /dev/null +++ b/resource/hai-wasm-demo/plugin/wasm.hai-mix.json @@ -0,0 +1,18 @@ +{ + "url": "/Users/sh.zhang/Workspace/huayun/jiyan/gateway/spacegate/resource/wasm/hai_process_mix.wasm", + "validate_on_create": false, + "fail_strategy": "fail_close", + "plugin_config": { + "ac_service_host": "ac-service.static", + "ac_service_port": 18091, + "asset_service_host": "asset-service.static", + "asset_service_port": 18092, + "ac_auth_path": "/ai-agent/internal/v1/ac/auth", + "asset_lookup_path": "/ai-agent/internal/v1/am/assets/{asset_id}", + "model_error_keywords": ["error", "ERR_", "FAILED"] + }, + "clusters": { + "outbound|18091||ac-service.static": "http://127.0.0.1:18091", + "outbound|18092||asset-service.static": "http://127.0.0.1:18092" + } +} diff --git a/resource/kube-manifests/higress-wasmplugin-crd.yaml b/resource/kube-manifests/higress-wasmplugin-crd.yaml new file mode 100644 index 00000000..acd09c3d --- /dev/null +++ b/resource/kube-manifests/higress-wasmplugin-crd.yaml @@ -0,0 +1,81 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + name: wasmplugins.extensions.higress.io +spec: + group: extensions.higress.io + scope: Namespaced + names: + plural: wasmplugins + singular: wasmplugin + kind: WasmPlugin + listKind: WasmPluginList + versions: + - name: v1alpha1 + served: true + storage: true + subresources: + status: {} + schema: + openAPIV3Schema: + type: object + properties: + spec: + type: object + required: + - url + properties: + url: + type: string + pluginName: + type: string + sha256: + type: string + phase: + type: string + priority: + type: integer + format: int32 + imagePullPolicy: + type: string + imagePullSecret: + type: string + defaultConfigDisable: + type: boolean + failStrategy: + type: string + defaultConfig: + x-kubernetes-preserve-unknown-fields: true + matchRules: + type: array + items: + type: object + properties: + ingress: + type: array + items: + type: string + domain: + type: array + items: + type: string + service: + type: array + items: + type: string + configDisable: + type: boolean + config: + x-kubernetes-preserve-unknown-fields: true + status: + type: object + properties: + observedGeneration: + type: integer + format: int64 + phase: + type: string + digest: + type: string + message: + type: string diff --git a/resource/kube-manifests/spacegate-admin-server.yaml b/resource/kube-manifests/spacegate-admin-server.yaml index 3666ae20..c106c180 100644 --- a/resource/kube-manifests/spacegate-admin-server.yaml +++ b/resource/kube-manifests/spacegate-admin-server.yaml @@ -83,6 +83,26 @@ rules: - list - watch - delete + - apiGroups: + - extensions.higress.io + resources: + - wasmplugins + verbs: + - get + - create + - update + - patch + - list + - watch + - delete + - apiGroups: + - extensions.higress.io + resources: + - wasmplugins/status + verbs: + - get + - update + - patch --- kind: ClusterRoleBinding apiVersion: rbac.authorization.k8s.io/v1 @@ -122,4 +142,4 @@ spec: hostPort: 9080 env: - name: CONFIG - value: k8s:spacegate \ No newline at end of file + value: k8s:spacegate diff --git a/resource/kube-manifests/spacegate-gateway.yaml b/resource/kube-manifests/spacegate-gateway.yaml index f0bab316..cac7e4c5 100644 --- a/resource/kube-manifests/spacegate-gateway.yaml +++ b/resource/kube-manifests/spacegate-gateway.yaml @@ -126,6 +126,21 @@ rules: - get - list - watch + - apiGroups: + - extensions.higress.io + resources: + - wasmplugins + verbs: + - get + - list + - watch + - apiGroups: + - extensions.higress.io + resources: + - wasmplugins/status + verbs: + - get + - update --- kind: RoleBinding apiVersion: rbac.authorization.k8s.io/v1 diff --git a/resource/kube-manifests/wasmplugin-hello-example.yaml b/resource/kube-manifests/wasmplugin-hello-example.yaml new file mode 100644 index 00000000..5e3ffeb9 --- /dev/null +++ b/resource/kube-manifests/wasmplugin-hello-example.yaml @@ -0,0 +1,19 @@ +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: hello-world + namespace: spacegate +spec: + url: https://example.com/plugins/spacegate_wasm_hello.wasm + sha256: sha256:6b9dacbcbf5a2d9de9795737aeecd434dc6b261476486803419e9d62084e651c + pluginName: hello-world + phase: AUTHN + priority: 100 + failStrategy: FAIL_CLOSE + defaultConfig: + message: hello world + matchRules: + - domain: + - api.example.com + config: + message: hello api diff --git a/resource/wasm-hello-demo/config.json b/resource/wasm-hello-demo/config.json new file mode 100644 index 00000000..20d4e030 --- /dev/null +++ b/resource/wasm-hello-demo/config.json @@ -0,0 +1,5 @@ +{ + "gateways": {}, + "plugins": {}, + "api_port": 19876 +} diff --git a/resource/wasm-hello-demo/gateway/wasm-hello/config.json b/resource/wasm-hello-demo/gateway/wasm-hello/config.json new file mode 100644 index 00000000..a751fe77 --- /dev/null +++ b/resource/wasm-hello-demo/gateway/wasm-hello/config.json @@ -0,0 +1,23 @@ +{ + "gateway": { + "name": "wasm-hello", + "parameters": {}, + "listeners": [ + { + "name": "http", + "ip": "127.0.0.1", + "port": 18082, + "protocol": { + "type": "http" + } + } + ], + "plugins": [ + { + "code": "wasm", + "kind": "named", + "name": "hello-world" + } + ] + } +} diff --git a/resource/wasm-hello-demo/gateway/wasm-hello/route/hello.json b/resource/wasm-hello-demo/gateway/wasm-hello/route/hello.json new file mode 100644 index 00000000..78e75bb4 --- /dev/null +++ b/resource/wasm-hello-demo/gateway/wasm-hello/route/hello.json @@ -0,0 +1,26 @@ +{ + "route_name": "hello", + "rules": [ + { + "matches": [ + { + "path": { + "kind": "Prefix", + "value": "/" + } + } + ], + "backends": [ + { + "host": { + "kind": "Host", + "host": "127.0.0.1" + }, + "port": 18099, + "weight": 1 + } + ] + } + ], + "priority": 0 +} diff --git a/resource/wasm-hello-demo/plugin/wasm.hello-world.json b/resource/wasm-hello-demo/plugin/wasm.hello-world.json new file mode 100644 index 00000000..392941ee --- /dev/null +++ b/resource/wasm-hello-demo/plugin/wasm.hello-world.json @@ -0,0 +1,13 @@ +{ + "url": "resource/wasm/spacegate_wasm_hello.wasm", + "sha256": "sha256:6b9dacbcbf5a2d9de9795737aeecd434dc6b261476486803419e9d62084e651c", + "module_cache_key": "spacegate-wasm-hello:v1", + "use_cache": true, + "validate_on_create": false, + "fail_strategy": "fail_close", + "plugin_name": "hello-world", + "plugin_root_id": "hello-world-root", + "plugin_vm_id": "hello-world-vm", + "plugin_config": {}, + "clusters": {} +} diff --git a/resource/wasm/hai_process_mix.wasm b/resource/wasm/hai_process_mix.wasm new file mode 100755 index 00000000..ddf0fdff Binary files /dev/null and b/resource/wasm/hai_process_mix.wasm differ diff --git a/resource/wasm/spacegate_wasm_hello.wasm b/resource/wasm/spacegate_wasm_hello.wasm new file mode 100755 index 00000000..0841d0ea Binary files /dev/null and b/resource/wasm/spacegate_wasm_hello.wasm differ