diff --git a/Cargo.lock b/Cargo.lock index 6bde39e..8a4e740 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -521,9 +521,9 @@ checksum = "1e4b40c7323adcfc0a41c4b88143ed58346ff65a288fc144329c5c45e05d70c6" [[package]] name = "bitflags" -version = "2.11.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" [[package]] name = "bitstream-io" @@ -661,6 +661,31 @@ dependencies = [ "time", ] +[[package]] +name = "bon" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +dependencies = [ + "darling 0.23.0", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.117", +] + [[package]] name = "borsh" version = "1.6.1" @@ -785,44 +810,11 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" -[[package]] -name = "camino" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e629a66d692cb9ff1a1c664e41771b3dcaf961985a9774c0eb0bd1b51cf60a48" -dependencies = [ - "serde_core", -] - -[[package]] -name = "cargo-platform" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd0061da739915fae12ea00e16397555ed4371a6bb285431aab930f61b0aa4ba" -dependencies = [ - "serde", - "serde_core", -] - -[[package]] -name = "cargo_metadata" -version = "0.23.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef987d17b0a113becdd19d3d0022d04d7ef41f9efe4f3fb63ac44ba61df3ade9" -dependencies = [ - "camino", - "cargo-platform", - "semver", - "serde", - "serde_json", - "thiserror 2.0.18", -] - [[package]] name = "cc" -version = "1.2.62" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "jobserver", @@ -909,9 +901,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "js-sys", @@ -982,9 +974,9 @@ dependencies = [ [[package]] name = "cmov" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" +checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a" [[package]] name = "color_quant" @@ -1156,6 +1148,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-skiplist" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df29de440c58ca2cc6e587ec3d22347551a32435fbde9d2bff64e78a9ffa151b" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1938,6 +1940,7 @@ dependencies = [ "bytes", "chrono", "clap", + "crossbeam-skiplist", "foldhash 0.2.0", "futures", "futures-lite", @@ -2304,9 +2307,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.1" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -2358,9 +2361,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb92f162bf56536459fc83c79b974bb12837acfed43d6bc370a7916d0ae15ecc" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", @@ -2637,9 +2640,9 @@ dependencies = [ [[package]] name = "imgref" -version = "1.12.1" +version = "1.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fac9d56ed6437b198fddba683305e8e2d651aa42647f00f5ae542e7f5c94a2" +checksum = "89194689a993ab15268672e99e7b0e19da2da3268ac682e8f02d29d4d1434cd7" [[package]] name = "indexmap" @@ -2666,9 +2669,9 @@ dependencies = [ [[package]] name = "inotify" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199" +checksum = "533e68a5842e734946fe159fb03fc9bbbb254f590dd0d8ad321ae5ff7beca2c1" dependencies = [ "bitflags", "inotify-sys", @@ -2830,13 +2833,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.99" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" +checksum = "f2025f20d7a4fa7785846e7b63d10a76d3f1cee98ee5cb79ea59703f95e42162" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] @@ -2851,9 +2853,9 @@ dependencies = [ [[package]] name = "kqueue" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a" +checksum = "273c0752728918e0ac4976f2b275b6fefb9ecd400585dec929419f3844cd87b5" dependencies = [ "kqueue-sys", "libc", @@ -2898,9 +2900,9 @@ checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "libfuzzer-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f12a681b7dd8ce12bff52488013ba614b869148d54dd79836ab85aafdd53f08d" +checksum = "a9fd2f41a1cba099f79a0b6b6c35656cf7c03351a7bae8ff0f28f25270f929d2" dependencies = [ "arbitrary", "cc", @@ -2964,9 +2966,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.30" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "loop9" @@ -3086,9 +3088,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "log", @@ -3856,9 +3858,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +checksum = "528ac67416ff8646872a3c02cad9cc4ee5dc9f9540c9b10771855c95cb2e5ae1" dependencies = [ "bytes", "prost-derive", @@ -3866,9 +3868,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +checksum = "b570b25f7617e43d59005d0990ccb79e950a423952cea19671b7a876da390adf" dependencies = [ "anyhow", "itertools", @@ -3879,9 +3881,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +checksum = "f94967dc7688f3054c7fac87473ffae4cc4c3904800e2d9f5b857246d8963b0a" dependencies = [ "prost", ] @@ -4225,9 +4227,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -4248,9 +4250,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "rend" @@ -4483,9 +4485,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +checksum = "dab5152771c58876a2146916e53e35057e1a4dfa2b9df0f0305b07f611fdea4d" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -4790,9 +4792,9 @@ dependencies = [ [[package]] name = "scc" -version = "3.7.1" +version = "3.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bcd12b6caff5213cc3c03123cde8c3db5e413008a63b0c0ba35e6275825ea92" +checksum = "40ba4937978d960f5b3a970432ad59095f048fe5398812501afb416d4a1aef26" dependencies = [ "saa", "sdd", @@ -4913,10 +4915,6 @@ name = "semver" version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" -dependencies = [ - "serde", - "serde_core", -] [[package]] name = "serde" @@ -5017,9 +5015,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.20.0" +version = "3.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e72c1c2cb7b223fafb600a619537a871c2818583d619401b785e7c0b746ccde2" +checksum = "76a5c54c7310e7b8b9577c286d7e399ddd876c3e12b3ed917a8aabc4b96e9e8c" dependencies = [ "base64", "bs58", @@ -5037,9 +5035,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.20.0" +version = "3.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b90c488738ecb4fb0262f41f43bc40efc5868d9fb744319ddf5f5317f417bfac" +checksum = "84d57bc0c8b9a17920c178daa6bb924850d54a9c97ab45194bb8c17ad66bb660" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -5102,9 +5100,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "signal-hook-registry" @@ -5177,9 +5175,9 @@ checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -5892,9 +5890,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "ulid" @@ -6033,9 +6031,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -6071,14 +6069,12 @@ checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "vergen" -version = "9.1.0" +version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b849a1f6d8639e8de261e81ee0fc881e3e3620db1af9f2e0da015d4382ceaf75" +checksum = "7bdf18a54cf91b4d98a8e8b67f6321606539fbcdcac02536286ad1de37b53fd2" dependencies = [ "anyhow", - "cargo_metadata", - "derive_builder", - "regex", + "bon", "rustc_version", "rustversion", "time", @@ -6087,12 +6083,12 @@ dependencies = [ [[package]] name = "vergen-gitcl" -version = "9.1.0" +version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ff3b5300a085d6bcd8fc96a507f706a28ae3814693236c9b409db71a1d15b9" +checksum = "4961429ed12888cb3c6dd20f7dc9508c821091a3ba5fec0156ed5a654c1c4572" dependencies = [ "anyhow", - "derive_builder", + "bon", "rustversion", "time", "vergen", @@ -6101,12 +6097,12 @@ dependencies = [ [[package]] name = "vergen-lib" -version = "9.1.0" +version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34a29ba7e9c59e62f229ae1932fb1b8fb8a6fdcc99215a641913f5f5a59a569" +checksum = "910e8471e27130bbc019e9bfa6bda16dfc4c6dd7c5d0793da70a9256caeae984" dependencies = [ "anyhow", - "derive_builder", + "bon", "rustversion", ] @@ -6185,9 +6181,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" +checksum = "a254a4b10c19a76f09a27640e7ffbf9bc30bf67e16a3bf28aaefa4920fe81563" dependencies = [ "cfg-if", "once_cell", @@ -6198,9 +6194,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.72" +version = "0.4.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" +checksum = "54568702fabf5d4849ce2b90fadfa64168a097eaf4b351ce9df8b687a0086aaf" dependencies = [ "js-sys", "wasm-bindgen", @@ -6208,9 +6204,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" +checksum = "24a40fc75b0ec6f3746ceb10d36f53a93dcd68a93b11b6445983945d79eba0dc" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6218,9 +6214,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" +checksum = "908f34bd9b9ce3d4caf07b72dfab63d61504d156856c6bd3cd87fa350cf3985b" dependencies = [ "bumpalo", "proc-macro2", @@ -6231,9 +6227,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.122" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" +checksum = "7acbf7616c27b194bbb550bf77ed0c2c3e5b7fd1260a93082b95fb7f47959b92" dependencies = [ "unicode-ident", ] @@ -6287,9 +6283,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.99" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" +checksum = "6e0871acf327f283dc6da28a1696cdc64fb355ba9f935d052021fa77f35cce69" dependencies = [ "js-sys", "wasm-bindgen", @@ -6860,9 +6856,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -6883,18 +6879,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.49" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce33a6288fa3f072a8c2c7d0f2fdbb90e28298f0135c1f99b96c3db2efcc60b" +checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.49" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd425244944f4ab65ccff928e7323354c5a018c75838362fdce749dfad2ee1e" +checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 43cc378..40d4975 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ spz-lib = { git = "https://github.com/404-Repo/spz-rs.git", rev = "227c6cb6cdc25 "async", ] } multer = "3.1.0" -tokio = { version = "1.52.2", features = [ +tokio = { version = "1.52.3", features = [ "rt-multi-thread", "macros", "sync", @@ -43,14 +43,14 @@ async-trait = "0.1.89" quinn = { version = "0.11.9", features = ["rustls", "runtime-tokio", "log"] } h3 = { version = "0.0.8", features = ["tracing"] } h3-quinn = { version = "0.0.10" } -http = "1.4.1" +http = "1.4.2" backon = "1.6.0" openraft = { version = "0.9.24", features = ["storage-v2", "serde"] } tokio-postgres-rustls = "0.14.0" rustls = { version = "0.23.40" } rustls-platform-verifier = "0.7.0" toml = "1.1.2" -uuid = { version = "1.23.0", features = ["v4", "serde"] } +uuid = { version = "1.23.3", features = ["v4", "serde"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.150" rmp-serde = "1.3.1" @@ -59,7 +59,7 @@ bytes = "1.11.1" rcgen = { version = "0.14.8", features = ["pem"] } foldhash = "0.2.0" portable-atomic = "1.13.1" -scc = "3.7.1" +scc = "3.7.2" sdd = "4.8.6" schnorrkel = "0.11.5" base64-simd = "0.8" @@ -69,7 +69,7 @@ prometheus = "0.14.0" blake3 = "1.8.5" moka = { version = "0.12.15", features = ["sync"] } mimalloc = { version = "0.1.52", default-features = false } -regex = "1.12.3" +regex = "1.12.4" image = "0.25.10" chrono = { version = "0.4", features = ["serde", "clock"] } itoa = "1.0.18" @@ -77,9 +77,10 @@ zstd = "0.13" tempfile = "3.27.0" notify = "8.2.0" rand = "0.10.1" +crossbeam-skiplist = "0.1.3" [dev-dependencies] -hyper = "1.9.0" +hyper = "1.10.1" http-body-util = "0.1.3" futures-lite = { version = "2.6.1", default-features = false, features = [ "std", @@ -107,7 +108,7 @@ required-features = ["test-support"] [build-dependencies] anyhow = "1.0.100" -vergen-gitcl = { version = "9.1.0", features = ["build", "cargo", "rustc"] } +vergen-gitcl = { version = "10.0.0", features = ["build", "cargo", "rustc"] } [profile.release] codegen-units = 1 diff --git a/dev-env/init-scripts/init-schema.sql b/dev-env/init-scripts/init-schema.sql index cc85b8e..807502d 100644 --- a/dev-env/init-scripts/init-schema.sql +++ b/dev-env/init-scripts/init-schema.sql @@ -249,6 +249,7 @@ CREATE TABLE IF NOT EXISTS companies ( task_limit_concurrent INTEGER NOT NULL DEFAULT 1 CHECK (task_limit_concurrent >= 0), task_limit_daily INTEGER NOT NULL CHECK (task_limit_daily >= 0), ownership_state TEXT NOT NULL DEFAULT 'unassigned' CHECK (ownership_state IN ('unassigned', 'owned')), + worker_tags TEXT[] NOT NULL DEFAULT '{}'::TEXT[], created_by_user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, created_at BIGINT NOT NULL, updated_at BIGINT NOT NULL diff --git a/src/api/request.rs b/src/api/request.rs index 5af0ee2..ee78c47 100644 --- a/src/api/request.rs +++ b/src/api/request.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use foldhash::{HashSet, fast::RandomState}; use serde::Deserialize; use serde_json::Value; use std::sync::Arc; @@ -8,6 +9,36 @@ use uuid::Uuid; pub use super::gateway_info::{GatewayInfoExt, GatewayInfoExtRef}; use crate::crypto::hotkey::Hotkey; +pub const MAX_WORKER_TAG_LENGTH: usize = 32; + +pub fn normalize_worker_tags(tags: &[String]) -> Result, String> { + let mut normalized = Vec::with_capacity(tags.len()); + let mut seen = HashSet::with_capacity_and_hasher(tags.len(), RandomState::default()); + for tag in tags { + let tag = tag.trim().to_ascii_lowercase(); + if tag.is_empty() { + return Err("worker_tags cannot contain empty tags".to_string()); + } + if tag.len() > MAX_WORKER_TAG_LENGTH { + return Err(format!( + "worker_tags entries cannot exceed {MAX_WORKER_TAG_LENGTH} characters" + )); + } + if !tag.bytes().all(|byte| { + byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'_' || byte == b'-' + }) { + return Err( + "worker_tags entries may only contain lowercase letters, numbers, underscores, or dashes" + .to_string(), + ); + } + if seen.insert(tag.clone()) { + normalized.push(tag); + } + } + Ok(normalized) +} + #[derive(Debug, Clone, Deserialize)] pub struct AddTaskRequest { pub seed: Option, @@ -57,6 +88,8 @@ pub struct GetTasksRequest { pub timestamp: String, pub requested_task_count: usize, pub model: ModelFilter, + #[serde(default)] + pub worker_tags: Vec, } #[derive(Debug, Deserialize)] diff --git a/src/build.rs b/src/build.rs index e669cee..f062d5a 100644 --- a/src/build.rs +++ b/src/build.rs @@ -1,12 +1,17 @@ use anyhow::Result; -use vergen_gitcl::{BuildBuilder, CargoBuilder, Emitter, GitclBuilder, RustcBuilder}; +use vergen_gitcl::{Build, Cargo, Emitter, Gitcl, Rustc}; fn main() -> Result<()> { + let build = Build::all_build(); + let cargo = Cargo::all_cargo(); + let gitcl = Gitcl::all_git(); + let rustc = Rustc::all_rustc(); + Emitter::default() - .add_instructions(&BuildBuilder::all_build()?)? - .add_instructions(&CargoBuilder::all_cargo()?)? - .add_instructions(&GitclBuilder::all_git()?)? - .add_instructions(&RustcBuilder::all_rustc()?)? + .add_instructions(&build)? + .add_instructions(&cargo)? + .add_instructions(&gitcl)? + .add_instructions(&rustc)? .emit() } diff --git a/src/common/queue/mod.rs b/src/common/queue/mod.rs index 3b6e753..6ca1de6 100644 --- a/src/common/queue/mod.rs +++ b/src/common/queue/mod.rs @@ -1,9 +1,10 @@ +use crossbeam_skiplist::SkipMap; use foldhash::fast::RandomState; use foldhash::{HashMap as FoldHashMap, HashSet as FoldHashSet}; use moka::sync::Cache; +use prometheus::IntGauge; use scc::HashMap; use scc::hash_map::Entry as SccEntry; -use sdd::Queue; use std::marker::PhantomData; use std::sync::Arc; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; @@ -19,17 +20,84 @@ const ACTIVE_ID_START_CAPACITY: usize = 1024; const DEFAULT_DUP_COUNT: usize = 1; const DEFAULT_TTL_SECS: u64 = 300; const DEFAULT_CLEANUP_INTERVAL_SECS: u64 = 1; -const EXHAUSTED_CACHE_MAX_CAPACITY: u64 = 16_384; +// Soft aggregate targets for per-model routing caches. The budget is divided +// across configured model buckets, with a minimum of one entry per bucket, so +// pathological model counts can exceed this by requiring one cache slot each. +const EXHAUSTED_CACHE_SOFT_TOTAL_CAPACITY: u64 = 16_384; +const SCAN_CURSOR_CACHE_SOFT_TOTAL_CAPACITY: u64 = 16_384; const MAX_CLEANUP_SCAN_PER_TICK: usize = 64; +const DEFAULT_RESERVATION_SCAN_CAP: usize = 512; const ENTRY_COUNT_BITS: u64 = 30; const ENTRY_COUNT_MASK: u64 = (1u64 << ENTRY_COUNT_BITS) - 1; const ENTRY_LEASED_SHIFT: u64 = ENTRY_COUNT_BITS; -const ENTRY_QUEUED_FLAG: u64 = 1u64 << 60; const ENTRY_RETIRED_FLAG: u64 = 1u64 << 61; const MAX_DUP_SLOTS: usize = ENTRY_COUNT_MASK as usize; type SentKey = (Uuid, usize); +#[derive(Clone, Eq, Hash, PartialEq)] +struct WorkerRoutingKey { + tags: Arc<[String]>, +} + +#[derive(Clone, Eq, Hash, PartialEq)] +struct ExhaustedKey { + hotkey: Hotkey, + routing: WorkerRoutingKey, +} + +#[derive(Clone, Default)] +pub struct TaskRouting { + required_worker_tags: Arc<[String]>, +} + +#[derive(Copy, Clone, Default)] +pub struct WorkerRouting<'a> { + tags: &'a [String], +} + +impl TaskRouting { + pub fn with_required_worker_tags(required_worker_tags: Vec) -> Self { + Self { + required_worker_tags: Arc::from(required_worker_tags), + } + } + + fn matches_worker(&self, worker: WorkerRouting<'_>) -> bool { + if self.required_worker_tags.is_empty() { + worker.tags.is_empty() + } else { + self.required_worker_tags + .iter() + .any(|required| worker.tags.iter().any(|tag| tag == required)) + } + } +} + +impl<'a> WorkerRouting<'a> { + pub fn from_tags(tags: &'a [String]) -> Self { + Self { tags } + } + + fn cache_key(self) -> WorkerRoutingKey { + let mut tags = self.tags.to_vec(); + tags.sort(); + tags.dedup(); + WorkerRoutingKey { + tags: Arc::from(tags), + } + } +} + +impl ExhaustedKey { + fn new(hotkey: &Hotkey, routing: WorkerRoutingKey) -> Self { + Self { + hotkey: hotkey.clone(), + routing, + } + } +} + #[derive(Copy, Clone)] struct ActiveTaskState { count: usize, @@ -38,6 +106,7 @@ struct ActiveTaskState { struct QueueEntry { item: Arc, + routing: TaskRouting, state: AtomicU64, timestamp: Instant, generation: usize, @@ -45,30 +114,25 @@ struct QueueEntry { #[derive(Copy, Clone)] struct LeaseAcquireOutcome { - available_after: usize, final_candidate: bool, } #[derive(Copy, Clone, Default)] struct LeaseReleaseOutcome { generation_bump: bool, - requeue: bool, retire_now: bool, } impl QueueEntry { fn initial_state(available: usize) -> u64 { debug_assert!(available <= MAX_DUP_SLOTS); - Self::pack_state(available, 0, false, false) + Self::pack_state(available, 0, false) } - fn pack_state(available: usize, leased: usize, queued: bool, retired: bool) -> u64 { + fn pack_state(available: usize, leased: usize, retired: bool) -> u64 { debug_assert!(available <= MAX_DUP_SLOTS); debug_assert!(leased <= MAX_DUP_SLOTS); let mut state = (available as u64) | ((leased as u64) << ENTRY_LEASED_SHIFT); - if queued { - state |= ENTRY_QUEUED_FLAG; - } if retired { state |= ENTRY_RETIRED_FLAG; } @@ -83,10 +147,6 @@ impl QueueEntry { ((state >> ENTRY_LEASED_SHIFT) & ENTRY_COUNT_MASK) as usize } - fn is_queued(state: u64) -> bool { - state & ENTRY_QUEUED_FLAG != 0 - } - fn is_retired(state: u64) -> bool { state & ENTRY_RETIRED_FLAG != 0 } @@ -104,45 +164,6 @@ impl QueueEntry { self.timestamp.elapsed() > ttl } - fn clear_queued_flag(&self) { - let mut state = self.load_state(); - loop { - if !Self::is_queued(state) { - return; - } - match self.state.compare_exchange_weak( - state, - state & !ENTRY_QUEUED_FLAG, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => return, - Err(actual) => state = actual, - } - } - } - - fn try_mark_queued(&self) -> bool { - let mut state = self.load_state(); - loop { - if Self::is_retired(state) - || Self::is_queued(state) - || Self::available_slots(state) == 0 - { - return false; - } - match self.state.compare_exchange_weak( - state, - state | ENTRY_QUEUED_FLAG, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => return true, - Err(actual) => state = actual, - } - } - } - fn try_acquire_lease(&self) -> Option { let mut state = self.load_state(); loop { @@ -156,14 +177,13 @@ impl QueueEntry { return None; } - let next = Self::pack_state(available - 1, leased + 1, Self::is_queued(state), false); + let next = Self::pack_state(available - 1, leased + 1, false); match self .state .compare_exchange_weak(state, next, Ordering::AcqRel, Ordering::Acquire) { Ok(_) => { return Some(LeaseAcquireOutcome { - available_after: available - 1, final_candidate: available == 1 && leased == 0, }); } @@ -172,16 +192,16 @@ impl QueueEntry { } } - fn retire_after_pop(&self) -> bool { + fn retire_if_idle(&self) -> bool { let mut state = self.load_state(); loop { + let leased = Self::leased_slots(state); if Self::is_retired(state) { - return false; + return leased == 0; } - let leased = Self::leased_slots(state); let retire_now = leased == 0; - let next = Self::pack_state(0, leased, false, retire_now); + let next = Self::pack_state(0, leased, true); match self .state .compare_exchange_weak(state, next, Ordering::AcqRel, Ordering::Acquire) @@ -192,20 +212,46 @@ impl QueueEntry { } } + /// Retire the entry only if it is genuinely drained: no available capacity + /// and nothing currently leased, observed atomically. Unlike + /// [`retire_if_idle`], this never discards an available slot, so it is safe + /// to call from the scan path where a concurrent rollback/requeue may have + /// restored capacity after an earlier `available == 0` observation. Returns + /// `true` only when this call performed the drained -> retired transition. + fn retire_if_drained(&self) -> bool { + let mut state = self.load_state(); + loop { + if Self::is_retired(state) { + return false; + } + if Self::available_slots(state) != 0 || Self::leased_slots(state) != 0 { + return false; + } + + let next = Self::pack_state(0, 0, true); + match self + .state + .compare_exchange_weak(state, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => return true, + Err(actual) => state = actual, + } + } + } + fn finish_commit(&self) -> bool { let mut state = self.load_state(); loop { let available = Self::available_slots(state); let leased = Self::leased_slots(state); - let queued = Self::is_queued(state); let retired = Self::is_retired(state); if leased == 0 { return false; } let next_leased = leased - 1; - let retire_now = !retired && available == 0 && next_leased == 0 && !queued; - let next = Self::pack_state(available, next_leased, queued, retired || retire_now); + let retire_now = next_leased == 0 && (retired || available == 0); + let next = Self::pack_state(available, next_leased, retired || retire_now); match self .state .compare_exchange_weak(state, next, Ordering::AcqRel, Ordering::Acquire) @@ -221,19 +267,18 @@ impl QueueEntry { loop { let available = Self::available_slots(state); let leased = Self::leased_slots(state); - let queued = Self::is_queued(state); let retired = Self::is_retired(state); if leased == 0 { return LeaseReleaseOutcome::default(); } let next_leased = leased - 1; - let (next_available, generation_bump, requeue, retire_now) = if expired { - (0, false, false, !retired && next_leased == 0 && !queued) + let (next_available, generation_bump, retire_now) = if expired || retired { + (0, false, next_leased == 0) } else { - (available + 1, available == 0, !queued, false) + (available + 1, available == 0, false) }; - let next = Self::pack_state(next_available, next_leased, queued, retired || retire_now); + let next = Self::pack_state(next_available, next_leased, retired || expired); match self .state .compare_exchange_weak(state, next, Ordering::AcqRel, Ordering::Acquire) @@ -241,7 +286,6 @@ impl QueueEntry { Ok(_) => { return LeaseReleaseOutcome { generation_bump, - requeue, retire_now, }; } @@ -252,117 +296,112 @@ impl QueueEntry { } struct ModelBucket { - q: Queue>, + entries: SkipMap>, + next_seq: AtomicU64, live: AtomicUsize, - in_flight: AtomicUsize, activity: AtomicUsize, generation: AtomicUsize, - exhausted_by_hotkey: Cache, + exhausted_by_worker: Cache, + scan_cursor_by_worker: Cache, } +#[cfg(test)] impl Default for ModelBucket { fn default() -> Self { + Self::with_cache_capacity( + EXHAUSTED_CACHE_SOFT_TOTAL_CAPACITY, + SCAN_CURSOR_CACHE_SOFT_TOTAL_CAPACITY, + ) + } +} + +impl ModelBucket { + fn with_cache_capacity(exhausted_capacity: u64, scan_cursor_capacity: u64) -> Self { Self { - q: Queue::default(), + entries: SkipMap::new(), + next_seq: AtomicU64::new(0), live: AtomicUsize::new(0), - in_flight: AtomicUsize::new(0), activity: AtomicUsize::new(0), generation: AtomicUsize::new(0), - exhausted_by_hotkey: Cache::builder() - .max_capacity(EXHAUSTED_CACHE_MAX_CAPACITY) + exhausted_by_worker: Cache::builder() + .max_capacity(exhausted_capacity.max(1)) + .build_with_hasher(RandomState::default()), + scan_cursor_by_worker: Cache::builder() + .max_capacity(scan_cursor_capacity.max(1)) .build_with_hasher(RandomState::default()), } } } impl ModelBucket { - fn pop_visible_entry(&self) -> Option> { - self.pop_visible_entry_inner(|| {}) + #[cfg(test)] + fn snapshot_entries(&self, limit: usize) -> Vec<(u64, Arc)> { + self.entries + .iter() + .take(limit) + .map(|entry| (*entry.key(), Arc::clone(entry.value()))) + .collect() } - fn pop_visible_entry_inner(&self, after_pop: F) -> Option> { - self.in_flight.fetch_add(1, Ordering::AcqRel); - match self.q.pop() { - Some(shared) => { - let entry = (**shared).clone(); - entry.clear_queued_flag(); - after_pop(); - self.activity.fetch_add(1, Ordering::AcqRel); - Some(entry) - } - None => { - self.in_flight.fetch_sub(1, Ordering::AcqRel); - None - } - } + fn bump_generation(&self) { + self.generation.fetch_add(1, Ordering::AcqRel); } - #[cfg(test)] - fn pop_visible_entry_with_hook(&self, after_pop: F) -> Option> { - self.pop_visible_entry_inner(after_pop) + fn clear_exhausted(&self, key: &ExhaustedKey) { + self.exhausted_by_worker.invalidate(key); } - fn push_ready(&self, entry: Arc) -> bool { - if !entry.try_mark_queued() { - return false; - } - self.q.push(entry); - self.activity.fetch_add(1, Ordering::AcqRel); - true + fn scan_cursor(&self, key: &ExhaustedKey) -> u64 { + self.scan_cursor_by_worker.get(key).unwrap_or(0) } - fn finish_entry(&self, entry: Arc, requeue: bool) -> bool { - let requeued = requeue && self.push_ready(entry); - self.in_flight.fetch_sub(1, Ordering::AcqRel); - requeued + fn update_scan_cursor(&self, key: &ExhaustedKey, seq: u64) { + self.scan_cursor_by_worker.insert(key.clone(), seq); } - fn bump_generation(&self) { - self.generation.fetch_add(1, Ordering::AcqRel); - } - - fn clear_exhausted_hotkey(&self, hotkey: &Hotkey) { - self.exhausted_by_hotkey.invalidate(hotkey); + fn rewind_scan_cursor(&self, key: &ExhaustedKey, seq: u64) { + self.scan_cursor_by_worker.insert(key.clone(), seq); } fn enqueue_new(&self, entry: Arc) { - let _ = self.push_ready(entry); + let seq = self.next_seq.fetch_add(1, Ordering::Relaxed); + self.entries.insert(seq, entry); self.live.fetch_add(1, Ordering::Relaxed); // Readers use the generation acquire-load as the publication point for // the widened live/activity state. + self.activity.fetch_add(1, Ordering::AcqRel); self.bump_generation(); } - fn exhausted_generation(&self, hotkey: &Hotkey) -> Option { - self.exhausted_by_hotkey.get(hotkey) + fn remove_entry(&self, seq: u64) -> bool { + if self.entries.remove(&seq).is_none() { + return false; + } + self.live.fetch_sub(1, Ordering::Relaxed); + self.activity.fetch_add(1, Ordering::AcqRel); + self.bump_generation(); + true + } + + fn exhausted_generation(&self, key: &ExhaustedKey) -> Option { + self.exhausted_by_worker.get(key) } fn mark_exhausted_if_stable( &self, - hotkey: &Hotkey, + key: &ExhaustedKey, generation: usize, start_activity: usize, local_activity: usize, - start_in_flight: usize, ) { - if start_in_flight != 0 { - return; - } if self.generation.load(Ordering::Acquire) != generation { return; } - if self.in_flight.load(Ordering::Acquire) != 0 { - return; - } if self.activity.load(Ordering::Acquire) != start_activity.wrapping_add(local_activity) { return; } - self.exhausted_by_hotkey.insert(hotkey.clone(), generation); - } - - fn finish_live_entry(&self) { - self.live.fetch_sub(1, Ordering::Relaxed); + self.exhausted_by_worker.insert(key.clone(), generation); } } @@ -376,8 +415,10 @@ struct Inner { active_ids: HashMap, ttl: Duration, len: AtomicUsize, + queue_len_gauge: Option, next_bucket: AtomicUsize, next_generation: AtomicUsize, + reservation_scan_cap: AtomicUsize, } #[derive(Clone)] @@ -394,10 +435,12 @@ pub struct TaskQueueReservation { pub struct TaskQueueDelivery { inner: Arc, bucket: Arc, + seq: u64, entry: Option>, key: Option, hotkey: Option, - task: Task, + exhausted_key: ExhaustedKey, + task: Option, duration: Option, } @@ -407,6 +450,8 @@ pub struct TaskQueueBuilder { cleanup_interval: Duration, default_model: Option, models: Vec, + queue_len_gauge: Option, + reservation_scan_cap: usize, } impl TaskQueueBuilder { @@ -439,6 +484,16 @@ impl TaskQueueBuilder { self } + pub fn queue_len_gauge(mut self, gauge: IntGauge) -> Self { + self.queue_len_gauge = Some(gauge); + self + } + + pub fn reservation_scan_cap(mut self, cap: usize) -> Self { + self.reservation_scan_cap = cap.max(1); + self + } + pub fn build(self) -> TaskQueue { assert!( self.dup.max(1) <= MAX_DUP_SLOTS, @@ -460,9 +515,19 @@ impl TaskQueueBuilder { } } + let bucket_count = bucket_order.len().max(1) as u64; + let exhausted_cache_capacity = (EXHAUSTED_CACHE_SOFT_TOTAL_CAPACITY / bucket_count).max(1); + let scan_cursor_cache_capacity = + (SCAN_CURSOR_CACHE_SOFT_TOTAL_CAPACITY / bucket_count).max(1); let mut buckets = FoldHashMap::default(); for model in &bucket_order { - buckets.insert(model.clone(), Arc::new(ModelBucket::default())); + buckets.insert( + model.clone(), + Arc::new(ModelBucket::with_cache_capacity( + exhausted_cache_capacity, + scan_cursor_cache_capacity, + )), + ); } let inner = Arc::new(Inner { @@ -477,9 +542,12 @@ impl TaskQueueBuilder { ), ttl: self.ttl, len: AtomicUsize::new(0), + queue_len_gauge: self.queue_len_gauge, next_bucket: AtomicUsize::new(0), next_generation: AtomicUsize::new(1), + reservation_scan_cap: AtomicUsize::new(self.reservation_scan_cap.max(1)), }); + inner.observe_len(0); let cleanup_interval = self.cleanup_interval; let weak_inner = Arc::downgrade(&inner); task::spawn(async move { @@ -491,21 +559,21 @@ impl TaskQueueBuilder { .live .load(Ordering::Acquire) .min(MAX_CLEANUP_SCAN_PER_TICK); - for _ in 0..scan_limit { - let Some(entry) = bucket.pop_visible_entry() else { - break; - }; - - if entry.is_expired(inner.ttl) || !entry.has_visible_capacity() { - let retired_now = entry.retire_after_pop(); + for index_entry in bucket.entries.iter().take(scan_limit) { + let seq = *index_entry.key(); + let entry = Arc::clone(index_entry.value()); + drop(index_entry); + + let state = entry.load_state(); + let idle_without_capacity = QueueEntry::available_slots(state) == 0 + && QueueEntry::leased_slots(state) == 0; + if entry.is_expired(inner.ttl) || idle_without_capacity { + let retired_now = entry.retire_if_idle(); let task_id = *entry.item.id(); let generation = entry.generation; - bucket.finish_entry(entry, false); if retired_now { - inner.retire_entry(bucket.as_ref(), task_id, generation); + inner.retire_entry(bucket.as_ref(), seq, task_id, generation); } - } else { - bucket.finish_entry(entry, true); } } } @@ -520,13 +588,22 @@ impl TaskQueueReservation { self.inner.push_internal(task); self.committed = true; } + + pub fn push_with_routing(mut self, task: Task, routing: TaskRouting) { + self.inner.push_internal_with_routing(task, routing); + self.committed = true; + } } impl TaskQueueDelivery { pub fn task(&self) -> &Task { - &self.task + self.task + .as_ref() + .expect("delivery owns a task until commit") } + /// Returns a pre-commit duration hint for the final lease. The duration + /// returned by `commit()` is authoritative because retirement happens there. pub fn duration(&self) -> Option { self.duration } @@ -540,16 +617,20 @@ impl TaskQueueDelivery { .key .take() .expect("delivery must own a dedupe key before commit"); + let task = self + .task + .take() + .expect("delivery must own a task before commit"); let generation = entry.generation; let retired_now = entry.finish_commit(); if retired_now { self.inner - .retire_entry(self.bucket.as_ref(), self.task.id, generation); + .retire_entry(self.bucket.as_ref(), self.seq, task.id, generation); } let duration = retired_now.then(|| entry.timestamp.elapsed()); - (self.task.clone(), duration.or(self.duration)) + (task, duration.or(self.duration)) } pub fn rollback(mut self) { @@ -568,25 +649,30 @@ impl TaskQueueDelivery { let Some(entry) = self.entry.take() else { return; }; + let expired = force_retire || entry.is_expired(self.inner.ttl); if let Some(key) = self.key.take() && let Some(hotkey) = self.hotkey.take() && clear_hotkey_delivery { self.inner.clear_sent_for_hotkey(key, &hotkey); - self.bucket.clear_exhausted_hotkey(&hotkey); + self.bucket.clear_exhausted(&self.exhausted_key); + if !expired { + self.bucket + .rewind_scan_cursor(&self.exhausted_key, self.seq); + } } - let expired = force_retire || entry.is_expired(self.inner.ttl); let outcome = entry.finish_rollback(expired); if outcome.generation_bump { self.bucket.bump_generation(); } - if outcome.requeue { - let _ = self.bucket.push_ready(Arc::clone(&entry)); - } if outcome.retire_now { - self.inner - .retire_entry(self.bucket.as_ref(), self.task.id, entry.generation); + self.inner.retire_entry( + self.bucket.as_ref(), + self.seq, + *entry.item.id(), + entry.generation, + ); } } } @@ -594,7 +680,7 @@ impl TaskQueueDelivery { impl Drop for TaskQueueReservation { fn drop(&mut self) { if !self.committed { - self.inner.len.fetch_sub(1, Ordering::Relaxed); + self.inner.decrement_len(); } } } @@ -606,6 +692,52 @@ impl Drop for TaskQueueDelivery { } impl Inner { + fn observe_len(&self, len: usize) { + if let Some(gauge) = self.queue_len_gauge.as_ref() { + gauge.set(len.try_into().unwrap_or(i64::MAX)); + } + } + + fn increment_len(&self) { + let len = self.len.fetch_add(1, Ordering::Relaxed) + 1; + self.observe_len(len); + } + + fn decrement_len(&self) { + loop { + let len = self.len.load(Ordering::Relaxed); + if len == 0 { + self.observe_len(0); + return; + } + if self + .len + .compare_exchange_weak(len, len - 1, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.observe_len(len - 1); + return; + } + } + } + + fn try_increment_len(&self, max_len: usize) -> bool { + loop { + let len = self.len.load(Ordering::Relaxed); + if len >= max_len { + return false; + } + if self + .len + .compare_exchange_weak(len, len + 1, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.observe_len(len + 1); + return true; + } + } + } + fn clear_sent_for_task_generation(&self, task_id: Uuid, generation: usize) { let _ = self.sent.remove_sync(&(task_id, generation)); } @@ -623,10 +755,11 @@ impl Inner { } } - fn retire_entry(&self, bucket: &ModelBucket, task_id: Uuid, generation: usize) { - bucket.finish_live_entry(); - self.len.fetch_sub(1, Ordering::Relaxed); - self.release_task_id(task_id, generation); + fn retire_entry(&self, bucket: &ModelBucket, seq: u64, task_id: Uuid, generation: usize) { + if bucket.remove_entry(seq) { + self.decrement_len(); + self.release_task_id(task_id, generation); + } } fn acquire_task_generation(&self, task_id: Uuid) -> usize { @@ -687,11 +820,16 @@ impl Inner { } fn push_internal(&self, task: Task) { + self.push_internal_with_routing(task, TaskRouting::default()); + } + + fn push_internal_with_routing(&self, task: Task, routing: TaskRouting) { let (task, model) = self.normalize_task(task); let task_id = *task.id(); let generation = self.acquire_task_generation(task_id); let entry = Arc::new(QueueEntry { item: Arc::new(task), + routing, state: AtomicU64::new(QueueEntry::initial_state(self.dup)), timestamp: Instant::now(), generation, @@ -708,33 +846,25 @@ impl TaskQueue { cleanup_interval: Duration::from_secs(DEFAULT_CLEANUP_INTERVAL_SECS), default_model: None, models: Vec::new(), + queue_len_gauge: None, + reservation_scan_cap: DEFAULT_RESERVATION_SCAN_CAP, } } pub fn push(&self, task: Task) { - self.inner.len.fetch_add(1, Ordering::Relaxed); + self.inner.increment_len(); self.inner.push_internal(task); } pub fn try_reserve(&self, max_len: usize) -> Option { - loop { - let len = self.inner.len.load(Ordering::Relaxed); - if len >= max_len { - return None; - } - if self - .inner - .len - .compare_exchange_weak(len, len + 1, Ordering::Relaxed, Ordering::Relaxed) - .is_ok() - { - return Some(TaskQueueReservation { - inner: Arc::clone(&self.inner), - committed: false, - _marker: PhantomData, - }); - } + if !self.inner.try_increment_len(max_len) { + return None; } + Some(TaskQueueReservation { + inner: Arc::clone(&self.inner), + committed: false, + _marker: PhantomData, + }) } #[allow(dead_code)] @@ -754,6 +884,7 @@ impl TaskQueue { self.reserve_from_bucket_refs( num, hotkey, + WorkerRouting::default(), self.inner .bucket_order .iter() @@ -785,18 +916,38 @@ impl TaskQueue { hotkey: &Hotkey, models: &[S], ) -> Vec + where + S: AsRef, + { + self.reserve_for_models_with_routing(num, hotkey, models, WorkerRouting::default()) + } + + pub fn reserve_for_models_with_routing( + &self, + num: usize, + hotkey: &Hotkey, + models: &[S], + routing: WorkerRouting<'_>, + ) -> Vec where S: AsRef, { self.reserve_from_bucket_refs( num, hotkey, + routing, models .iter() .filter_map(|model| self.inner.buckets.get(model.as_ref()).cloned()), ) } + pub fn set_reservation_scan_cap(&self, cap: usize) { + self.inner + .reservation_scan_cap + .store(cap.max(1), Ordering::Release); + } + fn pop_from_bucket_refs( &self, num: usize, @@ -806,7 +957,7 @@ impl TaskQueue { where I: IntoIterator>, { - self.reserve_from_bucket_refs(num, hotkey, buckets_iter) + self.reserve_from_bucket_refs(num, hotkey, WorkerRouting::default(), buckets_iter) .into_iter() .map(TaskQueueDelivery::commit) .collect() @@ -816,6 +967,7 @@ impl TaskQueue { &self, num: usize, hotkey: &Hotkey, + routing: WorkerRouting<'_>, buckets_iter: I, ) -> Vec where @@ -830,6 +982,11 @@ impl TaskQueue { return Vec::new(); } + // The exhausted key only depends on the hotkey and worker routing, not + // on the bucket, so compute it (and its routing-tag allocation) once per + // reservation instead of once per bucket scan attempt. + let exhausted_key = ExhaustedKey::new(hotkey, routing.cache_key()); + let mut result = Vec::with_capacity(num); while result.len() < num { let start = self.inner.next_bucket.fetch_add(1, Ordering::Relaxed) % buckets.len(); @@ -840,7 +997,9 @@ impl TaskQueue { break; } let bucket = &buckets[(start + offset) % buckets.len()]; - if let Some(item) = self.reserve_one_from_bucket(bucket, hotkey) { + if let Some(item) = + self.reserve_one_from_bucket(bucket, hotkey, routing, &exhausted_key) + { progressed = true; result.push(item); } @@ -858,101 +1017,173 @@ impl TaskQueue { &self, bucket: &Arc, hotkey: &Hotkey, + routing: WorkerRouting<'_>, + exhausted_key: &ExhaustedKey, ) -> Option { let generation = bucket.generation.load(Ordering::Acquire); - if bucket.exhausted_generation(hotkey) == Some(generation) { + if bucket.exhausted_generation(exhausted_key) == Some(generation) { return None; } let start_activity = bucket.activity.load(Ordering::Acquire); - let start_in_flight = bucket.in_flight.load(Ordering::Acquire); - let scan_limit = bucket.live.load(Ordering::Acquire); - let mut local_activity = 0usize; - - for _ in 0..scan_limit { - let Some(entry) = bucket.pop_visible_entry() else { - bucket.mark_exhausted_if_stable( - hotkey, - generation, - start_activity, - local_activity, - start_in_flight, - ); - return None; - }; - local_activity = local_activity.wrapping_add(1); - - if entry.is_expired(self.inner.ttl) || !entry.has_visible_capacity() { - let task_id = *entry.item.id(); - let generation = entry.generation; - let retired_now = entry.retire_after_pop(); - bucket.finish_entry(entry, false); - if retired_now { - self.inner - .retire_entry(bucket.as_ref(), task_id, generation); - } - continue; + let live = bucket.live.load(Ordering::Acquire); + let scan_cap = self.inner.reservation_scan_cap.load(Ordering::Acquire); + let scan_limit = live.min(scan_cap); + let local_activity = 0usize; + + if scan_limit == 0 { + bucket.mark_exhausted_if_stable( + exhausted_key, + generation, + start_activity, + local_activity, + ); + return None; + } + + let start_seq = bucket.scan_cursor(exhausted_key); + let mut scanned_any = false; + let mut scanned_count = 0usize; + + for index_entry in bucket.entries.range(start_seq..) { + scanned_any = true; + scanned_count += 1; + let seq = *index_entry.key(); + let entry = Arc::clone(index_entry.value()); + drop(index_entry); + bucket.update_scan_cursor(exhausted_key, seq.saturating_add(1)); + + if let Some(delivery) = + self.reserve_scanned_entry(bucket, exhausted_key, hotkey, routing, seq, entry) + { + return Some(delivery); + } + if scanned_count >= scan_limit { + break; } + } - let key = (*entry.item.id(), entry.generation); - let already_sent = match self.inner.sent.entry_sync(key) { - SccEntry::Occupied(mut sent_entry) => { - let sent_hotkeys = sent_entry.get_mut(); - !sent_hotkeys.insert(hotkey.clone()) - } - SccEntry::Vacant(sent_entry) => { - let mut sent_hotkeys = FoldHashSet::default(); - sent_hotkeys.insert(hotkey.clone()); - sent_entry.insert_entry(sent_hotkeys); - false + if scanned_count < scan_limit && start_seq != 0 { + for index_entry in bucket.entries.iter() { + let seq = *index_entry.key(); + if seq >= start_seq { + break; } - }; - - if already_sent { - if bucket.finish_entry(entry, true) { - local_activity = local_activity.wrapping_add(1); + scanned_any = true; + scanned_count += 1; + let entry = Arc::clone(index_entry.value()); + drop(index_entry); + bucket.update_scan_cursor(exhausted_key, seq.saturating_add(1)); + + if let Some(delivery) = + self.reserve_scanned_entry(bucket, exhausted_key, hotkey, routing, seq, entry) + { + return Some(delivery); } - continue; - } - - let Some(lease) = entry.try_acquire_lease() else { - self.inner.clear_sent_for_hotkey(key, hotkey); - let task_id = *entry.item.id(); - let generation = entry.generation; - let retired_now = entry.retire_after_pop(); - bucket.finish_entry(entry, false); - if retired_now { - self.inner - .retire_entry(bucket.as_ref(), task_id, generation); + if scanned_count >= scan_limit { + break; } - continue; - }; + } + } - bucket.clear_exhausted_hotkey(hotkey); - let _ = bucket.finish_entry(Arc::clone(&entry), lease.available_after > 0); - let task = (*entry.item).clone(); - let duration = lease.final_candidate.then(|| entry.timestamp.elapsed()); - return Some(TaskQueueDelivery { - inner: Arc::clone(&self.inner), - bucket: Arc::clone(bucket), - entry: Some(entry), - key: Some(key), - hotkey: Some(hotkey.clone()), - task, - duration, - }); - } - - bucket.mark_exhausted_if_stable( - hotkey, - generation, - start_activity, - local_activity, - start_in_flight, - ); + if !scanned_any { + bucket.mark_exhausted_if_stable( + exhausted_key, + generation, + start_activity, + local_activity, + ); + return None; + } + + if live <= scan_cap && scanned_count >= live { + bucket.mark_exhausted_if_stable( + exhausted_key, + generation, + start_activity, + local_activity, + ); + } None } + fn reserve_scanned_entry( + &self, + bucket: &Arc, + exhausted_key: &ExhaustedKey, + hotkey: &Hotkey, + routing: WorkerRouting<'_>, + seq: u64, + entry: Arc, + ) -> Option { + if entry.is_expired(self.inner.ttl) { + let task_id = *entry.item.id(); + let generation = entry.generation; + let retired_now = entry.retire_if_idle(); + if retired_now { + self.inner + .retire_entry(bucket.as_ref(), seq, task_id, generation); + } + return None; + } + + if !entry.has_visible_capacity() { + let task_id = *entry.item.id(); + let generation = entry.generation; + // Retire only if the entry is genuinely drained (no available + // capacity and nothing leased) as observed atomically. A concurrent + // rollback/requeue can restore an available slot between the + // has_visible_capacity() check above and this point; retiring then + // would destroy a live, still-deliverable slot and strand the task. + if entry.retire_if_drained() { + self.inner + .retire_entry(bucket.as_ref(), seq, task_id, generation); + } + return None; + } + + if !entry.routing.matches_worker(routing) { + return None; + } + + let key = (*entry.item.id(), entry.generation); + let already_sent = match self.inner.sent.entry_sync(key) { + SccEntry::Occupied(mut sent_entry) => { + let sent_hotkeys = sent_entry.get_mut(); + !sent_hotkeys.insert(hotkey.clone()) + } + SccEntry::Vacant(sent_entry) => { + let mut sent_hotkeys = FoldHashSet::default(); + sent_hotkeys.insert(hotkey.clone()); + sent_entry.insert_entry(sent_hotkeys); + false + } + }; + if already_sent { + return None; + } + + let Some(lease) = entry.try_acquire_lease() else { + self.inner.clear_sent_for_hotkey(key, hotkey); + return None; + }; + + bucket.clear_exhausted(exhausted_key); + let task = (*entry.item).clone(); + let duration = lease.final_candidate.then(|| entry.timestamp.elapsed()); + Some(TaskQueueDelivery { + inner: Arc::clone(&self.inner), + bucket: Arc::clone(bucket), + seq, + entry: Some(entry), + key: Some(key), + hotkey: Some(hotkey.clone()), + exhausted_key: exhausted_key.clone(), + task: Some(task), + duration, + }) + } + #[allow(dead_code)] pub fn dup(&self) -> usize { self.inner.dup diff --git a/src/common/queue/tests/capacity.rs b/src/common/queue/tests/capacity.rs new file mode 100644 index 0000000..4fa09cd --- /dev/null +++ b/src/common/queue/tests/capacity.rs @@ -0,0 +1,93 @@ +use super::*; + +#[test] +#[should_panic(expected = "task queue requires a non-empty default_model")] +fn build_requires_non_empty_default_model() { + let _ = TaskQueue::builder().default_model(" ").build(); +} + +#[test] +#[should_panic(expected = "task queue dup exceeds supported internal slot count")] +fn build_rejects_unsupported_dup_count() { + let _ = TaskQueue::builder() + .dup(usize::MAX) + .default_model("404-3dgs") + .build(); +} + +#[tokio::test] +async fn queue_len_gauge_tracks_atomic_len_changes() { + let (queue, gauge) = build_queue_with_len_gauge(); + assert_eq!(gauge.get(), 0); + + let uncommitted_slot = queue.try_reserve(usize::MAX).expect("reserve queue slot"); + assert_eq!(queue.len(), 1); + assert_eq!(gauge.get(), 1); + drop(uncommitted_slot); + assert_eq!(queue.len(), 0); + assert_eq!(gauge.get(), 0); + + let task = create_task("observed", Some("404-3dgs")); + queue.push(task.clone()); + assert_eq!(queue.len(), 1); + assert_eq!(gauge.get(), 1); + + let deliveries = queue.reserve_for_models(1, &Hotkey::from_bytes(&[26u8; 32]), &["404-3dgs"]); + assert_eq!(deliveries.len(), 1); + let (delivered, _) = deliveries.into_iter().next().expect("delivery").commit(); + assert_eq!(delivered, task); + assert_eq!(queue.len(), 0); + assert_eq!(gauge.get(), 0); +} + +#[tokio::test] +async fn try_reserve_respects_capacity() { + let queue = build_queue_with_config(1, 300, 1); + let task = create_task("Task A", None); + + let first_slot = queue.try_reserve(1).expect("first reservation should fit"); + assert_eq!(queue.len(), 1); + assert!( + queue.try_reserve(1).is_none(), + "second reservation should exceed the bound" + ); + + first_slot.push(task.clone()); + assert_eq!(queue.len(), 1); + + let hk: Hotkey = Hotkey::from_bytes(&[24u8; 32]); + let popped = queue.pop(1, &hk); + assert_eq!(popped.len(), 1); + assert_eq!(popped[0].0, task); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_try_reserve_respects_capacity() { + let queue = Arc::new(build_queue_with_config(1, 300, 1)); + let max_len = 4; + let contenders = 16; + let barrier = Arc::new(ThreadBarrier::new(contenders)); + let mut handles = Vec::with_capacity(contenders); + + for _ in 0..contenders { + let q = Arc::clone(&queue); + let barrier = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + barrier.wait(); + q.try_reserve(max_len) + })); + } + + let mut reservations = Vec::new(); + for handle in handles { + if let Some(reservation) = handle.join().expect("reservation racer should complete") { + reservations.push(reservation); + } + } + + assert_eq!(reservations.len(), max_len); + assert_eq!(queue.len(), max_len); + + drop(reservations); + assert_eq!(queue.len(), 0); +} diff --git a/src/common/queue/tests/concurrency.rs b/src/common/queue/tests/concurrency.rs new file mode 100644 index 0000000..dd0770c --- /dev/null +++ b/src/common/queue/tests/concurrency.rs @@ -0,0 +1,757 @@ +use super::*; + +#[tokio::test] +async fn bucket_activity_without_generation_change_does_not_mark_bucket_exhausted() { + let bucket = Arc::new(ModelBucket::default()); + let entry = raw_entry(create_task("robot", Some("404-3dgs")), 2, 0, 7); + bucket.enqueue_new(entry); + + let hotkey = Hotkey::from_bytes(&[26u8; 32]); + let generation = bucket.generation.load(Ordering::Acquire); + let start_activity = bucket.activity.load(Ordering::Acquire); + + bucket.activity.fetch_add(1, Ordering::AcqRel); + let key = exhausted_key(&hotkey, &[]); + bucket.mark_exhausted_if_stable(&key, generation, start_activity, 0); + assert!( + bucket.exhausted_generation(&key).is_none(), + "bucket activity must block exhausted-cache writes even when generation is unchanged" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_push_pop_returns_unique_tasks_per_hotkey() { + const HOTKEYS: usize = 3; + const TASKS: usize = 64; + + let queue = Arc::new(build_queue_with_config(HOTKEYS, 300, 1)); + let start = Arc::new(ThreadBarrier::new(HOTKEYS + 1)); + let delivered_count = Arc::new(AtomicUsize::new(0)); + let delivered_by_task = Arc::new(Mutex::new(HashMap::>::new())); + + let producer = { + let queue = Arc::clone(&queue); + let start = Arc::clone(&start); + thread::spawn(move || { + start.wait(); + for i in 0..TASKS { + queue.push(create_task(&format!("Task {}", i), None)); + } + }) + }; + + let mut handles = Vec::new(); + for worker_idx in 0..HOTKEYS { + let queue = Arc::clone(&queue); + let start = Arc::clone(&start); + let delivered_count = Arc::clone(&delivered_count); + let delivered_by_task = Arc::clone(&delivered_by_task); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[20u8 + worker_idx as u8; 32]); + let deadline = Instant::now() + Duration::from_secs(10); + start.wait(); + + while delivered_count.load(Ordering::Acquire) < TASKS * HOTKEYS { + let res = queue.pop(5, &hotkey); + if res.is_empty() { + assert!( + Instant::now() < deadline, + "concurrent push/pop test stalled with {} deliveries", + delivered_count.load(Ordering::Relaxed) + ); + thread::yield_now(); + continue; + } + + let set: HashSet<_> = res.iter().map(|(task, _)| task.id).collect(); + assert_eq!(set.len(), res.len()); + for (task, _) in res { + { + let mut delivered_by_task = + delivered_by_task.lock().expect("delivered map lock"); + let workers = delivered_by_task.entry(task.id).or_default(); + assert!( + workers.insert(worker_idx), + "worker received the same task twice" + ); + assert!( + workers.len() <= HOTKEYS, + "task was delivered more times than expected" + ); + } + delivered_count.fetch_add(1, Ordering::Release); + } + } + })); + } + + producer.join().expect("producer thread should not panic"); + for handle in handles { + handle.join().expect("consumer thread should not panic"); + } + + assert_eq!(queue.len(), 0); + assert_eq!(delivered_count.load(Ordering::Relaxed), TASKS * HOTKEYS); + let delivered_by_task = delivered_by_task.lock().expect("delivered map lock"); + assert_eq!(delivered_by_task.len(), TASKS); + for workers in delivered_by_task.values() { + assert_eq!(workers.len(), HOTKEYS); + } +} +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn multi_threaded_scan_pick_stress_preserves_correctness() { + const PRODUCERS: usize = 2; + const CONSUMERS: usize = 8; + const DUP: usize = 2; + const RUN_FOR: Duration = Duration::from_secs(3); + + let queue = Arc::new(build_queue_with_config(DUP, 60, 1)); + let state = Arc::new(Mutex::new(MtDeliveryState::default())); + let stop = Arc::new(AtomicBool::new(false)); + let produced_count = Arc::new(AtomicUsize::new(0)); + let committed_count = Arc::new(AtomicUsize::new(0)); + let start = Arc::new(ThreadBarrier::new(PRODUCERS + CONSUMERS + 1)); + let mut handles = Vec::with_capacity(PRODUCERS + CONSUMERS); + + for producer_idx in 0..PRODUCERS { + let queue = Arc::clone(&queue); + let state = Arc::clone(&state); + let stop = Arc::clone(&stop); + let produced_count = Arc::clone(&produced_count); + let start = Arc::clone(&start); + handles.push(thread::spawn(move || { + start.wait(); + let mut local_sequence = 0usize; + + while !stop.load(Ordering::Acquire) { + let sequence = producer_idx * 1_000_000 + local_sequence; + let required_tag = mt_required_tag(sequence); + let task = create_task( + &format!( + "mt-producer-{producer_idx}-task-{local_sequence}-tag-{}", + required_tag.unwrap_or("public") + ), + Some("404-3dgs"), + ); + + { + let mut state = state.lock().expect("mt delivery state lock"); + state.produced.insert(task.id, required_tag); + } + + match required_tag { + Some(tag) => push_task_with_required_tags(&queue, task, &[tag]), + None => queue.push(task), + } + produced_count.fetch_add(1, Ordering::Relaxed); + local_sequence += 1; + thread::sleep(Duration::from_millis(1)); + } + })); + } + + for worker_idx in 0..CONSUMERS { + let queue = Arc::clone(&queue); + let state = Arc::clone(&state); + let stop = Arc::clone(&stop); + let committed_count = Arc::clone(&committed_count); + let start = Arc::clone(&start); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[worker_idx as u8 + 80; 32]); + let worker_tags = mt_worker_tags(worker_idx); + start.wait(); + + while !stop.load(Ordering::Acquire) { + let deliveries = queue.reserve_for_models_with_routing( + 8, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + if deliveries.is_empty() { + thread::yield_now(); + continue; + } + + for delivery in deliveries { + let (task, _) = delivery.commit(); + record_mt_delivery(&state, worker_idx, &task, DUP); + committed_count.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + start.wait(); + thread::sleep(RUN_FOR); + stop.store(true, Ordering::Release); + + for handle in handles { + handle.join().expect("mt stress thread should not panic"); + } + + let drain_deadline = Instant::now() + Duration::from_secs(10); + while !queue.is_empty() { + let mut progressed = false; + for worker_idx in 0..CONSUMERS { + let hotkey = Hotkey::from_bytes(&[worker_idx as u8 + 80; 32]); + let worker_tags = mt_worker_tags(worker_idx); + let deliveries = queue.reserve_for_models_with_routing( + 16, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + + if deliveries.is_empty() { + continue; + } + progressed = true; + + for delivery in deliveries { + let (task, _) = delivery.commit(); + record_mt_delivery(&state, worker_idx, &task, DUP); + committed_count.fetch_add(1, Ordering::Relaxed); + } + } + + if !progressed { + assert!( + Instant::now() < drain_deadline, + "mt stress drain stalled with {} queued tasks", + queue.len() + ); + thread::yield_now(); + } + } + + let produced_count = produced_count.load(Ordering::Relaxed); + assert!( + produced_count >= 200, + "mt stress should produce enough tasks to exercise concurrent queue paths" + ); + + let state = state.lock().expect("mt delivery state lock"); + assert_eq!(state.produced.len(), produced_count); + assert_eq!( + committed_count.load(Ordering::Relaxed), + produced_count * DUP + ); + + for task_id in state.produced.keys() { + assert_eq!( + state.delivered_counts.get(task_id).copied(), + Some(DUP), + "produced task was not committed exactly dup times" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn multi_threaded_commit_only_scan_pick_preserves_per_worker_ordering() { + const CONSUMERS: usize = 16; + const DUP: usize = 2; + const RUN_FOR: Duration = Duration::from_secs(3); + + let queue = Arc::new(build_queue_with_config(DUP, 60, 60)); + let state = Arc::new(Mutex::new(MtOrderingState { + produced_sequences: HashMap::new(), + delivered_counts: HashMap::new(), + last_sequence_by_worker: vec![None; CONSUMERS], + })); + let stop = Arc::new(AtomicBool::new(false)); + let produced_count = Arc::new(AtomicUsize::new(0)); + let committed_count = Arc::new(AtomicUsize::new(0)); + let start = Arc::new(ThreadBarrier::new(CONSUMERS + 2)); + let mut handles = Vec::with_capacity(CONSUMERS + 1); + + { + let queue = Arc::clone(&queue); + let state = Arc::clone(&state); + let stop = Arc::clone(&stop); + let produced_count = Arc::clone(&produced_count); + let start = Arc::clone(&start); + handles.push(thread::spawn(move || { + start.wait(); + let mut sequence = 0usize; + + while !stop.load(Ordering::Acquire) { + let required_tag = mt_required_tag(sequence); + let task = create_task( + &format!( + "mt-order-task-{sequence}-tag-{}", + required_tag.unwrap_or("public") + ), + Some("404-3dgs"), + ); + + { + let mut state = state.lock().expect("mt ordering state lock"); + state.produced_sequences.insert(task.id, sequence); + } + + match required_tag { + Some(tag) => push_task_with_required_tags(&queue, task, &[tag]), + None => queue.push(task), + } + produced_count.fetch_add(1, Ordering::Relaxed); + sequence += 1; + thread::sleep(Duration::from_millis(1)); + } + })); + } + + for worker_idx in 0..CONSUMERS { + let queue = Arc::clone(&queue); + let state = Arc::clone(&state); + let stop = Arc::clone(&stop); + let committed_count = Arc::clone(&committed_count); + let start = Arc::clone(&start); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[worker_idx as u8 + 100; 32]); + let worker_tags = mt_worker_tags(worker_idx); + start.wait(); + + while !stop.load(Ordering::Acquire) { + let deliveries = queue.reserve_for_models_with_routing( + 8, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + if deliveries.is_empty() { + thread::yield_now(); + continue; + } + + for delivery in deliveries { + let (task, _) = delivery.commit(); + record_mt_ordered_delivery(&state, worker_idx, &task, DUP); + committed_count.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + start.wait(); + thread::sleep(RUN_FOR); + stop.store(true, Ordering::Release); + + for handle in handles { + handle + .join() + .expect("mt ordering stress thread should not panic"); + } + + let drain_deadline = Instant::now() + Duration::from_secs(30); + while !queue.is_empty() { + let mut progressed = false; + for worker_idx in 0..CONSUMERS { + let hotkey = Hotkey::from_bytes(&[worker_idx as u8 + 100; 32]); + let worker_tags = mt_worker_tags(worker_idx); + let deliveries = queue.reserve_for_models_with_routing( + 16, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + + if deliveries.is_empty() { + continue; + } + progressed = true; + + for delivery in deliveries { + let (task, _) = delivery.commit(); + record_mt_ordered_delivery(&state, worker_idx, &task, DUP); + committed_count.fetch_add(1, Ordering::Relaxed); + } + } + + if !progressed { + assert!( + Instant::now() < drain_deadline, + "mt ordering stress drain stalled with {} queued tasks", + queue.len() + ); + thread::yield_now(); + } + } + + let produced_count = produced_count.load(Ordering::Relaxed); + assert!( + produced_count >= 1_000, + "mt ordering stress should produce enough ordered tasks to prove no reshuffling" + ); + + let state = state.lock().expect("mt ordering state lock"); + assert_eq!(state.produced_sequences.len(), produced_count); + assert_eq!( + committed_count.load(Ordering::Relaxed), + produced_count * DUP + ); + + for task_id in state.produced_sequences.keys() { + assert_eq!( + state.delivered_counts.get(task_id).copied(), + Some(DUP), + "ordered stress task was not committed exactly dup times" + ); + } +} + +#[derive(Copy, Clone)] +enum MixedDeliveryAction { + Commit, + Rollback, + Requeue, +} + +#[derive(Default)] +struct MixedMtState { + produced: HashMap>, + attempts: HashMap<(Uuid, usize), usize>, + committed_pairs: HashSet<(Uuid, usize)>, + committed_counts: HashMap, +} + +fn next_mixed_delivery_action( + state: &Mutex, + worker_idx: usize, + task: &Task, +) -> MixedDeliveryAction { + let worker_tag = mt_worker_tag(worker_idx); + let mut state = state.lock().expect("mixed mt state lock"); + let required_tag = *state + .produced + .get(&task.id) + .expect("delivered task must have been produced by the mixed stress test"); + + assert_eq!( + worker_tag, required_tag, + "worker received a task outside its requested tag routing" + ); + + let attempts = state.attempts.entry((task.id, worker_idx)).or_default(); + *attempts += 1; + + if *attempts == 1 && worker_idx == 0 { + MixedDeliveryAction::Requeue + } else if *attempts == 1 && worker_idx == 1 { + MixedDeliveryAction::Rollback + } else { + MixedDeliveryAction::Commit + } +} + +fn record_mixed_commit(state: &Mutex, worker_idx: usize, task: &Task, dup: usize) { + let worker_tag = mt_worker_tag(worker_idx); + let mut state = state.lock().expect("mixed mt state lock"); + let required_tag = *state + .produced + .get(&task.id) + .expect("committed task must have been produced by the mixed stress test"); + + assert_eq!( + worker_tag, required_tag, + "worker committed a task outside its requested tag routing" + ); + + assert!( + state.committed_pairs.insert((task.id, worker_idx)), + "task was committed more than once by the same worker" + ); + + let committed_count = state.committed_counts.entry(task.id).or_default(); + *committed_count += 1; + assert!( + *committed_count <= dup, + "task was committed more times than the configured dup count" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn multi_threaded_mixed_rollback_requeue_stress_commits_every_slot_exactly_once() { + const CONSUMERS: usize = 8; + const DUP: usize = 1; + const TASKS: usize = 6_000; + + let queue = Arc::new(build_queue_with_config(DUP, 60, 60)); + let state = Arc::new(Mutex::new(MixedMtState::default())); + let committed_count = Arc::new(AtomicUsize::new(0)); + let rollback_count = Arc::new(AtomicUsize::new(0)); + let requeue_count = Arc::new(AtomicUsize::new(0)); + let start = Arc::new(ThreadBarrier::new(CONSUMERS + 1)); + let mut handles = Vec::with_capacity(CONSUMERS); + + for sequence in 0..TASKS { + let required_tag = mt_required_tag(sequence); + let task = create_task( + &format!( + "mt-mixed-preloaded-task-{sequence}-tag-{}", + required_tag.unwrap_or("public") + ), + Some("404-3dgs"), + ); + + { + let mut state = state.lock().expect("mixed mt state lock"); + state.produced.insert(task.id, required_tag); + } + + match required_tag { + Some(tag) => push_task_with_required_tags(&queue, task, &[tag]), + None => queue.push(task), + } + } + + for worker_idx in 0..CONSUMERS { + let queue = Arc::clone(&queue); + let state = Arc::clone(&state); + let committed_count = Arc::clone(&committed_count); + let rollback_count = Arc::clone(&rollback_count); + let requeue_count = Arc::clone(&requeue_count); + let start = Arc::clone(&start); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[worker_idx as u8 + 120; 32]); + let worker_tags = mt_worker_tags(worker_idx); + start.wait(); + let deadline = Instant::now() + Duration::from_secs(30); + + while committed_count.load(Ordering::Acquire) < TASKS * DUP && Instant::now() < deadline + { + let deliveries = queue.reserve_for_models_with_routing( + 8, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + if deliveries.is_empty() { + if queue.is_empty() { + break; + } + thread::yield_now(); + continue; + } + + for delivery in deliveries { + match next_mixed_delivery_action(&state, worker_idx, delivery.task()) { + MixedDeliveryAction::Commit => { + let (task, _) = delivery.commit(); + record_mixed_commit(&state, worker_idx, &task, DUP); + committed_count.fetch_add(1, Ordering::Relaxed); + } + MixedDeliveryAction::Rollback => { + rollback_count.fetch_add(1, Ordering::Relaxed); + delivery.rollback(); + } + MixedDeliveryAction::Requeue => { + requeue_count.fetch_add(1, Ordering::Relaxed); + delivery.requeue_for_other_hotkeys(); + } + } + } + } + })); + } + + start.wait(); + + for handle in handles { + handle + .join() + .expect("mixed mt stress thread should not panic"); + } + + assert_eq!( + TASKS, + state.lock().expect("mixed mt state lock").produced.len() + ); + assert!( + rollback_count.load(Ordering::Relaxed) > 0, + "mixed mt stress should exercise rollback" + ); + assert!( + requeue_count.load(Ordering::Relaxed) > 0, + "mixed mt stress should exercise requeue_for_other_hotkeys" + ); + + let state = state.lock().expect("mixed mt state lock"); + assert_eq!(state.produced.len(), TASKS); + assert!( + queue + .inner + .bucket_for_model("404-3dgs") + .snapshot_entries(usize::MAX) + .is_empty(), + "mixed mt stress should leave no bucket entries after draining" + ); + let recorded_commits = state.committed_counts.values().copied().sum::(); + assert_eq!( + committed_count.load(Ordering::Relaxed), + recorded_commits, + "atomic commit counter should match recorded per-task commits" + ); + assert_eq!( + recorded_commits, + TASKS * DUP, + "every enqueued slot must be committed exactly once: a shortfall means a \ + requeued/rolled-back slot was stranded or dropped without delivery" + ); + assert!( + recorded_commits <= TASKS * DUP, + "mixed stress must not commit more slots than were enqueued" + ); + assert_eq!( + state.committed_pairs.len(), + recorded_commits, + "mixed stress must not commit the same task more than once per worker" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_polls_from_same_hotkey_do_not_duplicate_task() { + let queue = Arc::new(build_queue_with_config(1, 300, 1)); + let task = create_task("single", None); + queue.push(task.clone()); + + let hotkey = Hotkey::from_bytes(&[23u8; 32]); + let barrier = Arc::new(ThreadBarrier::new(5)); + let mut handles = Vec::new(); + for _ in 0..5 { + let q = Arc::clone(&queue); + let barrier = Arc::clone(&barrier); + let hotkey = hotkey.clone(); + handles.push(thread::spawn(move || { + barrier.wait(); + q.pop(1, &hotkey) + })); + } + + let mut delivered = Vec::new(); + for handle in handles { + delivered.extend(handle.join().expect("same-hotkey consumer should complete")); + } + + assert_eq!(delivered.len(), 1); + assert_eq!(delivered[0].0, task); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_polls_racing_final_slot_leave_no_loser_sent_state() { + let queue = Arc::new(build_queue_with_config(1, 300, 1)); + let task = create_task("final-slot", Some("404-3dgs")); + queue.push(task.clone()); + + let racers = 64; + let barrier = Arc::new(ThreadBarrier::new(racers)); + let mut handles = Vec::with_capacity(racers); + for racer_idx in 0..racers { + let q = Arc::clone(&queue); + let barrier = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[90u8 + racer_idx as u8; 32]); + barrier.wait(); + q.reserve_for_models(1, &hotkey, &["404-3dgs"]) + })); + } + + let mut winners = Vec::new(); + for handle in handles { + winners.extend(handle.join().expect("final-slot racer should complete")); + } + + assert_eq!( + winners.len(), + 1, + "exactly one worker should acquire the final available slot" + ); + assert_eq!(winners[0].task(), &task); + let generation = winners[0] + .entry + .as_ref() + .map(|entry| entry.generation) + .expect("winner should hold the queue entry generation"); + let key = (task.id, generation); + assert_eq!( + queue.inner.sent.read_sync(&key, |_, hotkeys| hotkeys.len()), + Some(1), + "losing racers must not leave stale sent-hotkey state" + ); + + let (committed, _) = winners.pop().expect("winner").commit(); + assert_eq!(committed, task); + assert_eq!(queue.len(), 0); + assert!( + queue.inner.sent.read_sync(&key, |_, _| ()).is_none(), + "committing the final slot should clear sent bookkeeping" + ); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none(), + "committing the final slot should clear active-id bookkeeping" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn moderate_contention_drains_all_tasks() { + let queue = Arc::new(build_queue_with_config(1, 10, 1)); + let producers = 2; + let tasks_per_producer = 500; + let consumers = 4; + let total_tasks = producers * tasks_per_producer; + let barrier = Arc::new(ThreadBarrier::new(producers + consumers)); + + let mut producer_handles = Vec::with_capacity(producers); + for producer_idx in 0..producers { + let q = Arc::clone(&queue); + let barrier = Arc::clone(&barrier); + producer_handles.push(thread::spawn(move || { + barrier.wait(); + for i in 0..tasks_per_producer { + q.push(create_task(&format!("p{producer_idx}-{i}"), None)); + } + })); + } + + let counter = Arc::new(AtomicUsize::new(0)); + let mut handles = Vec::with_capacity(consumers); + for consumer_idx in 0..consumers { + let q = Arc::clone(&queue); + let barrier = Arc::clone(&barrier); + let counter = Arc::clone(&counter); + handles.push(thread::spawn(move || { + let hotkey = Hotkey::from_bytes(&[30u8 + consumer_idx as u8; 32]); + let deadline = Instant::now() + Duration::from_secs(10); + barrier.wait(); + + while counter.load(Ordering::Acquire) < total_tasks { + let items = q.pop(7, &hotkey); + if !items.is_empty() { + counter.fetch_add(items.len(), Ordering::Release); + continue; + } + + assert!( + Instant::now() < deadline, + "contention consumer stalled with {} delivered tasks", + counter.load(Ordering::Relaxed) + ); + thread::yield_now(); + } + })); + } + + for handle in producer_handles { + handle.join().expect("contention producer should complete"); + } + for handle in handles { + handle.join().expect("contention consumer should complete"); + } + + assert_eq!(queue.len(), 0); + assert_eq!(counter.load(Ordering::Relaxed), total_tasks); +} diff --git a/src/common/queue/tests/expiration.rs b/src/common/queue/tests/expiration.rs new file mode 100644 index 0000000..aea24bc --- /dev/null +++ b/src/common/queue/tests/expiration.rs @@ -0,0 +1,222 @@ +use super::*; + +#[tokio::test] +async fn expired_reserved_delivery_does_not_requeue_on_drop() { + let queue = build_queue_with_config(1, 1, 60); + let task = create_task("expired-rollback", Some("404-3dgs")); + queue.push(task); + + let hk = Hotkey::from_bytes(&[62u8; 32]); + let mut deliveries = queue.reserve_for_models(1, &hk, &["404-3dgs"]); + let delivery = deliveries.pop().expect("expected reserved task"); + + tokio::time::sleep(Duration::from_secs(2)).await; + drop(delivery); + + assert!(queue.pop_for_models(1, &hk, &["404-3dgs"]).is_empty()); + assert_eq!(queue.len(), 0); +} + +#[tokio::test] +async fn final_commit_retires_task_after_expired_duplicate_is_observed() { + let queue = build_queue_with_config(2, 1, 60); + let task = create_task("expired-commit", Some("404-3dgs")); + queue.push(task.clone()); + + let hk_a = Hotkey::from_bytes(&[64u8; 32]); + let hk_b = Hotkey::from_bytes(&[65u8; 32]); + let mut deliveries = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); + let delivery = deliveries.pop().expect("expected reserved task"); + let generation = delivery + .entry + .as_ref() + .map(|entry| entry.generation) + .expect("delivery should hold generation"); + + tokio::time::sleep(Duration::from_secs(2)).await; + + assert!( + queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty(), + "expired duplicate should be reaped instead of delivered" + ); + assert_eq!( + queue.len(), + 1, + "outstanding lease should keep bookkeeping alive until final commit" + ); + + let (committed, duration) = delivery.commit(); + assert_eq!(committed, task); + assert!( + duration.is_some(), + "final commit should report queue duration once it retires the task" + ); + assert_eq!(queue.len(), 0); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none() + ); + assert!( + queue + .inner + .sent + .read_sync(&(task.id, generation), |_, _| ()) + .is_none() + ); +} + +#[tokio::test] +async fn final_expired_rollback_retires_task_after_duplicate_is_observed() { + let queue = build_queue_with_config(2, 1, 60); + let task = create_task("expired-final-rollback", Some("404-3dgs")); + queue.push(task.clone()); + + let hk_a = Hotkey::from_bytes(&[66u8; 32]); + let hk_b = Hotkey::from_bytes(&[67u8; 32]); + let mut deliveries = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); + let delivery = deliveries.pop().expect("expected reserved task"); + let generation = delivery + .entry + .as_ref() + .map(|entry| entry.generation) + .expect("delivery should hold generation"); + + tokio::time::sleep(Duration::from_secs(2)).await; + + assert!( + queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty(), + "expired duplicate should be reaped instead of delivered" + ); + assert_eq!( + queue.len(), + 1, + "outstanding expired lease should keep bookkeeping alive until rollback" + ); + + delivery.rollback(); + + assert_eq!(queue.len(), 0); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none() + ); + assert!( + queue + .inner + .sent + .read_sync(&(task.id, generation), |_, _| ()) + .is_none() + ); + let hk_c = Hotkey::from_bytes(&[68u8; 32]); + assert!(queue.pop_for_models(1, &hk_c, &["404-3dgs"]).is_empty()); +} + +#[tokio::test] +async fn lazy_expiration_reaps_before_background_cleanup_runs() { + let queue = build_queue_with_config(1, 1, 60); + queue.push(create_task("lazy-expired", None)); + tokio::time::sleep(Duration::from_secs(2)).await; + + let hk = Hotkey::from_bytes(&[63u8; 32]); + assert!(queue.pop(1, &hk).is_empty()); + assert_eq!(queue.len(), 0); +} + +#[tokio::test] +async fn cleanup_removes_old_entries() { + let queue = build_queue_with_config(1, 1, 1); + let task = create_task("Old Task", None); + queue.push(task.clone()); + tokio::time::sleep(Duration::from_millis(2500)).await; + + assert_eq!(queue.len(), 0); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none(), + "background cleanup should clear active bookkeeping for expired entries" + ); +} + +#[tokio::test] +async fn cleanup_removes_idle_zero_capacity_entries_before_ttl() { + let queue = build_queue_with_config(1, 300, 1); + let task = create_task("idle-zero", Some("404-3dgs")); + let generation = queue.inner.acquire_task_generation(task.id); + let entry = raw_entry(task.clone(), 0, 0, generation); + queue.inner.increment_len(); + queue.inner.bucket_for_model("404-3dgs").enqueue_new(entry); + + tokio::time::sleep(Duration::from_millis(1200)).await; + + assert_eq!(queue.len(), 0); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none(), + "cleanup should clear bookkeeping for idle zero-capacity entries" + ); +} + +#[tokio::test] +async fn dropping_queue_clone_does_not_stop_cleanup() { + let queue = build_queue_with_config(1, 1, 1); + let clone = queue.clone(); + drop(clone); + + let task = create_task("Old Task", None); + queue.push(task.clone()); + tokio::time::sleep(Duration::from_millis(2500)).await; + + assert_eq!(queue.len(), 0); + assert!( + queue + .inner + .active_ids + .read_sync(&task.id, |_, _| ()) + .is_none(), + "background cleanup should continue after dropping a queue clone" + ); +} + +#[tokio::test] +async fn cleanup_removes_expired_entries_behind_fresh_front_entry() { + let queue = build_queue_with_config(2, 3, 1); + let old_task = create_task("Old Task", None); + queue.push(old_task.clone()); + + tokio::time::sleep(Duration::from_millis(2100)).await; + + let fresh_task = create_task("Fresh Task", None); + queue.push(fresh_task.clone()); + + let hk_a: Hotkey = Hotkey::from_bytes(&[26u8; 32]); + let first_pop = queue.pop(1, &hk_a); + assert_eq!(first_pop.len(), 1); + assert_eq!(first_pop[0].0, old_task); + assert!(first_pop[0].1.is_none()); + + tokio::time::sleep(Duration::from_millis(1200)).await; + + assert_eq!( + queue.len(), + 1, + "cleanup should remove the expired remaining duplicate while preserving fresh work" + ); + + let hk_b: Hotkey = Hotkey::from_bytes(&[27u8; 32]); + let second_pop = queue.pop(2, &hk_b); + assert_eq!(second_pop.len(), 1); + assert_eq!(second_pop[0].0, fresh_task); + assert!(second_pop[0].1.is_none()); +} diff --git a/src/common/queue/tests.rs b/src/common/queue/tests/lifecycle.rs similarity index 50% rename from src/common/queue/tests.rs rename to src/common/queue/tests/lifecycle.rs index ef0e33e..1f9d727 100644 --- a/src/common/queue/tests.rs +++ b/src/common/queue/tests/lifecycle.rs @@ -1,100 +1,4 @@ -use super::{ModelBucket, QueueEntry, TaskQueue}; -use crate::api::Task; -use crate::crypto::hotkey::Hotkey; -use std::collections::HashSet; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::mpsc; -use std::time::{Duration, Instant}; -use tokio::sync::Barrier; -use tokio::task; -use uuid::Uuid; - -fn build_queue() -> TaskQueue { - build_queue_with_config(2, 300, 1) -} - -fn build_queue_with_config(dup: usize, ttl: u64, cleanup_interval: u64) -> TaskQueue { - TaskQueue::builder() - .dup(dup) - .ttl(ttl) - .cleanup_interval(cleanup_interval) - .default_model("404-3dgs") - .models(["404-3dgs", "404-mesh"]) - .build() -} - -fn create_task(prompt: &str, model: Option<&str>) -> Task { - Task { - id: Uuid::new_v4(), - prompt: Some(Arc::new(prompt.to_string())), - image: None, - model: model.map(ToOwned::to_owned), - seed: 0, - model_params: None, - } -} - -fn raw_entry( - task: Task, - available: usize, - leased: usize, - queued: bool, - generation: usize, -) -> Arc { - Arc::new(QueueEntry { - item: Arc::new(task), - state: AtomicU64::new(QueueEntry::pack_state(available, leased, queued, false)), - timestamp: Instant::now(), - generation, - }) -} - -#[tokio::test] -async fn normalizes_default_model_on_enqueue() { - let queue = build_queue(); - let hotkey = Hotkey::from_bytes(&[1u8; 32]); - queue.push(create_task("robot", None)); - - let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); - assert_eq!(items.len(), 1); - assert_eq!(items[0].0.model.as_deref(), Some("404-3dgs")); -} - -#[tokio::test] -async fn pop_for_models_only_scans_requested_buckets() { - let queue = build_queue(); - let hotkey = Hotkey::from_bytes(&[2u8; 32]); - queue.push(create_task("robot", Some("404-3dgs"))); - queue.push(create_task("car", Some("404-mesh"))); - - let items = queue.pop_for_models(2, &hotkey, &["404-3dgs"]); - assert_eq!(items.len(), 1); - assert_eq!(items[0].0.model.as_deref(), Some("404-3dgs")); - assert!(queue.pop_for_models(1, &hotkey, &["404-3dgs"]).is_empty()); - let other_hotkey = Hotkey::from_bytes(&[9u8; 32]); - let other_items = queue.pop_for_models(1, &other_hotkey, &["404-mesh"]); - assert_eq!(other_items.len(), 1); - assert_eq!(other_items[0].0.model.as_deref(), Some("404-mesh")); -} - -#[tokio::test] -async fn pop_for_models_round_robins_across_requested_buckets() { - let queue = build_queue(); - let hotkey = Hotkey::from_bytes(&[3u8; 32]); - queue.push(create_task("robot", Some("404-3dgs"))); - queue.push(create_task("car", Some("404-mesh"))); - - let items = queue.pop_for_models(2, &hotkey, &["404-3dgs", "404-mesh"]); - assert_eq!(items.len(), 2); - let models: HashSet<_> = items - .into_iter() - .filter_map(|(task, _)| task.model) - .collect(); - assert_eq!(models.len(), 2); - assert!(models.contains("404-3dgs")); - assert!(models.contains("404-mesh")); -} +use super::*; #[tokio::test] async fn single_push_pop_honors_dup_across_hotkeys() { @@ -221,6 +125,47 @@ async fn final_commit_reports_duration_at_commit_time() { ); } +#[tokio::test] +async fn final_dup_slot_with_outstanding_lease_reports_duration_only_on_commit() { + let queue = build_queue_with_config(2, 300, 1); + let task = create_task("dup-final-duration", None); + queue.push(task.clone()); + + let hk_a = Hotkey::from_bytes(&[72u8; 32]); + let hk_b = Hotkey::from_bytes(&[73u8; 32]); + let first = queue + .reserve(1, &hk_a) + .pop() + .expect("expected first duplicate lease"); + let second = queue + .reserve(1, &hk_b) + .pop() + .expect("expected second duplicate lease"); + + assert!( + first.duration().is_none(), + "first dup lease is not the final candidate" + ); + assert!( + second.duration().is_none(), + "last slot is not final while another lease is still outstanding" + ); + + let (first_task, first_duration) = first.commit(); + assert_eq!(first_task, task); + assert!( + first_duration.is_none(), + "non-retiring commit should not report queue duration" + ); + + let (second_task, second_duration) = second.commit(); + assert_eq!(second_task, task); + assert!( + second_duration.is_some(), + "final retiring commit should report queue duration" + ); +} + #[tokio::test] async fn completed_tasks_clear_internal_dedupe_bookkeeping() { let queue = build_queue_with_config(1, 300, 1); @@ -265,38 +210,6 @@ async fn completed_tasks_clear_internal_dedupe_bookkeeping() { } } -#[tokio::test] -async fn unknown_model_uses_default_bucket() { - let queue = build_queue(); - let hotkey = Hotkey::from_bytes(&[40u8; 32]); - queue.push(create_task("robot", Some("missing-model"))); - - let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); - assert_eq!(items.len(), 1); - assert_eq!(items[0].0.model.as_deref(), Some("missing-model")); -} - -#[tokio::test] -async fn model_filtered_pop_and_reserve_ignore_empty_or_unknown_models() { - let queue = build_queue(); - let task = create_task("robot", Some("404-3dgs")); - queue.push(task.clone()); - - let hotkey = Hotkey::from_bytes(&[41u8; 32]); - let empty: [&str; 0] = []; - assert!(queue.reserve_for_models(1, &hotkey, &empty).is_empty()); - assert!( - queue - .pop_for_models(1, &hotkey, &["missing-model"]) - .is_empty() - ); - assert_eq!(queue.len(), 1); - - let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); - assert_eq!(items.len(), 1); - assert_eq!(items[0].0, task); -} - #[tokio::test] async fn dropped_delivery_restores_hotkey_and_dup_budget() { let queue = build_queue_with_config(2, 300, 1); @@ -322,127 +235,74 @@ async fn dropped_delivery_restores_hotkey_and_dup_budget() { } #[tokio::test] -async fn duplicate_reservations_can_be_held_concurrently() { - let queue = build_queue_with_config(2, 300, 60); - let task = create_task("parallel-dup", Some("404-3dgs")); +async fn retire_single_delivery_removes_task_and_bookkeeping() { + let queue = build_queue_with_config(1, 300, 60); + let task = create_task("retire-single", Some("404-3dgs")); queue.push(task.clone()); - let hk_a = Hotkey::from_bytes(&[60u8; 32]); - let hk_b = Hotkey::from_bytes(&[61u8; 32]); - - let mut first = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); - let mut second = queue.reserve_for_models(1, &hk_b, &["404-3dgs"]); - - assert_eq!(first.len(), 1); - assert_eq!(second.len(), 1); - assert_eq!(first[0].task(), &task); - assert_eq!(second[0].task(), &task); - assert!(queue.pop_for_models(1, &hk_a, &["404-3dgs"]).is_empty()); - assert!(queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty()); - - drop(first.pop().expect("first delivery")); - drop(second.pop().expect("second delivery")); -} - -#[tokio::test] -async fn expired_reserved_delivery_does_not_requeue_on_drop() { - let queue = build_queue_with_config(1, 1, 60); - let task = create_task("expired-rollback", Some("404-3dgs")); - queue.push(task); - - let hk = Hotkey::from_bytes(&[62u8; 32]); - let mut deliveries = queue.reserve_for_models(1, &hk, &["404-3dgs"]); - let delivery = deliveries.pop().expect("expected reserved task"); - - tokio::time::sleep(Duration::from_secs(2)).await; - drop(delivery); - - assert!(queue.pop_for_models(1, &hk, &["404-3dgs"]).is_empty()); - assert_eq!(queue.len(), 0); -} - -#[tokio::test] -async fn final_commit_retires_task_after_expired_duplicate_is_observed() { - let queue = build_queue_with_config(2, 1, 60); - let task = create_task("expired-commit", Some("404-3dgs")); - queue.push(task.clone()); - - let hk_a = Hotkey::from_bytes(&[64u8; 32]); - let hk_b = Hotkey::from_bytes(&[65u8; 32]); - let mut deliveries = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); - let delivery = deliveries.pop().expect("expected reserved task"); + let hotkey = Hotkey::from_bytes(&[75u8; 32]); + let mut deliveries = queue.reserve_for_models(1, &hotkey, &["404-3dgs"]); + let delivery = deliveries.pop().expect("expected delivery to retire"); let generation = delivery .entry .as_ref() .map(|entry| entry.generation) .expect("delivery should hold generation"); - tokio::time::sleep(Duration::from_secs(2)).await; + delivery.retire(); - assert!( - queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty(), - "expired duplicate should be reaped instead of delivered" - ); - assert_eq!( - queue.len(), - 1, - "outstanding lease should keep bookkeeping alive until final commit" - ); - - let (committed, duration) = delivery.commit(); - assert_eq!(committed, task); - assert!( - duration.is_some(), - "final commit should report queue duration once it retires the task" - ); assert_eq!(queue.len(), 0); + assert!(queue.pop_for_models(1, &hotkey, &["404-3dgs"]).is_empty()); assert!( queue .inner .active_ids .read_sync(&task.id, |_, _| ()) - .is_none() + .is_none(), + "retired task should clear active bookkeeping" ); assert!( queue .inner .sent .read_sync(&(task.id, generation), |_, _| ()) - .is_none() + .is_none(), + "retired task should clear sent bookkeeping" ); } #[tokio::test] -async fn final_expired_rollback_retires_task_after_duplicate_is_observed() { - let queue = build_queue_with_config(2, 1, 60); - let task = create_task("expired-final-rollback", Some("404-3dgs")); +async fn retire_with_sibling_lease_cannot_be_resurrected_by_rollback() { + let queue = build_queue_with_config(2, 300, 60); + let task = create_task("retire-dup", Some("404-3dgs")); queue.push(task.clone()); - let hk_a = Hotkey::from_bytes(&[66u8; 32]); - let hk_b = Hotkey::from_bytes(&[67u8; 32]); - let mut deliveries = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); - let delivery = deliveries.pop().expect("expected reserved task"); - let generation = delivery - .entry - .as_ref() - .map(|entry| entry.generation) - .expect("delivery should hold generation"); - - tokio::time::sleep(Duration::from_secs(2)).await; - - assert!( - queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty(), - "expired duplicate should be reaped instead of delivered" - ); + let hk_a = Hotkey::from_bytes(&[76u8; 32]); + let hk_b = Hotkey::from_bytes(&[77u8; 32]); + let first = queue + .reserve_for_models(1, &hk_a, &["404-3dgs"]) + .pop() + .expect("expected first lease"); + let second = queue + .reserve_for_models(1, &hk_b, &["404-3dgs"]) + .pop() + .expect("expected sibling lease"); + + first.retire(); assert_eq!( queue.len(), 1, - "outstanding expired lease should keep bookkeeping alive until rollback" + "outstanding sibling lease should keep bookkeeping until it closes" ); - delivery.rollback(); + second.rollback(); assert_eq!(queue.len(), 0); + let hk_c = Hotkey::from_bytes(&[78u8; 32]); + assert!( + queue.pop_for_models(1, &hk_c, &["404-3dgs"]).is_empty(), + "rollback of a sibling lease must not resurrect a retired task" + ); assert!( queue .inner @@ -450,15 +310,86 @@ async fn final_expired_rollback_retires_task_after_duplicate_is_observed() { .read_sync(&task.id, |_, _| ()) .is_none() ); - assert!( - queue - .inner - .sent - .read_sync(&(task.id, generation), |_, _| ()) - .is_none() +} + +#[tokio::test] +async fn rollback_rewinds_cursor_for_same_routing_key() { + let queue = build_queue_with_config(1, 300, 60); + let first = create_task("first", Some("404-3dgs")); + let second = create_task("second", Some("404-3dgs")); + queue.push(first.clone()); + queue.push(second); + + let hotkey = Hotkey::from_bytes(&[74u8; 32]); + let delivery = queue + .reserve_for_models(1, &hotkey, &["404-3dgs"]) + .pop() + .expect("expected first delivery"); + assert_eq!(delivery.task(), &first); + + delivery.rollback(); + + let retry = queue.reserve_for_models(1, &hotkey, &["404-3dgs"]); + assert_eq!(retry.len(), 1); + assert_eq!( + retry[0].task(), + &first, + "rollback should make the restored lower sequence visible before later work" + ); +} + +#[tokio::test] +async fn rollback_from_other_hotkey_can_reoffer_earlier_sequence_after_later_delivery() { + let queue = build_queue_with_config(1, 300, 60); + let first = create_task("first", Some("404-3dgs")); + let second = create_task("second", Some("404-3dgs")); + queue.push(first.clone()); + queue.push(second.clone()); + + let holder_hotkey = Hotkey::from_bytes(&[79u8; 32]); + let worker_hotkey = Hotkey::from_bytes(&[80u8; 32]); + let held_first = queue + .reserve_for_models(1, &holder_hotkey, &["404-3dgs"]) + .pop() + .expect("expected first task to be held by another hotkey"); + + let later = queue.reserve_for_models(1, &worker_hotkey, &["404-3dgs"]); + assert_eq!(later.len(), 1); + assert_eq!(later[0].task(), &second); + later.into_iter().next().expect("later delivery").commit(); + + held_first.rollback(); + + let reoffered = queue.reserve_for_models(1, &worker_hotkey, &["404-3dgs"]); + assert_eq!(reoffered.len(), 1); + assert_eq!( + reoffered[0].task(), + &first, + "rollback/requeue paths are best-effort ordered, not strictly ordered" ); - let hk_c = Hotkey::from_bytes(&[68u8; 32]); - assert!(queue.pop_for_models(1, &hk_c, &["404-3dgs"]).is_empty()); +} + +#[tokio::test] +async fn duplicate_reservations_can_be_held_concurrently() { + let queue = build_queue_with_config(2, 300, 60); + let task = create_task("parallel-dup", Some("404-3dgs")); + queue.push(task.clone()); + + let hk_a = Hotkey::from_bytes(&[60u8; 32]); + let hk_b = Hotkey::from_bytes(&[61u8; 32]); + + let mut first = queue.reserve_for_models(1, &hk_a, &["404-3dgs"]); + let mut second = queue.reserve_for_models(1, &hk_b, &["404-3dgs"]); + + assert_eq!(first.len(), 1); + assert_eq!(second.len(), 1); + assert_eq!(first[0].task(), &task); + assert_eq!(second[0].task(), &task); + assert!(queue.pop_for_models(1, &hk_a, &["404-3dgs"]).is_empty()); + assert!(queue.pop_for_models(1, &hk_b, &["404-3dgs"]).is_empty()); + + drop(first.pop().expect("first delivery")); + drop(second.pop().expect("second delivery")); } #[tokio::test] @@ -466,13 +397,10 @@ async fn zero_remaining_entry_is_reaped_if_encountered() { let queue = build_queue_with_config(1, 300, 1); let task = create_task("stale", Some("404-3dgs")); let generation = queue.inner.acquire_task_generation(task.id); - let entry = raw_entry(task.clone(), 0, 0, true, generation); - queue.inner.len.fetch_add(1, Ordering::Relaxed); + let entry = raw_entry(task.clone(), 0, 0, generation); + queue.inner.increment_len(); let bucket = queue.inner.bucket_for_model("404-3dgs"); - bucket.q.push(entry); - bucket.live.fetch_add(1, Ordering::Relaxed); - bucket.activity.fetch_add(1, Ordering::AcqRel); - bucket.generation.fetch_add(1, Ordering::AcqRel); + bucket.enqueue_new(entry); let hotkey = Hotkey::from_bytes(&[42u8; 32]); assert!(queue.pop_for_models(1, &hotkey, &["404-3dgs"]).is_empty()); @@ -486,6 +414,29 @@ async fn zero_remaining_entry_is_reaped_if_encountered() { ); } +#[tokio::test] +async fn retire_if_idle_marks_leased_entry_retired_until_lease_closes() { + let task = create_task("leased", Some("404-3dgs")); + let entry = raw_entry(task, 1, 0, 1); + + assert!( + entry.try_acquire_lease().is_some(), + "test setup should acquire the only available lease" + ); + assert!( + !entry.retire_if_idle(), + "entry with an outstanding lease cannot be removed immediately" + ); + assert!( + QueueEntry::is_retired(entry.load_state()), + "retired marker should prevent sibling rollback from restoring capacity" + ); + assert!( + entry.finish_rollback(false).retire_now, + "closing the last lease should remove a previously retired entry" + ); +} + #[tokio::test] async fn pop_more_than_exists_returns_all_available_tasks() { for &dup in &[1usize, 4usize] { @@ -594,8 +545,9 @@ async fn rollback_clears_exhausted_cache_for_same_hotkey() { "same hotkey should be marked exhausted after scanning only already-sent work" ); let generation = bucket.generation.load(Ordering::Acquire); + let key = exhausted_key(&hk, &[]); assert_eq!( - bucket.exhausted_generation(&hk), + bucket.exhausted_generation(&key), Some(generation), "empty scan should cache the hotkey as exhausted for the current generation" ); @@ -603,7 +555,7 @@ async fn rollback_clears_exhausted_cache_for_same_hotkey() { delivery.rollback(); assert!( - bucket.exhausted_generation(&hk).is_none(), + bucket.exhausted_generation(&key).is_none(), "rollback should invalidate exhausted-cache state for the same hotkey" ); let retry = queue.pop_for_models(1, &hk, &["404-3dgs"]); @@ -634,288 +586,62 @@ async fn requeue_for_other_hotkeys_keeps_same_hotkey_blocked() { } #[tokio::test] -async fn in_flight_pop_does_not_mark_bucket_exhausted() { - let bucket = Arc::new(ModelBucket::default()); - let entry = raw_entry(create_task("robot", Some("404-3dgs")), 2, 0, false, 7); - bucket.enqueue_new(entry); +async fn requeue_with_remaining_dup_capacity_keeps_already_exhausted_hotkey_blocked() { + let queue = build_queue_with_config(3, 300, 60); + let task = create_task("dup-requeue-exhausted", Some("404-3dgs")); + queue.push(task.clone()); - let hotkey = Hotkey::from_bytes(&[26u8; 32]); - let generation = bucket.generation.load(Ordering::Acquire); - let start_activity = bucket.activity.load(Ordering::Acquire); - let start_in_flight = bucket.in_flight.load(Ordering::Acquire); - assert_eq!(start_in_flight, 0); - - let (ready_tx, ready_rx) = mpsc::channel(); - let (release_tx, release_rx) = mpsc::channel(); - let worker_bucket = Arc::clone(&bucket); - let handle = std::thread::spawn(move || { - worker_bucket - .pop_visible_entry_with_hook(|| { - ready_tx.send(()).expect("signal pop readiness"); - release_rx.recv().expect("wait for release"); - }) - .expect("expected queued entry") - }); - - ready_rx.recv().expect("wait for worker pop"); - assert_eq!(bucket.in_flight.load(Ordering::Acquire), 1); - assert!( - bucket.q.pop().is_none(), - "worker should hold the only entry" - ); + let exhausted_hotkey = Hotkey::from_bytes(&[81u8; 32]); + let mut exhausted_deliveries = queue.reserve_for_models(1, &exhausted_hotkey, &["404-3dgs"]); + let exhausted_delivery = exhausted_deliveries + .pop() + .expect("expected first lease for hotkey that will become exhausted"); - bucket.mark_exhausted_if_stable(&hotkey, generation, start_activity, 0, start_in_flight); assert!( - bucket.exhausted_generation(&hotkey).is_none(), - "in-flight work must block exhausted-cache writes" + queue + .reserve_for_models(1, &exhausted_hotkey, &["404-3dgs"]) + .is_empty(), + "same hotkey should exhaust after scanning only already-sent work" ); - release_tx.send(()).expect("resume worker"); - let popped = handle.join().expect("worker should join"); - bucket.finish_entry(popped, true); -} - -#[tokio::test] -async fn concurrent_push_pop_returns_unique_tasks_per_hotkey() { - let queue = Arc::new(build_queue_with_config(3, 300, 1)); - let q_producer = Arc::clone(&queue); - let producer = task::spawn(async move { - for i in 0..10 { - q_producer.push(create_task(&format!("Task {}", i), None)); - } - }); - producer.await.expect("producer task should complete"); - - let mut handles = Vec::new(); - for idx in 0..3 { - let q = Arc::clone(&queue); - let hotkey: Hotkey = match idx { - 0 => Hotkey::from_bytes(&[20u8; 32]), - 1 => Hotkey::from_bytes(&[21u8; 32]), - _ => Hotkey::from_bytes(&[22u8; 32]), - }; - handles.push(task::spawn(async move { - let res = q.pop(5, &hotkey); - let set: HashSet<_> = res.iter().map(|(task, _)| task.id).collect(); - assert_eq!(set.len(), res.len()); - for (_, duration) in &res { - assert!(duration.is_none()); - } - res - })); - } - - for handle in handles { - handle.await.expect("consumer task should complete"); - } -} - -#[tokio::test] -async fn concurrent_polls_from_same_hotkey_do_not_duplicate_task() { - let queue = Arc::new(build_queue_with_config(1, 300, 1)); - let task = create_task("single", None); - queue.push(task.clone()); - - let hotkey = Hotkey::from_bytes(&[23u8; 32]); - let barrier = Arc::new(Barrier::new(5)); - let mut handles = Vec::new(); - for _ in 0..5 { - let q = Arc::clone(&queue); - let barrier = Arc::clone(&barrier); - let hotkey = hotkey.clone(); - handles.push(task::spawn(async move { - barrier.wait().await; - q.pop(1, &hotkey) - })); - } - - let mut delivered = Vec::new(); - for handle in handles { - delivered.extend(handle.await.expect("same-hotkey consumer should complete")); - } - - assert_eq!(delivered.len(), 1); - assert_eq!(delivered[0].0, task); -} - -#[tokio::test] -async fn lazy_expiration_reaps_before_background_cleanup_runs() { - let queue = build_queue_with_config(1, 1, 60); - queue.push(create_task("lazy-expired", None)); - tokio::time::sleep(Duration::from_secs(2)).await; - - let hk = Hotkey::from_bytes(&[63u8; 32]); - assert!(queue.pop(1, &hk).is_empty()); - assert_eq!(queue.len(), 0); -} - -#[tokio::test] -async fn cleanup_removes_old_entries() { - let queue = build_queue_with_config(1, 1, 1); - queue.push(create_task("Old Task", None)); - tokio::time::sleep(Duration::from_secs(2)).await; - - let hk: Hotkey = Hotkey::from_bytes(&[23u8; 32]); - assert!(queue.pop(1, &hk).is_empty()); -} - -#[tokio::test] -async fn dropping_queue_clone_does_not_stop_cleanup() { - let queue = build_queue_with_config(1, 1, 1); - let clone = queue.clone(); - drop(clone); - - queue.push(create_task("Old Task", None)); - tokio::time::sleep(Duration::from_secs(2)).await; - - let hk: Hotkey = Hotkey::from_bytes(&[37u8; 32]); - assert!(queue.pop(1, &hk).is_empty()); -} - -#[tokio::test] -async fn cleanup_removes_expired_entries_behind_fresh_front_entry() { - let queue = build_queue_with_config(2, 3, 1); - let old_task = create_task("Old Task", None); - queue.push(old_task.clone()); - - tokio::time::sleep(Duration::from_millis(2100)).await; - - let fresh_task = create_task("Fresh Task", None); - queue.push(fresh_task.clone()); - - let hk_a: Hotkey = Hotkey::from_bytes(&[26u8; 32]); - let first_pop = queue.pop(1, &hk_a); - assert_eq!(first_pop.len(), 1); - assert_eq!(first_pop[0].0, old_task); - assert!(first_pop[0].1.is_none()); - - tokio::time::sleep(Duration::from_millis(1200)).await; - + let bucket = queue.inner.bucket_for_model("404-3dgs"); + let exhausted_key = exhausted_key(&exhausted_hotkey, &[]); + let generation_before_requeue = bucket.generation.load(Ordering::Acquire); assert_eq!( - queue.len(), - 1, - "cleanup should remove the expired requeued task while preserving fresh work" + bucket.exhausted_generation(&exhausted_key), + Some(generation_before_requeue), + "test setup should cache the already-sent hotkey as exhausted" ); - let hk_b: Hotkey = Hotkey::from_bytes(&[27u8; 32]); - let second_pop = queue.pop(2, &hk_b); - assert_eq!(second_pop.len(), 1); - assert_eq!(second_pop[0].0, fresh_task); - assert!(second_pop[0].1.is_none()); -} - -#[tokio::test] -async fn try_reserve_respects_capacity() { - let queue = build_queue_with_config(1, 300, 1); - let task = create_task("Task A", None); + let requeue_hotkey = Hotkey::from_bytes(&[82u8; 32]); + let requeued_delivery = queue + .reserve_for_models(1, &requeue_hotkey, &["404-3dgs"]) + .pop() + .expect("expected sibling lease to requeue"); + requeued_delivery.requeue_for_other_hotkeys(); - let first_slot = queue.try_reserve(1).expect("first reservation should fit"); - assert_eq!(queue.len(), 1); + assert_eq!( + bucket.generation.load(Ordering::Acquire), + generation_before_requeue, + "requeue with remaining available dup capacity should not need a generation bump" + ); assert!( - queue.try_reserve(1).is_none(), - "second reservation should exceed the bound" + queue + .reserve_for_models(1, &exhausted_hotkey, &["404-3dgs"]) + .is_empty(), + "requeue_for_other_hotkeys must not unblock a hotkey that already received this task" ); - first_slot.push(task.clone()); - assert_eq!(queue.len(), 1); - - let hk: Hotkey = Hotkey::from_bytes(&[24u8; 32]); - let popped = queue.pop(1, &hk); - assert_eq!(popped.len(), 1); - assert_eq!(popped[0].0, task); -} - -#[tokio::test] -async fn concurrent_try_reserve_respects_capacity() { - let queue = Arc::new(build_queue_with_config(1, 300, 1)); - let max_len = 4; - let contenders = 16; - let barrier = Arc::new(Barrier::new(contenders)); - let mut handles = Vec::with_capacity(contenders); - - for _ in 0..contenders { - let q = Arc::clone(&queue); - let barrier = Arc::clone(&barrier); - handles.push(task::spawn(async move { - barrier.wait().await; - q.try_reserve(max_len) - })); - } - - let mut reservations = Vec::new(); - for handle in handles { - if let Some(reservation) = handle.await.expect("reservation racer should complete") { - reservations.push(reservation); - } - } - - assert_eq!(reservations.len(), max_len); - assert_eq!(queue.len(), max_len); - - drop(reservations); - assert_eq!(queue.len(), 0); -} - -#[tokio::test] -async fn moderate_contention_drains_all_tasks() { - let queue = build_queue_with_config(1, 10, 1); - let producers = 2; - let tasks_per_producer = 500; - let consumers = 4; - let barrier = Arc::new(Barrier::new(producers + consumers)); - - for producer_idx in 0..producers { - let q = queue.clone(); - let barrier = Arc::clone(&barrier); - task::spawn(async move { - barrier.wait().await; - for i in 0..tasks_per_producer { - q.push(create_task(&format!("p{producer_idx}-{i}"), None)); - } - }); - } - - let counter = Arc::new(AtomicUsize::new(0)); - let mut handles = Vec::with_capacity(consumers); - for consumer_idx in 0..consumers { - let q = queue.clone(); - let barrier = Arc::clone(&barrier); - let counter = Arc::clone(&counter); - handles.push(task::spawn(async move { - let hotkey = Hotkey::from_bytes(&[30u8 + consumer_idx as u8; 32]); - barrier.wait().await; - let mut idle_start: Option = None; - loop { - if counter.load(Ordering::Relaxed) == producers * tasks_per_producer { - break; - } - - let items = q.pop(7, &hotkey); - if !items.is_empty() { - counter.fetch_add(items.len(), Ordering::Relaxed); - idle_start = None; - } else { - idle_start.get_or_insert_with(Instant::now); - if idle_start.expect("idle start should be set").elapsed() - >= Duration::from_secs(5) - { - break; - } - tokio::time::sleep(Duration::from_millis(50)).await; - } - } - })); - } - - for handle in handles { - handle.await.expect("contention consumer should complete"); - } - - assert_eq!(queue.len(), 0); + let fresh_hotkey = Hotkey::from_bytes(&[83u8; 32]); + let fresh_delivery = queue.reserve_for_models(1, &fresh_hotkey, &["404-3dgs"]); + assert_eq!(fresh_delivery.len(), 1); assert_eq!( - counter.load(Ordering::Relaxed), - producers * tasks_per_producer + fresh_delivery[0].task(), + &task, + "a fresh hotkey should still see the requeued available slot" ); + + drop(exhausted_delivery); } #[tokio::test] diff --git a/src/common/queue/tests/mod.rs b/src/common/queue/tests/mod.rs new file mode 100644 index 0000000..393df92 --- /dev/null +++ b/src/common/queue/tests/mod.rs @@ -0,0 +1,190 @@ +use super::{ + DEFAULT_RESERVATION_SCAN_CAP, ExhaustedKey, ModelBucket, QueueEntry, TaskQueue, TaskRouting, + WorkerRouting, +}; +use crate::api::Task; +use crate::crypto::hotkey::Hotkey; +use prometheus::{IntGauge, opts}; +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, Barrier as ThreadBarrier, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; +use uuid::Uuid; + +fn build_queue() -> TaskQueue { + build_queue_with_config(2, 300, 1) +} + +fn build_queue_with_config(dup: usize, ttl: u64, cleanup_interval: u64) -> TaskQueue { + TaskQueue::builder() + .dup(dup) + .ttl(ttl) + .cleanup_interval(cleanup_interval) + .default_model("404-3dgs") + .models(["404-3dgs", "404-mesh"]) + .build() +} + +fn build_queue_with_len_gauge() -> (TaskQueue, IntGauge) { + let gauge = IntGauge::with_opts(opts!("test_queue_len", "Observed test queue length")) + .expect("queue len gauge"); + let queue = TaskQueue::builder() + .dup(1) + .ttl(300) + .cleanup_interval(1) + .default_model("404-3dgs") + .models(["404-3dgs"]) + .queue_len_gauge(gauge.clone()) + .build(); + (queue, gauge) +} + +fn create_task(prompt: &str, model: Option<&str>) -> Task { + Task { + id: Uuid::new_v4(), + prompt: Some(Arc::new(prompt.to_string())), + image: None, + model: model.map(ToOwned::to_owned), + seed: 0, + model_params: None, + } +} + +fn push_task_with_required_tags(queue: &TaskQueue, task: Task, tags: &[&str]) { + queue + .try_reserve(usize::MAX) + .expect("reserve queue slot") + .push_with_routing( + task, + TaskRouting::with_required_worker_tags( + tags.iter().map(|tag| tag.to_string()).collect(), + ), + ); +} + +fn bucket_task_ids(queue: &TaskQueue, model: &str) -> Vec { + queue + .inner + .bucket_for_model(model) + .snapshot_entries(usize::MAX) + .into_iter() + .map(|(_, entry)| entry.item.id) + .collect() +} + +#[derive(Default)] +struct MtDeliveryState { + produced: HashMap>, + delivered_pairs: HashSet<(Uuid, usize)>, + delivered_counts: HashMap, +} + +struct MtOrderingState { + produced_sequences: HashMap, + delivered_counts: HashMap, + last_sequence_by_worker: Vec>, +} + +fn mt_required_tag(sequence: usize) -> Option<&'static str> { + match sequence % 3 { + 0 => Some("acme"), + 1 => Some("enterprise"), + _ => None, + } +} + +fn mt_worker_tag(worker_idx: usize) -> Option<&'static str> { + match worker_idx % 3 { + 0 => Some("acme"), + 1 => Some("enterprise"), + _ => None, + } +} + +fn mt_worker_tags(worker_idx: usize) -> Vec { + mt_worker_tag(worker_idx) + .into_iter() + .map(str::to_string) + .collect() +} + +fn exhausted_key(hotkey: &Hotkey, tags: &[String]) -> ExhaustedKey { + ExhaustedKey::new(hotkey, WorkerRouting::from_tags(tags).cache_key()) +} + +fn record_mt_delivery(state: &Mutex, worker_idx: usize, task: &Task, dup: usize) { + let worker_tag = mt_worker_tag(worker_idx); + let mut state = state.lock().expect("mt delivery state lock"); + let required_tag = *state + .produced + .get(&task.id) + .expect("delivered task must have been produced by the stress test"); + + assert_eq!( + worker_tag, required_tag, + "worker received a task outside its requested tag routing" + ); + + assert!( + state.delivered_pairs.insert((task.id, worker_idx)), + "task was delivered more than once to the same worker" + ); + + let delivered_count = state.delivered_counts.entry(task.id).or_default(); + *delivered_count += 1; + assert!( + *delivered_count <= dup, + "task was delivered more times than the configured dup count" + ); +} + +fn record_mt_ordered_delivery( + state: &Mutex, + worker_idx: usize, + task: &Task, + dup: usize, +) { + let mut state = state.lock().expect("mt ordering state lock"); + let sequence = *state + .produced_sequences + .get(&task.id) + .expect("delivered task must have been produced by the ordering stress test"); + + assert_eq!( + mt_worker_tag(worker_idx), + mt_required_tag(sequence), + "worker received a task outside its requested tag routing" + ); + + if let Some(previous_sequence) = state.last_sequence_by_worker[worker_idx] { + assert!( + previous_sequence < sequence, + "worker observed non-increasing task order: previous={previous_sequence}, current={sequence}" + ); + } + state.last_sequence_by_worker[worker_idx] = Some(sequence); + + let delivered_count = state.delivered_counts.entry(task.id).or_default(); + *delivered_count += 1; + assert!( + *delivered_count <= dup, + "task was delivered more times than the configured dup count" + ); +} + +fn raw_entry(task: Task, available: usize, leased: usize, generation: usize) -> Arc { + Arc::new(QueueEntry { + item: Arc::new(task), + routing: TaskRouting::default(), + state: AtomicU64::new(QueueEntry::pack_state(available, leased, false)), + timestamp: Instant::now(), + generation, + }) +} + +mod capacity; +mod concurrency; +mod expiration; +mod lifecycle; +mod routing; diff --git a/src/common/queue/tests/routing.rs b/src/common/queue/tests/routing.rs new file mode 100644 index 0000000..8f2b753 --- /dev/null +++ b/src/common/queue/tests/routing.rs @@ -0,0 +1,501 @@ +use super::*; + +#[tokio::test] +async fn normalizes_default_model_on_enqueue() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[1u8; 32]); + queue.push(create_task("robot", None)); + + let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); + assert_eq!(items.len(), 1); + assert_eq!(items[0].0.model.as_deref(), Some("404-3dgs")); +} + +#[tokio::test] +async fn pop_for_models_only_scans_requested_buckets() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[2u8; 32]); + queue.push(create_task("robot", Some("404-3dgs"))); + queue.push(create_task("car", Some("404-mesh"))); + + let items = queue.pop_for_models(2, &hotkey, &["404-3dgs"]); + assert_eq!(items.len(), 1); + assert_eq!(items[0].0.model.as_deref(), Some("404-3dgs")); + assert!(queue.pop_for_models(1, &hotkey, &["404-3dgs"]).is_empty()); + let other_hotkey = Hotkey::from_bytes(&[9u8; 32]); + let other_items = queue.pop_for_models(1, &other_hotkey, &["404-mesh"]); + assert_eq!(other_items.len(), 1); + assert_eq!(other_items[0].0.model.as_deref(), Some("404-mesh")); +} + +#[tokio::test] +async fn reserve_for_models_returns_tagged_task_for_matching_worker_tags() { + let queue = build_queue(); + let task = create_task("matching", Some("404-3dgs")); + push_task_with_required_tags(&queue, task.clone(), &["acme"]); + + let worker_tags = vec!["acme".to_string()]; + let items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[20u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + + assert_eq!(items.len(), 1); + assert_eq!(items[0].task(), &task); +} + +#[tokio::test] +async fn worker_tag_filter_keeps_tagged_task_for_later_matching_worker() { + let queue = build_queue(); + let task = create_task("restricted", Some("404-3dgs")); + push_task_with_required_tags(&queue, task.clone(), &["acme"]); + + let non_matching_tags = vec!["other".to_string()]; + let non_matching_items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[24u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(non_matching_tags.as_slice()), + ); + assert!( + non_matching_items.is_empty(), + "non-matching worker must not receive tagged task" + ); + + let matching_tags = vec!["acme".to_string()]; + let matching_items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[25u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(matching_tags.as_slice()), + ); + assert_eq!(matching_items.len(), 1); + assert_eq!(matching_items[0].task(), &task); +} + +#[tokio::test] +async fn worker_tag_filter_keeps_public_task_for_untagged_worker() { + let queue = build_queue(); + let task = create_task("public", Some("404-3dgs")); + queue.push(task.clone()); + + let tagged_worker_tags = vec!["acme".to_string()]; + let tagged_items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[26u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(tagged_worker_tags.as_slice()), + ); + assert!( + tagged_items.is_empty(), + "worker requesting tagged work must not receive public task" + ); + + let untagged_items = + queue.reserve_for_models(1, &Hotkey::from_bytes(&[37u8; 32]), &["404-3dgs"]); + assert_eq!(untagged_items.len(), 1); + assert_eq!(untagged_items[0].task(), &task); +} + +#[tokio::test] +async fn worker_tag_filter_does_not_reshuffle_skipped_tasks() { + let queue = build_queue(); + let first = create_task("first restricted", Some("404-3dgs")); + let second = create_task("second restricted", Some("404-3dgs")); + push_task_with_required_tags(&queue, first.clone(), &["acme"]); + push_task_with_required_tags(&queue, second.clone(), &["enterprise"]); + + assert_eq!( + bucket_task_ids(&queue, "404-3dgs"), + vec![first.id, second.id] + ); + + let non_matching_tags = vec!["other".to_string()]; + let non_matching_items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[27u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(non_matching_tags.as_slice()), + ); + assert!(non_matching_items.is_empty()); + + assert_eq!( + bucket_task_ids(&queue, "404-3dgs"), + vec![first.id, second.id] + ); +} + +#[tokio::test] +async fn exhausted_cache_is_worker_routing_aware() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[28u8; 32]); + let task = create_task("restricted", Some("404-3dgs")); + push_task_with_required_tags(&queue, task.clone(), &["acme"]); + + let non_matching_tags = vec!["other".to_string()]; + let non_matching_items = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(non_matching_tags.as_slice()), + ); + assert!(non_matching_items.is_empty()); + + let bucket = queue.inner.bucket_for_model("404-3dgs"); + let non_matching_key = exhausted_key(&hotkey, non_matching_tags.as_slice()); + assert_eq!( + bucket.exhausted_generation(&non_matching_key), + Some(bucket.generation.load(Ordering::Acquire)), + "non-matching routing should cache its own exhausted scan" + ); + + let matching_tags = vec!["acme".to_string()]; + let matching_items = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(matching_tags.as_slice()), + ); + assert_eq!(matching_items.len(), 1); + assert_eq!(matching_items[0].task(), &task); +} + +#[tokio::test] +async fn bounded_scan_cursor_reaches_work_after_ineligible_prefix() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[29u8; 32]); + let worker_tags = Vec::::new(); + + for idx in 0..=DEFAULT_RESERVATION_SCAN_CAP { + push_task_with_required_tags( + &queue, + create_task(&format!("restricted-{idx}"), Some("404-3dgs")), + &["acme"], + ); + } + let public = create_task("public-after-prefix", Some("404-3dgs")); + queue.push(public.clone()); + + let first = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert!( + first.is_empty(), + "first bounded scan should stop before reaching the public task" + ); + + let second = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(second.len(), 1); + assert_eq!(second[0].task(), &public); +} + +#[tokio::test] +async fn reservation_scan_cap_can_track_queue_limit_above_default() { + let queue = TaskQueue::builder() + .dup(2) + .ttl(300) + .cleanup_interval(1) + .default_model("404-3dgs") + .models(["404-3dgs"]) + .reservation_scan_cap(DEFAULT_RESERVATION_SCAN_CAP + 64) + .build(); + let hotkey = Hotkey::from_bytes(&[32u8; 32]); + let worker_tags = Vec::::new(); + + for idx in 0..=DEFAULT_RESERVATION_SCAN_CAP { + push_task_with_required_tags( + &queue, + create_task(&format!("restricted-{idx}"), Some("404-3dgs")), + &["acme"], + ); + } + let public = create_task("public-above-default-cap", Some("404-3dgs")); + queue.push(public.clone()); + + let items = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(items.len(), 1); + assert_eq!(items[0].task(), &public); +} + +#[tokio::test] +async fn reservation_scan_cap_can_be_built_below_default() { + let queue = TaskQueue::builder() + .dup(2) + .ttl(300) + .cleanup_interval(1) + .default_model("404-3dgs") + .models(["404-3dgs"]) + .reservation_scan_cap(2) + .build(); + let hotkey = Hotkey::from_bytes(&[35u8; 32]); + let worker_tags = Vec::::new(); + + for idx in 0..2 { + push_task_with_required_tags( + &queue, + create_task(&format!("restricted-below-default-{idx}"), Some("404-3dgs")), + &["acme"], + ); + } + let public = create_task("public-after-below-default-cap", Some("404-3dgs")); + queue.push(public.clone()); + + let first = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert!( + first.is_empty(), + "custom scan cap below default should stop before later eligible work" + ); + + let second = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(second.len(), 1); + assert_eq!(second[0].task(), &public); +} + +#[tokio::test] +async fn reservation_scan_cap_can_be_raised_after_build() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[33u8; 32]); + let worker_tags = Vec::::new(); + + for idx in 0..=DEFAULT_RESERVATION_SCAN_CAP { + push_task_with_required_tags( + &queue, + create_task(&format!("restricted-{idx}"), Some("404-3dgs")), + &["acme"], + ); + } + let public = create_task("public-after-scan-cap-update", Some("404-3dgs")); + queue.push(public.clone()); + + let first = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert!(first.is_empty()); + + let other_hotkey = Hotkey::from_bytes(&[34u8; 32]); + queue.set_reservation_scan_cap(DEFAULT_RESERVATION_SCAN_CAP + 64); + + let second = queue.reserve_for_models_with_routing( + 1, + &other_hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(second.len(), 1); + assert_eq!(second[0].task(), &public); +} + +#[tokio::test] +async fn reservation_scan_cap_can_be_lowered_after_build() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[36u8; 32]); + let worker_tags = Vec::::new(); + + for idx in 0..2 { + push_task_with_required_tags( + &queue, + create_task(&format!("restricted-after-lower-{idx}"), Some("404-3dgs")), + &["acme"], + ); + } + let public = create_task("public-after-runtime-lower", Some("404-3dgs")); + queue.push(public.clone()); + + queue.set_reservation_scan_cap(2); + let first = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert!( + first.is_empty(), + "runtime scan cap below default should stop before later eligible work" + ); + + let second = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(second.len(), 1); + assert_eq!(second[0].task(), &public); +} + +#[tokio::test] +async fn scan_cursor_wraps_when_start_seq_is_past_tail() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[30u8; 32]); + let task = create_task("wrap-target", Some("404-3dgs")); + queue.push(task.clone()); + + let bucket = queue.inner.bucket_for_model("404-3dgs"); + let key = exhausted_key(&hotkey, &[]); + bucket.update_scan_cursor(&key, u64::MAX); + + let items = queue.reserve_for_models(1, &hotkey, &["404-3dgs"]); + assert_eq!(items.len(), 1); + assert_eq!(items[0].task(), &task); +} + +#[tokio::test] +async fn scan_cursor_isolated_by_worker_routing_for_same_hotkey() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[31u8; 32]); + let b_task = create_task("b-target", Some("404-3dgs")); + let a_task = create_task("a-target", Some("404-3dgs")); + push_task_with_required_tags(&queue, b_task.clone(), &["b"]); + push_task_with_required_tags(&queue, a_task.clone(), &["a"]); + for idx in 0..DEFAULT_RESERVATION_SCAN_CAP { + queue.push(create_task(&format!("public-{idx}"), Some("404-3dgs"))); + } + + let a_tags = vec!["a".to_string()]; + let a_items = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(a_tags.as_slice()), + ); + assert_eq!(a_items.len(), 1); + assert_eq!(a_items[0].task(), &a_task); + + let b_tags = vec!["b".to_string()]; + let b_items = queue.reserve_for_models_with_routing( + 1, + &hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(b_tags.as_slice()), + ); + assert_eq!(b_items.len(), 1); + assert_eq!(b_items[0].task(), &b_task); +} + +#[tokio::test] +async fn empty_default_and_explicit_empty_worker_routing_share_cache_key() { + let explicit: Vec = Vec::new(); + assert!( + WorkerRouting::default().cache_key() == WorkerRouting::from_tags(&explicit).cache_key() + ); +} + +#[tokio::test] +async fn reserve_for_models_filters_by_worker_tags() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[21u8; 32]); + let restricted = create_task("restricted", Some("404-3dgs")); + let public = create_task("public", Some("404-3dgs")); + push_task_with_required_tags(&queue, restricted.clone(), &["acme"]); + queue.push(public.clone()); + + let no_tag_items = queue.pop_for_models(2, &hotkey, &["404-3dgs"]); + assert_eq!(no_tag_items.len(), 1); + assert_eq!(no_tag_items[0].0, public); + + let tagged_hotkey = Hotkey::from_bytes(&[22u8; 32]); + let worker_tags = vec!["acme".to_string()]; + let tagged_items = queue.reserve_for_models_with_routing( + 1, + &tagged_hotkey, + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(tagged_items.len(), 1); + assert_eq!(tagged_items[0].task(), &restricted); +} + +#[tokio::test] +async fn worker_tag_filter_uses_any_matching_tag() { + let queue = build_queue(); + let task = create_task("enterprise", Some("404-3dgs")); + push_task_with_required_tags(&queue, task.clone(), &["acme", "premium"]); + + let worker_tags = vec!["premium".to_string()]; + let items = queue.reserve_for_models_with_routing( + 1, + &Hotkey::from_bytes(&[23u8; 32]), + &["404-3dgs"], + WorkerRouting::from_tags(worker_tags.as_slice()), + ); + assert_eq!(items.len(), 1); + assert_eq!(items[0].task(), &task); +} + +#[tokio::test] +async fn pop_for_models_round_robins_across_requested_buckets() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[3u8; 32]); + queue.push(create_task("robot", Some("404-3dgs"))); + queue.push(create_task("car", Some("404-mesh"))); + + let items = queue.pop_for_models(2, &hotkey, &["404-3dgs", "404-mesh"]); + assert_eq!(items.len(), 2); + let models: HashSet<_> = items + .into_iter() + .filter_map(|(task, _)| task.model) + .collect(); + assert_eq!(models.len(), 2); + assert!(models.contains("404-3dgs")); + assert!(models.contains("404-mesh")); +} + +#[tokio::test] +async fn unknown_model_uses_default_bucket() { + let queue = build_queue(); + let hotkey = Hotkey::from_bytes(&[40u8; 32]); + queue.push(create_task("robot", Some("missing-model"))); + + let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); + assert_eq!(items.len(), 1); + assert_eq!(items[0].0.model.as_deref(), Some("missing-model")); +} + +#[tokio::test] +async fn model_filtered_pop_and_reserve_ignore_empty_or_unknown_models() { + let queue = build_queue(); + let task = create_task("robot", Some("404-3dgs")); + queue.push(task.clone()); + + let hotkey = Hotkey::from_bytes(&[41u8; 32]); + let empty: [&str; 0] = []; + assert!(queue.reserve_for_models(1, &hotkey, &empty).is_empty()); + assert!( + queue + .pop_for_models(1, &hotkey, &["missing-model"]) + .is_empty() + ); + assert_eq!(queue.len(), 1); + + let items = queue.pop_for_models(1, &hotkey, &["404-3dgs"]); + assert_eq!(items.len(), 1); + assert_eq!(items[0].0, task); +} diff --git a/src/db/data_access.rs b/src/db/data_access.rs index 486b289..9d916c8 100644 --- a/src/db/data_access.rs +++ b/src/db/data_access.rs @@ -7,6 +7,7 @@ use tracing::info; use super::connection::StmtKey; use super::task_lifecycle::GenerationBillingOwner; use super::{ActivityEventRow, Database, WorkerEventRow}; +use crate::api::request::normalize_worker_tags; #[derive(Debug, Clone)] pub struct ApiKeyBillingOwnerRow { @@ -33,10 +34,17 @@ pub struct GatewaySettingsRow { pub registered_window_ms: u64, } +#[derive(Debug, Clone)] +pub struct CompanyMeta { + pub name: String, + pub concurrent_limit: u64, + pub daily_limit: u64, + pub worker_tags: Vec, +} + type UserKeyHashesRow = (i64, String, u64, u64, Vec>); type UserMetaRow = (String, u64, u64); -type TaskLimitsRow = (String, u64, u64); -type CompanyMetaRow = (uuid::Uuid, TaskLimitsRow); +type CompanyMetaRow = (uuid::Uuid, CompanyMeta); type CompanyKeyHashesRow = (uuid::Uuid, Vec>); fn nonnegative_i32_to_u64(value: i32, field_name: &str) -> Result { @@ -58,12 +66,20 @@ fn positive_i64_to_u64(value: i64, field_name: &str) -> Result { Ok(value as u64) } -fn decode_task_limits(name: String, concurrent: i32, daily: i32) -> Result { - Ok(( +fn decode_company_meta_fields( + name: String, + concurrent: i32, + daily: i32, + worker_tags: Vec, +) -> Result { + let worker_tags = normalize_worker_tags(&worker_tags) + .map_err(|message| anyhow!("Invalid company worker_tags for company {name}: {message}"))?; + Ok(CompanyMeta { name, - nonnegative_i32_to_u64(concurrent, "task_limit_concurrent")?, - nonnegative_i32_to_u64(daily, "task_limit_daily")?, - )) + concurrent_limit: nonnegative_i32_to_u64(concurrent, "task_limit_concurrent")?, + daily_limit: nonnegative_i32_to_u64(daily, "task_limit_daily")?, + worker_tags, + }) } fn decode_user_key_hashes_row(row: Row) -> Result { @@ -86,14 +102,19 @@ fn decode_company_meta_row(row: Row) -> Result { let name: String = row.get("name"); let concurrent: i32 = row.get("task_limit_concurrent"); let daily: i32 = row.get("task_limit_daily"); - Ok((id, decode_task_limits(name, concurrent, daily)?)) + let worker_tags: Vec = row.get("worker_tags"); + Ok(( + id, + decode_company_meta_fields(name, concurrent, daily, worker_tags)?, + )) } -fn decode_company_limits_row(row: Row) -> Result { +fn decode_company_limits_row(row: Row) -> Result { let name: String = row.get("name"); let concurrent: i32 = row.get("task_limit_concurrent"); let daily: i32 = row.get("task_limit_daily"); - decode_task_limits(name, concurrent, daily) + let worker_tags: Vec = row.get("worker_tags"); + decode_company_meta_fields(name, concurrent, daily, worker_tags) } fn decode_user_meta_row(row: Row) -> Result { @@ -402,18 +423,18 @@ WHERE id = $1; "#; pub(super) const Q_FULL_COMPANIES_META: &'static str = r#" -SELECT id, name, task_limit_concurrent, task_limit_daily FROM companies; +SELECT id, name, task_limit_concurrent, task_limit_daily, worker_tags FROM companies; "#; pub(super) const Q_DELTA_COMPANIES_META: &'static str = r#" -SELECT id, name, task_limit_concurrent, task_limit_daily +SELECT id, name, task_limit_concurrent, task_limit_daily, worker_tags FROM companies WHERE (updated_at > $1 AND updated_at <= $2) OR (created_at > $1 AND created_at <= $2); "#; pub(super) const Q_COMPANY_META_BY_ID: &'static str = r#" -SELECT id, name, task_limit_concurrent, task_limit_daily +SELECT id, name, task_limit_concurrent, task_limit_daily, worker_tags FROM companies WHERE id = $1 LIMIT 1; @@ -566,7 +587,7 @@ created_at\ .collect()) } - pub async fn fetch_full_companies_meta(&self) -> Result> { + pub async fn fetch_full_companies_meta(&self) -> Result> { let rows = self.query_prepared(StmtKey::FullCompaniesMeta, &[]).await?; rows.into_iter() .map(decode_company_meta_row) @@ -577,7 +598,7 @@ created_at\ &self, since: i64, until: i64, - ) -> Result> { + ) -> Result> { let params: [&(dyn ToSql + Sync); 2] = [&since, &until]; let rows = self .query_prepared(StmtKey::DeltaCompaniesMeta, ¶ms) @@ -587,10 +608,7 @@ created_at\ .collect::>>() } - pub async fn fetch_company_meta( - &self, - company_id: uuid::Uuid, - ) -> Result> { + pub async fn fetch_company_meta(&self, company_id: uuid::Uuid) -> Result> { let params: [&(dyn ToSql + Sync); 1] = [&company_id]; let rows = self .query_prepared(StmtKey::CompanyMetaById, ¶ms) @@ -752,3 +770,37 @@ fn append_copy_field(buf: &mut Vec, value: Option<&str>) { } } } + +#[cfg(test)] +mod tests { + use super::decode_company_meta_fields; + + #[test] + fn decode_company_meta_fields_normalizes_and_dedupes_worker_tags() { + let meta = decode_company_meta_fields( + "Acme".to_string(), + 1, + 10, + vec![ + " Premium ".to_string(), + "premium".to_string(), + "ACME".to_string(), + ], + ) + .expect("company meta should decode"); + + assert_eq!(meta.worker_tags, vec!["premium", "acme"]); + } + + #[test] + fn decode_company_meta_fields_rejects_invalid_worker_tags() { + let error = + decode_company_meta_fields("Acme".to_string(), 1, 10, vec!["bad tag".to_string()]) + .expect_err("invalid company worker tag should be rejected"); + + assert!( + error.to_string().contains("Invalid company worker_tags"), + "unexpected error: {error}" + ); + } +} diff --git a/src/db/key_validator.rs b/src/db/key_validator.rs index a95cd26..2daf11b 100644 --- a/src/db/key_validator.rs +++ b/src/db/key_validator.rs @@ -13,7 +13,7 @@ use tracing::{error, info, warn}; use uuid::Uuid; use super::Database; -use super::data_access::ApiKeyBillingOwnerRow; +use super::data_access::{ApiKeyBillingOwnerRow, CompanyMeta}; use super::gateway_settings::{GatewayRuntimeSettingsConfig, GatewayRuntimeSettingsStore}; use super::task_lifecycle::GenerationBillingOwner; use crate::crypto::crypto_provider::ApiKeyHasher; @@ -38,6 +38,14 @@ struct CachedUserMeta { daily_limit: u64, } +#[derive(Debug, Clone)] +pub struct CachedCompanyMeta { + pub name: Arc, + pub concurrent_limit: u64, + pub daily_limit: u64, + pub worker_tags: Vec, +} + #[derive(Clone)] struct UnknownKeyIpGuard { misses: Cache, Arc, RandomState>, @@ -92,7 +100,7 @@ pub struct ApiKeyValidator { // mapping from user_id -> cached metadata used by the gateway auth/rate-limit path users_meta: scc::HashMap, // mapping from company_id -> cached metadata used by the gateway auth/rate-limit path - companies: scc::HashMap, u64, u64), RandomState>, + companies: scc::HashMap, // forward mapping: company_id -> list of api_key_hashes company_keys: scc::HashMap, RandomState>, // reverse mapping: api_key_hash -> key record for auth and billing @@ -118,7 +126,7 @@ pub struct ApiKeyLookup { pub user_email: Option>, pub user_limits: Option<(u64, u64)>, pub company_id: Option, - pub company_info: Option<(Arc, u64, u64)>, + pub company_info: Option, pub billing_owner: Option, pub auth_lookup_blocked: bool, } @@ -279,6 +287,15 @@ impl ApiKeyValidator { Arc::clone(&self.gateway_settings) } + fn cached_company_meta(meta: CompanyMeta) -> CachedCompanyMeta { + CachedCompanyMeta { + name: Arc::::from(meta.name), + concurrent_limit: meta.concurrent_limit, + daily_limit: meta.daily_limit, + worker_tags: meta.worker_tags, + } + } + fn convert_hash(bytes: &[u8]) -> Option<[u8; 32]> { (bytes.len() == 32).then(|| { let mut array = [0u8; 32]; @@ -510,9 +527,8 @@ impl ApiKeyValidator { if let Some(since) = Self::delta_since(last_sync_ms, now, "companies_meta") { let deltas = self.db.fetch_delta_companies_meta(since, now).await?; - for (cid, limits) in deltas { - let (name, concurrent, daily) = limits; - let value = (Arc::::from(name), concurrent, daily); + for (cid, meta) in deltas { + let value = Self::cached_company_meta(meta); match self.companies.entry_async(cid).await { scc::hash_map::Entry::Occupied(mut entry) => { *entry.get_mut() = value; @@ -525,11 +541,10 @@ impl ApiKeyValidator { } else { let all = self.db.fetch_full_companies_meta().await?; self.companies.clear_async().await; - for (cid, limits) in all { - let (name, concurrent, daily) = limits; + for (cid, meta) in all { let _ = self .companies - .insert_async(cid, (Arc::::from(name), concurrent, daily)) + .insert_async(cid, Self::cached_company_meta(meta)) .await; } } @@ -625,8 +640,8 @@ impl ApiKeyValidator { Some(entry.get().clone()) } else { match self.db.fetch_company_meta(cid).await { - Ok(Some((name, concurrent, daily))) => { - let company_info = (Arc::::from(name), concurrent, daily); + Ok(Some(meta)) => { + let company_info = Self::cached_company_meta(meta); let _ = self.companies.insert_async(cid, company_info.clone()).await; Some(company_info) } @@ -796,17 +811,40 @@ impl ApiKeyValidator { company_name: &str, concurrent_limit: u64, daily_limit: u64, + ) { + self.seed_company_key_with_worker_tags( + api_key, + company_id, + company_name, + concurrent_limit, + daily_limit, + Vec::new(), + ) + .await; + } + + #[cfg(any(test, feature = "test-support"))] + #[allow(dead_code)] + pub async fn seed_company_key_with_worker_tags( + &self, + api_key: &str, + company_id: Uuid, + company_name: &str, + concurrent_limit: u64, + daily_limit: u64, + worker_tags: Vec, ) { let key_hash = self.hasher.compute_hash_array(api_key); let _ = self .companies .insert_async( company_id, - ( - Arc::::from(company_name), + CachedCompanyMeta { + name: Arc::::from(company_name), concurrent_limit, daily_limit, - ), + worker_tags, + }, ) .await; let _ = self diff --git a/src/http3/handlers/result/add_result.rs b/src/http3/handlers/result/add_result.rs index 05d17d0..9d07dc9 100644 --- a/src/http3/handlers/result/add_result.rs +++ b/src/http3/handlers/result/add_result.rs @@ -226,18 +226,17 @@ pub async fn add_result_handler( .inc_task_completed(outcome.worker_hotkey.as_ref()) .await; metrics.inc_task_completed_kind(task_kind); + if let Some(elapsed) = elapsed_secs { + metrics + .record_completion_time(outcome.worker_hotkey.as_ref(), elapsed) + .await; + } } else { metrics .inc_task_failed(outcome.worker_hotkey.as_ref()) .await; } - if let Some(elapsed) = elapsed_secs { - metrics - .record_completion_time(outcome.worker_hotkey.as_ref(), elapsed) - .await; - } - if outcome.completed { manager.finalize_task(task_id).await; } diff --git a/src/http3/handlers/result/read.rs b/src/http3/handlers/result/read.rs index 3237d39..0d63031 100644 --- a/src/http3/handlers/result/read.rs +++ b/src/http3/handlers/result/read.rs @@ -193,7 +193,6 @@ pub async fn get_result_handler( let state = depot.require::()?.clone(); let cfg = state.config(); let http_cfg = cfg.http(); - let metrics = state.metrics().clone(); let record_origin = normalize_origin(req, http_cfg); let gateway_state = state.gateway_state().clone(); let task_manager = gateway_state.task_manager(); @@ -293,13 +292,6 @@ pub async fn get_result_handler( let successful_results = results_vec; if get_task.all { - let best_worker = successful_results - .first() - .ok_or_else(|| ServerError::Internal("No TaskResult after filtering".into()))? - .worker_hotkey - .clone(); - metrics.inc_best_task(&best_worker).await; - let content_disposition = "attachment; filename=\"results.zip\""; set_download_headers(res, "application/zip", content_disposition)?; @@ -366,14 +358,11 @@ pub async fn get_result_handler( .next() .ok_or_else(|| ServerError::Internal("Failed to select best TaskResult".into()))?; - let best_worker = best.worker_hotkey.clone(); let asset = best .asset .ok_or_else(|| ServerError::Internal("Missing asset on best TaskResult".into()))?; let data = process_asset(asset, decompress_spz).await?; - metrics.inc_best_task(&best_worker).await; - let content_disposition = format!("attachment; filename=\"result.{}\"", extension); set_download_headers(res, content_type, &content_disposition)?; set_content_length_header(res, data.len())?; diff --git a/src/http3/handlers/task/add_task.rs b/src/http3/handlers/task/add_task.rs index fc62026..a57ed59 100644 --- a/src/http3/handlers/task/add_task.rs +++ b/src/http3/handlers/task/add_task.rs @@ -6,6 +6,7 @@ use tracing::warn; use uuid::Uuid; use crate::api::Task; +use crate::common::queue::TaskRouting; use crate::db::{CreateGenerationTaskInput, CreateGenerationTaskOutcome}; use crate::http3::depot_ext::DepotExt; use crate::http3::error::ServerError; @@ -111,6 +112,19 @@ pub async fn add_task_handler( } else { None }; + let required_worker_tags = if billing_owner + .as_ref() + .and_then(|owner| owner.company_id) + .is_some() + { + rate_ctx + .company + .as_ref() + .map(|company| company.worker_tags.clone()) + .unwrap_or_default() + } else { + Vec::new() + }; let billing_request_json = json!({ "seed": seed, "model": &model_name, @@ -265,8 +279,10 @@ pub async fn add_task_handler( .task_manager() .add_task_with_rate_limit_reservation(&task, reservation.clone()) .await; - queue_slot.push(task); - metrics.set_queue_len(queue.len()); + queue_slot.push_with_routing( + task, + TaskRouting::with_required_worker_tags(required_worker_tags), + ); if let Some(task_log) = task_log.as_ref() { task_log.queued(); diff --git a/src/http3/handlers/task/get_tasks.rs b/src/http3/handlers/task/get_tasks.rs index 7de79ed..c08effc 100644 --- a/src/http3/handlers/task/get_tasks.rs +++ b/src/http3/handlers/task/get_tasks.rs @@ -5,8 +5,9 @@ use serde_json::json; use tracing::{debug, error, info, warn}; use uuid::Uuid; -use crate::api::request::GetTasksRequest; +use crate::api::request::{GetTasksRequest, normalize_worker_tags}; use crate::api::response::{AssignedTask, GetTasksResponse}; +use crate::common::queue::WorkerRouting; use crate::db::{RecordedGenerationTaskAssignment, RecordedGenerationTaskAssignmentAction}; use crate::http3::depot_ext::DepotExt; use crate::http3::error::ServerError; @@ -69,6 +70,13 @@ pub async fn get_tasks_handler( models }; + let worker_tags = normalize_worker_tags(&get_tasks.worker_tags).map_err(|message| { + ServerError::BadRequestJson(serde_json::json!({ + "error": "invalid_field", + "field": "worker_tags", + "message": message, + })) + })?; info!( worker_hotkey = %get_tasks.worker_hotkey, @@ -76,6 +84,7 @@ pub async fn get_tasks_handler( requested_count = get_tasks.requested_task_count, model_filter_count = model_filter.len(), models = ?model_filter, + worker_tag_count = worker_tags.len(), "Worker requested tasks" ); @@ -85,20 +94,20 @@ pub async fn get_tasks_handler( .map_err(|e| ServerError::Internal(format!("Failed to obtain gateways: {:?}", e)))?; let gateway_count = gateways.len(); + let max_task_queue_len = state.gateway_state().max_task_queue_len(); let requested_task_count = get_tasks .requested_task_count - .min(state.gateway_state().max_task_queue_len().max(1)); + .min(max_task_queue_len.max(1)); + queue.set_reservation_scan_cap(max_task_queue_len); let mut task_ids = Vec::with_capacity(requested_task_count); let task_manager = gateway_state.task_manager(); - let mut deliveries = queue.reserve_for_models( + let mut deliveries = queue.reserve_for_models_with_routing( requested_task_count, &get_tasks.worker_hotkey, model_filter.as_slice(), + WorkerRouting::from_tags(worker_tags.as_slice()), ); for delivery in &deliveries { - if let Some(dur) = delivery.duration() { - metrics.record_queue_time(dur.as_secs_f64()); - } task_ids.push(delivery.task().id); } let reserved_count = task_ids.len(); @@ -177,9 +186,12 @@ pub async fn get_tasks_handler( let assignment_token = assignment .assignment_token .expect("assigned task is missing assignment token"); - let task = delivery.commit().0; + let (task, queue_duration) = delivery.commit(); + if let Some(duration) = queue_duration { + metrics.record_queue_time(duration.as_secs_f64()); + } tasks.push(AssignedTask { - task: task.clone(), + task, assignment_token, }); task_manager @@ -216,8 +228,6 @@ pub async fn get_tasks_handler( }); } - metrics.set_queue_len(queue.len()); - gateway_state.update_task_acquisition().map_err(|e| { ServerError::Internal(format!( "Failed to execute update_task_acquisition: {:?}", diff --git a/src/http3/rate_limits.rs b/src/http3/rate_limits.rs index 858eeef..5eba589 100644 --- a/src/http3/rate_limits.rs +++ b/src/http3/rate_limits.rs @@ -46,6 +46,7 @@ pub struct CompanyRateLimit { pub name: Arc, pub concurrent_limit: u64, pub daily_limit: u64, + pub worker_tags: Vec, } #[derive(Clone)] @@ -407,14 +408,13 @@ async fn populate_api_key_context( context.user_limits = lookup.user_limits; context.billing_owner = lookup.billing_owner; context.auth_lookup_blocked = lookup.auth_lookup_blocked; - if let (Some(cid), Some((name, concurrent, daily))) = - (lookup.company_id, lookup.company_info) - { + if let (Some(cid), Some(company_info)) = (lookup.company_id, lookup.company_info) { context.company = Some(CompanyRateLimit { id: cid, - name, - concurrent_limit: concurrent, - daily_limit: daily, + name: company_info.name, + concurrent_limit: company_info.concurrent_limit, + daily_limit: company_info.daily_limit, + worker_tags: company_info.worker_tags, }); } } diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 39e6aed..021b204 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -18,7 +18,6 @@ pub struct MetricsEntry { pub failed_tasks: Counter, pub timeout_failed_tasks: Counter, pub tasks_received: Counter, - pub best_results_total: Gauge, pub tasks_in_progress: IntGauge, last_touched_ms: AtomicU64, } @@ -49,7 +48,7 @@ struct MetricsInner { registry: Registry, - queue_len: Gauge, + queue_len: IntGauge, queue_time_avg: Gauge, queue_time_max: Gauge, @@ -63,7 +62,6 @@ struct MetricsInner { failed_tasks: CounterVec, timeout_failed_tasks: CounterVec, tasks_received: CounterVec, - best_results_total: GaugeVec, tasks_in_progress: IntGaugeVec, map: HashMap, RandomState>, @@ -74,7 +72,7 @@ impl Metrics { pub fn new(alpha: f64) -> Result { let registry = Registry::new(); - let queue_len = Gauge::with_opts(opts!( + let queue_len = IntGauge::with_opts(opts!( "queue_len", "Number of tasks currently waiting in the queue" ))?; @@ -147,13 +145,6 @@ impl Metrics { ), &["worker"], )?; - let best_completed_tasks = GaugeVec::new( - Opts::new( - "best_completed_tasks", - "Per-worker count of wins where this worker's result was selected as best", - ), - &["worker"], - )?; let tasks_in_progress = IntGaugeVec::new( Opts::new( "tasks_in_progress", @@ -170,7 +161,6 @@ impl Metrics { registry.register(Box::new(failed_tasks.clone()))?; registry.register(Box::new(timeout_failed_tasks.clone()))?; registry.register(Box::new(tasks_received.clone()))?; - registry.register(Box::new(best_completed_tasks.clone()))?; registry.register(Box::new(tasks_in_progress.clone()))?; let inner = MetricsInner { @@ -186,7 +176,6 @@ impl Metrics { failed_tasks, timeout_failed_tasks, tasks_received, - best_results_total: best_completed_tasks, tasks_in_progress, requests_by_origin, map: HashMap::with_capacity_and_hasher(16, RandomState::default()), @@ -209,8 +198,8 @@ impl Metrics { &self.inner.registry } - pub fn set_queue_len(&self, len: usize) { - self.inner.queue_len.set(len as f64); + pub fn queue_len_gauge(&self) -> IntGauge { + self.inner.queue_len.clone() } pub fn inc_task_completed_kind(&self, kind: TaskKind) { @@ -233,7 +222,6 @@ impl Metrics { failed_tasks: self.inner.failed_tasks.with_label_values(&[key]), timeout_failed_tasks: self.inner.timeout_failed_tasks.with_label_values(&[key]), tasks_received: self.inner.tasks_received.with_label_values(&[key]), - best_results_total: self.inner.best_results_total.with_label_values(&[key]), tasks_in_progress: self.inner.tasks_in_progress.with_label_values(&[key]), last_touched_ms: AtomicU64::new(now_ms), }); @@ -274,7 +262,6 @@ impl Metrics { let failed_tasks = self.inner.failed_tasks.clone(); let timeout_failed_tasks = self.inner.timeout_failed_tasks.clone(); let tasks_received = self.inner.tasks_received.clone(); - let best_results_total = self.inner.best_results_total.clone(); let tasks_in_progress = self.inner.tasks_in_progress.clone(); self.inner @@ -293,7 +280,6 @@ impl Metrics { let _ = failed_tasks.remove_label_values(&labels); let _ = timeout_failed_tasks.remove_label_values(&labels); let _ = tasks_received.remove_label_values(&labels); - let _ = best_results_total.remove_label_values(&labels); let _ = tasks_in_progress.remove_label_values(&labels); false }) @@ -339,10 +325,6 @@ impl Metrics { entry.timeout_failed_tasks.inc(); } - pub async fn inc_best_task(&self, key: &str) { - self.get_entry(key).await.best_results_total.inc(); - } - pub async fn start_task(&self, key: &str) -> TaskInProgressGuard { let entry = self.get_entry(key).await; TaskInProgressGuard::new(entry.tasks_in_progress.clone()) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 36af42d..96d4704 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -71,7 +71,6 @@ impl_try_from!(VoteResponse, VoteResponse); pub struct Protocol { _connection: Connection, max_message_size: usize, - max_recv_buffer_size: usize, receive_message_timeout_ms: Duration, } @@ -79,13 +78,11 @@ impl Protocol { pub fn new( conn: Connection, max_message_size: usize, - max_recv_buffer_size: usize, receive_message_timeout_ms: Duration, ) -> Self { Self { _connection: conn, max_message_size, - max_recv_buffer_size, receive_message_timeout_ms, } } @@ -107,22 +104,21 @@ impl Protocol { } pub async fn receive_message(&self, mut recv_stream: RecvStream) -> Result { - let mut message_data = Vec::with_capacity(self.max_message_size); - let mut buffer = vec![0; self.max_recv_buffer_size]; - - let data = tokio::time::timeout(self.receive_message_timeout_ms, async { - while let Some(n) = recv_stream.read(&mut buffer).await? { - message_data.extend_from_slice(&buffer[..n]); - if message_data.len() > self.max_message_size { - return Err(anyhow!( - "Message exceeded maximum size of {} bytes", - self.max_message_size - )); - } - } - Ok(message_data) - }) - .await??; + // `read_to_end` enforces the size cap while reading directly into a + // single right-sized buffer, avoiding a per-message pre-allocation of + // `max_message_size` plus a separate zero-initialized scratch buffer. + let data = tokio::time::timeout( + self.receive_message_timeout_ms, + recv_stream.read_to_end(self.max_message_size), + ) + .await? + .map_err(|e| { + anyhow!( + "Failed to read message (max {} bytes): {:?}", + self.max_message_size, + e + ) + })?; if data.is_empty() { return Err(anyhow!("Received empty message")); diff --git a/src/raft/client/mod.rs b/src/raft/client/mod.rs index 7fe3418..5f8e901 100644 --- a/src/raft/client/mod.rs +++ b/src/raft/client/mod.rs @@ -319,7 +319,6 @@ impl RClient { let proto = Protocol::new( conn.clone(), self.protocol_cfg.max_message_size, - self.protocol_cfg.max_recv_buffer_size, Duration::from_millis(self.protocol_cfg.receive_message_timeout_ms), ); let msg: RaftMessageType = data.into(); diff --git a/src/raft/mod.rs b/src/raft/mod.rs index 6a5d005..0f53d1e 100644 --- a/src/raft/mod.rs +++ b/src/raft/mod.rs @@ -617,13 +617,15 @@ pub async fn get_node_ids( Ok(ids) } -fn build_task_queue(cfg: &NodeConfig) -> TaskQueue { +fn build_task_queue(cfg: &NodeConfig, metrics: &Metrics) -> TaskQueue { TaskQueue::builder() .dup(cfg.basic.unique_workers_per_task) .ttl(cfg.basic.taskqueue_task_ttl) .cleanup_interval(cfg.basic.taskqueue_cleanup_interval) .default_model(cfg.model_config.default_model.clone()) .models(cfg.model_config.models.keys().cloned()) + .reservation_scan_cap(cfg.http.max_task_queue_len) + .queue_len_gauge(metrics.queue_len_gauge()) .build() } @@ -832,7 +834,8 @@ pub async fn start_gateway( RandomState::default(), )); - let task_queue = build_task_queue(cfg); + let metrics = Metrics::new(0.05).map_err(|e| anyhow::anyhow!(e))?; + let task_queue = build_task_queue(cfg, &metrics); let raft_config = build_raft_config(cfg)?; let gateway_shutdown = shutdown.child_token(); @@ -941,8 +944,6 @@ pub async fn start_gateway( Arc::clone(&api_key_validator), gateway_shutdown.clone(), )); - let metrics = Metrics::new(0.05).map_err(|e| anyhow::anyhow!(e))?; - let rate_limit_queue = RateLimitMutationBuffer::default(); let task_manager = TaskManager::new_with_rate_limit_mutation_queue(TaskManagerInit { initial_capacity: cfg.basic.taskmanager_initial_capacity, diff --git a/src/raft/server/mod.rs b/src/raft/server/mod.rs index 0b0a775..4ef645e 100644 --- a/src/raft/server/mod.rs +++ b/src/raft/server/mod.rs @@ -124,7 +124,6 @@ impl RServer { let protocol = Protocol::new( connection.clone(), rserver_cfg.max_message_size, - rserver_cfg.max_recv_buffer_size, Duration::from_millis(rserver_cfg.receive_message_timeout_ms), ); if let Err(e) = Self::handle_request(protocol, send, recv, raft.clone()).await { diff --git a/src/raft/tests/cross_gateway_add_task.rs b/src/raft/tests/cross_gateway_add_task.rs index 2dc4957..c63a37b 100644 --- a/src/raft/tests/cross_gateway_add_task.rs +++ b/src/raft/tests/cross_gateway_add_task.rs @@ -41,6 +41,8 @@ const LOCALHOST: &str = "127.0.0.1"; const TEST_DOMAIN: &str = "localhost"; const PROMPT_JSON: &str = r#"{"prompt":"mechanic robot"}"#; const GENERIC_WINDOW_MS: u64 = 86_400_000; +const TEST_HTTP_IDLE_TIMEOUT_SEC: u64 = 10; +const TEST_HTTP_KEEP_ALIVE_SEC: u64 = 2; #[derive(Default)] struct BlockingCreateStore { @@ -192,6 +194,7 @@ fn build_task_queue(config: &NodeConfig) -> TaskQueue { .cleanup_interval(config.basic.taskqueue_cleanup_interval) .default_model(config.model_config.default_model.clone()) .models(config.model_config.models.keys().cloned()) + .reservation_scan_cap(config.http.max_task_queue_len) .build() } @@ -213,6 +216,8 @@ async fn build_http_client(http_port: u16) -> Result { match Http3ClientBuilder::new() .server_domain(TEST_DOMAIN) .server_ip(server_ip.clone()) + .max_idle_timeout_sec(TEST_HTTP_IDLE_TIMEOUT_SEC) + .keep_alive_interval(TEST_HTTP_KEEP_ALIVE_SEC) .dangerous_skip_verification(true) .build() .await @@ -236,6 +241,7 @@ async fn build_gateway_node( node_id: u64, raft_addr: &str, http_addr: &str, + http_reservation: std::net::UdpSocket, raft: Raft, state: Arc, blocking_store: Arc, @@ -257,6 +263,8 @@ async fn build_gateway_node( config.cert.dangerous_skip_verification = true; config.cert.cert_file_path.clear(); config.cert.key_file_path.clear(); + config.http.max_idle_timeout_sec = TEST_HTTP_IDLE_TIMEOUT_SEC; + config.http.keep_alive_interval_sec = TEST_HTTP_KEEP_ALIVE_SEC; let config_file = tempfile::Builder::new() .prefix(&format!("gateway-http-node-{node_id}-")) @@ -321,6 +329,7 @@ async fn build_gateway_node( event_recorder, }); + drop(http_reservation); let http_server = Http3Server::run( runtime_config.clone(), Some(RustlsConfig::new(generate_and_create_keycert(vec![ @@ -387,21 +396,27 @@ async fn setup_cross_gateway_harness_inner( wait_for_consistent_leader_index(&raft_nodes, &node_configs, Duration::from_secs(10)) .await?; + // Keep the reservation sockets alive; each is handed to `build_gateway_node` + // and only dropped immediately before that node binds its server, which + // closes the reserve->bind race window per node. let (http_addrs, http_reservations) = reserve_udp_addresses(node_configs.len())?; - drop(http_reservations); let blocking_store = Arc::new(BlockingCreateStore::default()); let mut nodes = Vec::with_capacity(node_configs.len()); - for ((node_id, raft_addr), (state, http_addr)) in node_configs - .iter() - .zip(state_machines.iter().cloned().zip(http_addrs.iter())) - { + for ((node_id, raft_addr), ((state, http_addr), http_reservation)) in node_configs.iter().zip( + state_machines + .iter() + .cloned() + .zip(http_addrs.iter()) + .zip(http_reservations), + ) { nodes.push( build_gateway_node( &base_config, *node_id, raft_addr, http_addr.as_str(), + http_reservation, raft_nodes[node_index(&node_configs, *node_id)].clone(), state, blocking_store.clone(), @@ -646,47 +661,47 @@ async fn repeated_random_admin_key_writes_are_rejected_before_body_parse() -> Re let cfg = leader._runtime_config.snapshot(); let random_admin_key = random_non_admin_key(cfg.http().admin_key).to_string(); - let (first_status, first_body) = post_gateway_write( - &leader.client, - leader.http_port, + let invalid_bodies = [ Bytes::from_static(b"not-msgpack"), - random_admin_key.as_str(), - ) - .await?; - assert_eq!( - first_status, - StatusCode::UNAUTHORIZED, - "bad-key write should fail on admin key before body parsing; body: {}", - String::from_utf8_lossy(first_body.as_ref()) - ); - - let (second_status, second_body) = post_gateway_write( - &leader.client, - leader.http_port, Bytes::from_static(b"still-not-msgpack"), - random_admin_key.as_str(), - ) - .await?; - assert_eq!( - second_status, - StatusCode::TOO_MANY_REQUESTS, - "body: {}", - String::from_utf8_lossy(second_body.as_ref()) - ); - let payload: serde_json::Value = serde_json::from_slice(second_body.as_ref())?; - assert_eq!( - payload.get("error").and_then(|value| value.as_str()), - Some("invalid_admin_key_rate_limit") - ); - - let (third_status, _third_body) = post_gateway_write( - &leader.client, - leader.http_port, Bytes::from_static(b"still-not-msgpack-again"), - random_admin_key.as_str(), - ) - .await?; - assert_eq!(third_status, StatusCode::TOO_MANY_REQUESTS); + ]; + let mut saw_rate_limit = false; + for body in invalid_bodies { + let (status, response_body) = post_gateway_write( + &leader.client, + leader.http_port, + body, + random_admin_key.as_str(), + ) + .await?; + + match status { + StatusCode::UNAUTHORIZED => { + assert!( + !saw_rate_limit, + "bad-key write became unblocked after rate limit; body: {}", + String::from_utf8_lossy(response_body.as_ref()) + ); + } + StatusCode::TOO_MANY_REQUESTS => { + saw_rate_limit = true; + let payload: serde_json::Value = serde_json::from_slice(response_body.as_ref())?; + assert_eq!( + payload.get("error").and_then(|value| value.as_str()), + Some("invalid_admin_key_rate_limit") + ); + } + other => panic!( + "bad-key write should fail on admin key before body parsing; status: {other}, body: {}", + String::from_utf8_lossy(response_body.as_ref()) + ), + } + } + assert!( + saw_rate_limit, + "repeated bad-key writes should eventually hit invalid admin key rate limit" + ); Ok(()) } diff --git a/src/task/mod.rs b/src/task/mod.rs index 844a2da..46927cb 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -24,7 +24,10 @@ use crate::http3::rate_limits::RateLimitReservation; use crate::metrics::{Metrics, TaskInProgressGuard, TaskKind}; const TASK_TIMED_OUT_REASON: &str = "Task timed out"; +#[cfg(not(test))] const PENDING_RESULT_COMMIT_GRACE: Duration = Duration::from_secs(5); +#[cfg(test)] +const PENDING_RESULT_COMMIT_GRACE: Duration = Duration::from_millis(150); const RATE_LIMIT_COMPLETION_REQUEST_ID_XOR: u128 = 0x7b2f_4d91_a6c3_e805_19fe_42ac_55d0_3b71; pub(crate) fn rate_limit_completion_request_id(task_id: Uuid) -> u128 { @@ -82,6 +85,7 @@ struct TaskState { assigned_assignment_tokens: FoldHashMap, in_progress: FoldHashMap, pending_results: FoldHashMap, + pending_results_grace_until: Option, seed: Option, finished_results_count: usize, rate_limit_reservation: Option, @@ -191,6 +195,7 @@ impl TaskState { assigned_assignment_tokens: FoldHashMap::default(), in_progress: FoldHashMap::default(), pending_results: FoldHashMap::default(), + pending_results_grace_until: None, seed: Some(task.seed), finished_results_count: 0, rate_limit_reservation: None, @@ -480,16 +485,23 @@ impl TaskManager { continue; } if !state.pending_results.is_empty() { - let deferred_expires_at = - now + PENDING_RESULT_COMMIT_GRACE; - state.task_expires_at = - Some(deferred_expires_at); - heap.push(Reverse(( - deferred_expires_at, - ExpirationKind::Task as u8, - task_id.as_u128(), - ))); - continue; + let grace_until = state + .pending_results_grace_until + .get_or_insert_with(|| { + now + PENDING_RESULT_COMMIT_GRACE + }); + if *grace_until > now { + state.task_expires_at = + Some(*grace_until); + heap.push(Reverse(( + *grace_until, + ExpirationKind::Task as u8, + task_id.as_u128(), + ))); + continue; + } + state.pending_results.clear(); + state.pending_results_grace_until = None; } let task_kind = state.task_kind.label(); let assigned_workers = std::mem::take(&mut state.assigned_workers); @@ -685,6 +697,7 @@ impl TaskManager { return Err(AddResultError::AlreadyStaged); } state.pending_results.insert(worker, result); + state.pending_results_grace_until = None; Ok(()) } Entry::Vacant(_) => Err(AddResultError::NotAssigned), @@ -716,6 +729,9 @@ impl TaskManager { let Some(result) = state.pending_results.remove(worker) else { return Err(AddResultError::PendingResultMissing); }; + if state.pending_results.is_empty() { + state.pending_results_grace_until = None; + } let (outcome, results_expires_at, reservation) = Self::apply_committed_result( state, result, @@ -737,7 +753,11 @@ impl TaskManager { pub async fn rollback_staged_result(&self, task_id: Uuid, worker: &Hotkey) { if let Entry::Occupied(mut entry) = self.inner.tasks.entry_async(task_id).await { - entry.get_mut().pending_results.remove(worker); + let state = entry.get_mut(); + state.pending_results.remove(worker); + if state.pending_results.is_empty() { + state.pending_results_grace_until = None; + } } } @@ -899,6 +919,9 @@ impl TaskManager { .is_some_and(|stored_token| *stored_token != assignment_token) { state.pending_results.remove(&worker); + if state.pending_results.is_empty() { + state.pending_results_grace_until = None; + } } state .assigned_assignment_tokens diff --git a/src/task/tests.rs b/src/task/tests.rs index 66dc1ac..2d0b227 100644 --- a/src/task/tests.rs +++ b/src/task/tests.rs @@ -365,8 +365,8 @@ async fn stage_result_rejects_double_stage_for_same_worker() { #[tokio::test] async fn staged_result_defers_timeout_until_commit() { - const CLEANUP_INTERVAL: Duration = Duration::from_millis(40); - const TASK_AND_RESULT_LIFETIME: Duration = Duration::from_millis(120); + const CLEANUP_INTERVAL: Duration = Duration::from_millis(20); + const TASK_AND_RESULT_LIFETIME: Duration = Duration::from_millis(40); let metrics = Metrics::new(0.05).unwrap(); let mutation_queue = RateLimitMutationBuffer::default(); @@ -376,7 +376,7 @@ async fn staged_result_defers_timeout_until_commit() { expected_results: 1, cleanup_interval: CLEANUP_INTERVAL, task_lifetime: TASK_AND_RESULT_LIFETIME, - result_lifetime: TASK_AND_RESULT_LIFETIME, + result_lifetime: Duration::from_secs(1), rate_limit_mutation_queue: mutation_queue.clone(), metrics, worker_event_recorder: None, @@ -412,10 +412,7 @@ async fn staged_result_defers_timeout_until_commit() { .await .unwrap(); - tokio::time::sleep(Duration::from_millis( - TASK_AND_RESULT_LIFETIME.as_millis() as u64 + 200, - )) - .await; + tokio::time::sleep(Duration::from_millis(80)).await; assert!( task_manager.is_assigned(task_id, &worker).await, @@ -445,6 +442,77 @@ async fn staged_result_defers_timeout_until_commit() { ); } +#[tokio::test] +async fn stale_staged_result_is_dropped_after_bounded_grace() { + const CLEANUP_INTERVAL: Duration = Duration::from_millis(20); + const TASK_AND_RESULT_LIFETIME: Duration = Duration::from_millis(40); + + let metrics = Metrics::new(0.05).unwrap(); + let mutation_queue = RateLimitMutationBuffer::default(); + let limiter = DistributedRateLimiter::new(64); + let task_manager = TaskManager::new_with_rate_limit_mutation_queue(TaskManagerInit { + initial_capacity: 4, + expected_results: 1, + cleanup_interval: CLEANUP_INTERVAL, + task_lifetime: TASK_AND_RESULT_LIFETIME, + result_lifetime: Duration::from_secs(1), + rate_limit_mutation_queue: mutation_queue.clone(), + metrics, + worker_event_recorder: None, + }) + .await; + let task_id = Uuid::new_v4(); + task_manager + .add_task_with_rate_limit_reservation( + &sample_task(task_id), + Some(RateLimitReservation::new( + limiter, + vec![RateLimitDelta { + subject: Subject::User, + id: 4_243u128, + day_epoch: 99, + add_active: 1, + add_day: 0, + }], + )), + ) + .await; + + let worker: Hotkey = Hotkey::from_bytes(&[62u8; 32]); + let worker_str = worker.to_string(); + task_manager + .record_assignment(task_id, worker.clone(), worker_str.clone().into()) + .await; + task_manager + .stage_result( + task_id, + make_result(worker_str.as_str(), worker_str.as_str(), Instant::now()), + ) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(260)).await; + + assert!( + !task_manager.is_assigned(task_id, &worker).await, + "stale staged result should not keep the task assigned forever" + ); + assert!(matches!( + task_manager.commit_staged_result(task_id, &worker).await, + Err(AddResultError::PendingResultMissing) + )); + assert_eq!( + task_manager.get_status(task_id).await, + TaskStatus::Failure { + reason: "Task timed out".into(), + } + ); + assert!( + mutation_queue.drain_batch(1).await.is_some(), + "timeout cleanup should still emit completion side effects after dropping stale staged data" + ); +} + #[tokio::test] async fn mixed_outcome_reports_success_when_expected_results_are_exhausted() { let task_manager = TaskManager::new( diff --git a/src/test_support.rs b/src/test_support.rs index 58508e2..b680c21 100644 --- a/src/test_support.rs +++ b/src/test_support.rs @@ -104,6 +104,8 @@ pub async fn build_shared_harness_core( .cleanup_interval(config.basic.taskqueue_cleanup_interval) .default_model(config.model_config.default_model.clone()) .models(config.model_config.models.keys().cloned()) + .reservation_scan_cap(config.http.max_task_queue_len) + .queue_len_gauge(metrics.queue_len_gauge()) .build(); let db = Arc::new(Database::new_mock()); diff --git a/tests/client_http_api/add_task.rs b/tests/client_http_api/add_task.rs index 2f8a34e..f7b94b6 100644 --- a/tests/client_http_api/add_task.rs +++ b/tests/client_http_api/add_task.rs @@ -843,8 +843,8 @@ async fn add_task_company_key_limit_updates_take_effect_without_manual_updated_a let lookup = lookup_api_key(&h, &h.company_api_key).await; assert_eq!(lookup.company_id, Some(h.company_id)); let company = lookup.company_info.expect("company lookup info"); - assert_eq!(company.1, 100); - assert_eq!(company.2, 1); + assert_eq!(company.concurrent_limit, 100); + assert_eq!(company.daily_limit, 1); } #[tokio::test] diff --git a/tests/client_http_api/get_tasks.rs b/tests/client_http_api/get_tasks.rs index b082009..aeec1a6 100644 --- a/tests/client_http_api/get_tasks.rs +++ b/tests/client_http_api/get_tasks.rs @@ -3,13 +3,57 @@ use gateway::crypto::hotkey::Hotkey; use http::StatusCode; use crate::support::{ - TestClient, add_task_prompt, add_task_prompt_with_api_key, build_harness, + TestClient, TestHarness, add_task_prompt, add_task_prompt_with_api_key, build_harness, build_harness_with_worker_whitelist, create_personal_api_key, multipart_body, purge_terminal_generation_task_in_db, read_response, sign_worker, timeout_generation_task_in_db, tiny_png_bytes, top_up_personal_api_key_balance, - update_personal_api_key_limits_without_timestamp, + update_company_worker_tags, update_personal_api_key_limits_without_timestamp, }; +async fn signed_get_tasks_json( + h: &TestHarness, + seed: u8, + worker_id: &str, + requested_task_count: usize, + worker_tags: Option, +) -> (StatusCode, serde_json::Value) { + let (worker_hotkey, timestamp, signature) = sign_worker([seed; 32]); + let mut payload = serde_json::json!({ + "worker_hotkey": worker_hotkey.to_string(), + "worker_id": worker_id, + "signature": signature, + "timestamp": timestamp, + "requested_task_count": requested_task_count, + "model": "404-3dgs" + }); + if let Some(worker_tags) = worker_tags { + payload["worker_tags"] = worker_tags; + } + + let res = TestClient::post("http://localhost/get_tasks") + .json(&payload) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + let payload: serde_json::Value = serde_json::from_slice(&body).expect("json response"); + (status, payload) +} + +fn task_ids_from_get_tasks_payload(payload: &serde_json::Value) -> Vec { + payload + .get("tasks") + .and_then(|value| value.as_array()) + .expect("tasks") + .iter() + .map(|task| { + task.get("id") + .and_then(|value| value.as_str()) + .and_then(|value| uuid::Uuid::parse_str(value).ok()) + .expect("task id") + }) + .collect() +} + #[tokio::test] async fn get_tasks_invalid_signature_returns_unauthorized() { let h = build_harness().await; @@ -48,6 +92,44 @@ async fn get_tasks_invalid_model_returns_bad_request() { assert_eq!(status, StatusCode::BAD_REQUEST); } +#[tokio::test] +async fn get_tasks_rejects_invalid_worker_tags() { + let h = build_harness().await; + let invalid_worker_tags = [ + serde_json::json!([""]), + serde_json::json!(["bad tag"]), + serde_json::json!([format!("{}x", "a".repeat(32))]), + ]; + + for (idx, worker_tags) in invalid_worker_tags.into_iter().enumerate() { + let (worker_hotkey, timestamp, signature) = sign_worker([40u8 + idx as u8; 32]); + let res = TestClient::post("http://localhost/get_tasks") + .json(&serde_json::json!({ + "worker_hotkey": worker_hotkey.to_string(), + "worker_id": format!("worker-invalid-tags-{idx}"), + "signature": signature, + "timestamp": timestamp, + "requested_task_count": 1, + "model": "404-3dgs", + "worker_tags": worker_tags + })) + .send(&h.service) + .await; + + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + let payload: serde_json::Value = serde_json::from_slice(&body).expect("json response"); + assert_eq!( + payload.get("error").and_then(|value| value.as_str()), + Some("invalid_field") + ); + assert_eq!( + payload.get("field").and_then(|value| value.as_str()), + Some("worker_tags") + ); + } +} + #[tokio::test] async fn get_tasks_whitelist_rejects_worker() { let mut whitelist = HashSet::new(); @@ -329,6 +411,269 @@ async fn get_tasks_multiple_models_returns_all_matches() { assert!(models.contains("404-mesh")); } +#[tokio::test] +async fn get_tasks_requires_matching_worker_tag_for_tagged_company_task() { + let h = build_harness().await; + update_company_worker_tags(&h, &["acme", "premium"]).await; + let task_id = add_task_prompt_with_api_key(&h, &h.company_api_key, "company robot", None).await; + + let (worker_hotkey, timestamp, signature) = sign_worker([5u8; 32]); + let no_match = TestClient::post("http://localhost/get_tasks") + .json(&serde_json::json!({ + "worker_hotkey": worker_hotkey.to_string(), + "worker_id": "worker-no-match", + "signature": signature, + "timestamp": timestamp, + "requested_task_count": 1, + "model": "404-3dgs", + "worker_tags": ["other"] + })) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(no_match).await; + assert_eq!(status, StatusCode::OK); + let payload: serde_json::Value = serde_json::from_slice(&body).expect("json response"); + assert_eq!( + payload + .get("tasks") + .and_then(|value| value.as_array()) + .expect("tasks") + .len(), + 0 + ); + + let (worker_hotkey, timestamp, signature) = sign_worker([6u8; 32]); + let matching = TestClient::post("http://localhost/get_tasks") + .json(&serde_json::json!({ + "worker_hotkey": worker_hotkey.to_string(), + "worker_id": "worker-match", + "signature": signature, + "timestamp": timestamp, + "requested_task_count": 1, + "model": "404-3dgs", + "worker_tags": ["premium"] + })) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(matching).await; + assert_eq!(status, StatusCode::OK); + let payload: serde_json::Value = serde_json::from_slice(&body).expect("json response"); + let tasks = payload + .get("tasks") + .and_then(|value| value.as_array()) + .expect("tasks"); + assert_eq!(tasks.len(), 1); + let returned_task_id = tasks[0] + .get("id") + .and_then(|value| value.as_str()) + .and_then(|value| uuid::Uuid::parse_str(value).ok()) + .expect("task id"); + assert_eq!(returned_task_id, task_id); +} + +#[tokio::test] +async fn get_tasks_queued_company_task_keeps_original_worker_tags_after_company_update() { + let h = build_harness().await; + update_company_worker_tags(&h, &["acme"]).await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "company snapshot robot", None).await; + + update_company_worker_tags(&h, &["other"]).await; + + let (status, payload) = signed_get_tasks_json( + &h, + 52, + "worker-new-company-tag", + 1, + Some(serde_json::json!(["other"])), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert!( + task_ids_from_get_tasks_payload(&payload).is_empty(), + "task queued with acme must not dynamically switch to the updated company tag" + ); + + let (status, payload) = signed_get_tasks_json( + &h, + 53, + "worker-original-company-tag", + 1, + Some(serde_json::json!(["acme"])), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + +#[tokio::test] +async fn get_tasks_normalizes_worker_tags_before_matching_company_task() { + let h = build_harness().await; + update_company_worker_tags(&h, &["premium"]).await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "normalized tag robot", None).await; + + let (status, payload) = signed_get_tasks_json( + &h, + 54, + "worker-normalized-tags", + 1, + Some(serde_json::json!([" Premium ", "premium"])), + ) + .await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + +#[tokio::test] +async fn get_tasks_normalizes_company_worker_tags_before_matching_worker() { + let h = build_harness().await; + update_company_worker_tags(&h, &[" PREMIUM "]).await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "normalized company tag robot", None) + .await; + + let (status, payload) = signed_get_tasks_json( + &h, + 58, + "worker-matches-normalized-company-tag", + 1, + Some(serde_json::json!(["premium"])), + ) + .await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + +#[tokio::test] +async fn get_tasks_batch_filters_public_and_restricted_company_tasks_by_worker_tags() { + let h = build_harness().await; + update_company_worker_tags(&h, &["acme"]).await; + let restricted_task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "restricted company robot", None) + .await; + let public_task_id = add_task_prompt(&h, "public robot", None).await; + + let (status, payload) = signed_get_tasks_json(&h, 55, "worker-without-tags", 10, None).await; + assert_eq!(status, StatusCode::OK); + assert_eq!( + task_ids_from_get_tasks_payload(&payload), + vec![public_task_id], + "worker without tags should only receive public tasks from a mixed batch" + ); + + let (status, payload) = signed_get_tasks_json( + &h, + 55, + "worker-with-restricted-tag", + 10, + Some(serde_json::json!(["acme"])), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!( + task_ids_from_get_tasks_payload(&payload), + vec![restricted_task_id] + ); +} + +#[tokio::test] +async fn get_tasks_worker_with_multiple_tags_matches_company_task_by_any_tag() { + let h = build_harness().await; + update_company_worker_tags(&h, &["acme"]).await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "multi-tag worker robot", None).await; + + let (status, payload) = signed_get_tasks_json( + &h, + 57, + "worker-multiple-tags", + 1, + Some(serde_json::json!(["other", "acme"])), + ) + .await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + +#[tokio::test] +async fn get_tasks_tagged_worker_does_not_receive_untagged_company_task() { + let h = build_harness().await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "untagged company robot", None).await; + + let (status, payload) = signed_get_tasks_json( + &h, + 59, + "tagged-worker-public-company-task", + 1, + Some(serde_json::json!(["acme"])), + ) + .await; + + assert_eq!(status, StatusCode::OK); + assert!( + task_ids_from_get_tasks_payload(&payload).is_empty(), + "worker requesting tagged work must not receive an untagged company task" + ); + + let (status, payload) = + signed_get_tasks_json(&h, 60, "untagged-worker-company-task", 1, None).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + +#[tokio::test] +async fn get_tasks_allows_untagged_company_task_without_worker_tags() { + let h = build_harness().await; + let task_id = add_task_prompt_with_api_key(&h, &h.company_api_key, "company robot", None).await; + + let (worker_hotkey, timestamp, signature) = sign_worker([7u8; 32]); + let res = TestClient::post("http://localhost/get_tasks") + .json(&serde_json::json!({ + "worker_hotkey": worker_hotkey.to_string(), + "worker_id": "worker-public", + "signature": signature, + "timestamp": timestamp, + "requested_task_count": 1, + "model": "404-3dgs" + })) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::OK); + let payload: serde_json::Value = serde_json::from_slice(&body).expect("json response"); + let tasks = payload + .get("tasks") + .and_then(|value| value.as_array()) + .expect("tasks"); + assert_eq!(tasks.len(), 1); + let returned_task_id = tasks[0] + .get("id") + .and_then(|value| value.as_str()) + .and_then(|value| uuid::Uuid::parse_str(value).ok()) + .expect("task id"); + assert_eq!(returned_task_id, task_id); +} + +#[tokio::test] +async fn get_tasks_allows_untagged_company_task_with_empty_worker_tags() { + let h = build_harness().await; + let task_id = + add_task_prompt_with_api_key(&h, &h.company_api_key, "empty tags company robot", None) + .await; + + let (status, payload) = + signed_get_tasks_json(&h, 61, "worker-empty-tags", 1, Some(serde_json::json!([]))).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(task_ids_from_get_tasks_payload(&payload), vec![task_id]); +} + #[tokio::test] async fn get_tasks_retires_db_timed_out_task_without_reoffering_it() { let h = build_harness().await; diff --git a/tests/client_http_api/support.rs b/tests/client_http_api/support.rs index 40b208b..50c3793 100644 --- a/tests/client_http_api/support.rs +++ b/tests/client_http_api/support.rs @@ -543,6 +543,24 @@ pub(crate) async fn update_company_api_key_limits_without_timestamp( sync_db_caches(h).await; } +pub(crate) async fn update_company_worker_tags(h: &TestHarness, worker_tags: &[&str]) { + let tags: Vec = worker_tags.iter().map(|tag| tag.to_string()).collect(); + let updated = h + .core + .db_client + .execute( + "UPDATE companies + SET worker_tags = $1, + updated_at = FLOOR(EXTRACT(EPOCH FROM clock_timestamp()) * 1000)::BIGINT + WHERE id = $2", + &[&tags, &h.company_id], + ) + .await + .expect("update company worker tags"); + assert_eq!(updated, 1, "expected exactly one company row to update"); + sync_db_caches(h).await; +} + pub(crate) async fn revoke_api_key_without_timestamp_no_sync(h: &TestHarness, api_key: &str) { let key_hash = api_key_hash_bytes(h, api_key); let updated = h diff --git a/tests/event_tracker/activity.rs b/tests/event_tracker/activity.rs index b61d826..fef8a58 100644 --- a/tests/event_tracker/activity.rs +++ b/tests/event_tracker/activity.rs @@ -101,6 +101,7 @@ async fn records_company_activity_with_image_task() { name: Arc::from("Acme"), concurrent_limit: 100, daily_limit: 1000, + worker_tags: Vec::new(), }), ..RateLimitContext::default() };