From 860d9dcf359994605f2c4f9aba49d6c1cb6676a2 Mon Sep 17 00:00:00 2001 From: Andrew Paprotsky Date: Thu, 7 May 2026 11:55:12 +0000 Subject: [PATCH] feat: add OCI artifact model provider - Add OCI as a model source alongside HuggingFace, NGC, and GCS. - Support tag and digest refs, registry auth, raw file layers, and tar/tar+zstd layers. - Add staged OCI cache publishing and provider-specific cache layout. - Wire OCI through proto, CLI, provider dispatch, Redis/Kubernetes registry state, and CRDs. - Document OCI artifact format, auth, cache behavior, and usage. Signed-off-by: Andrew Paprotsky --- Cargo.lock | 345 ++++++++++++ Cargo.toml | 3 + README.md | 11 +- docs/ARCHITECTURE.md | 19 +- docs/CLI.md | 8 + docs/DEPLOYMENT.md | 13 + docs/OCI_PROVIDER.md | 83 +++ docs/metadata.md | 2 +- examples/crds.yaml | 2 + modelexpress-cli-completion.bash | 7 +- modelexpress_client/src/bin/cli.rs | 3 + modelexpress_common/Cargo.toml | 5 + modelexpress_common/proto/model.proto | 1 + modelexpress_common/src/cache.rs | 4 + modelexpress_common/src/download.rs | 16 +- modelexpress_common/src/lib.rs | 4 + modelexpress_common/src/models.rs | 8 +- modelexpress_common/src/providers.rs | 2 + modelexpress_common/src/providers/oci.rs | 161 ++++++ .../src/providers/oci/archive_format.rs | 325 +++++++++++ .../src/providers/oci/cache_entry.rs | 352 ++++++++++++ .../src/providers/oci/downloader.rs | 533 ++++++++++++++++++ .../src/providers/oci/layer_download.rs | 262 +++++++++ modelexpress_common/src/providers/oci/path.rs | 110 ++++ .../src/providers/oci/provider_cache.rs | 260 +++++++++ .../src/providers/oci/reference.rs | 142 +++++ .../src/providers/oci/registry_auth.rs | 74 +++ .../src/registry/backend/kubernetes.rs | 19 +- .../src/registry/backend/redis.rs | 3 + modelexpress_server/src/registry/k8s_types.rs | 2 +- 30 files changed, 2758 insertions(+), 21 deletions(-) create mode 100644 docs/OCI_PROVIDER.md create mode 100644 modelexpress_common/src/providers/oci.rs create mode 100644 modelexpress_common/src/providers/oci/archive_format.rs create mode 100644 modelexpress_common/src/providers/oci/cache_entry.rs create mode 100644 modelexpress_common/src/providers/oci/downloader.rs create mode 100644 modelexpress_common/src/providers/oci/layer_download.rs create mode 100644 modelexpress_common/src/providers/oci/path.rs create mode 100644 modelexpress_common/src/providers/oci/provider_cache.rs create mode 100644 modelexpress_common/src/providers/oci/reference.rs create mode 100644 modelexpress_common/src/providers/oci/registry_auth.rs diff --git a/Cargo.lock b/Cargo.lock index 00670d1a..6f08b2c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,6 +201,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f50037ee5e1e41e7b8f9d161680a725bd1626cb6f8c7e901f91f942850852fe7" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.8.4" @@ -339,6 +361,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -452,6 +476,15 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "cmake" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0f78a02292a74a88ac736019ab962ece0bc380e3f977bf72e376c5d78ff0678" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.4" @@ -543,6 +576,27 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "const_format" +version = "0.2.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4481a617ad9a412be3b97c5d403fef8ed023103368908b9c50af598ff467cc1e" +dependencies = [ + "const_format_proc_macros", + "konst", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + [[package]] name = "convert_case" version = "0.6.0" @@ -782,6 +836,37 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -840,6 +925,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -963,6 +1054,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -1002,6 +1104,12 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.31" @@ -1142,6 +1250,18 @@ dependencies = [ "wasip3", ] +[[package]] +name = "getset" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" +dependencies = [ + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "gimli" version = "0.31.1" @@ -1532,6 +1652,15 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-auth" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "150fa4a9462ef926824cf4519c84ed652ca8f4fbae34cb8af045b5cbcaf98822" +dependencies = [ + "memchr", +] + [[package]] name = "http-body" version = "1.0.1" @@ -2006,6 +2135,16 @@ dependencies = [ "syn", ] +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.91" @@ -2062,6 +2201,20 @@ dependencies = [ "serde_json", ] +[[package]] +name = "jsonwebtoken" +version = "10.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" +dependencies = [ + "base64 0.22.1", + "getrandom 0.2.16", + "js-sys", + "serde", + "serde_json", + "signature", +] + [[package]] name = "k8s-openapi" version = "0.24.0" @@ -2075,6 +2228,21 @@ dependencies = [ "serde_json", ] +[[package]] +name = "konst" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "128133ed7824fcd73d6e7b17957c5eb7bacb885649bd8c69708b2331a10bcefb" +dependencies = [ + "konst_macro_rules", +] + +[[package]] +name = "konst_macro_rules" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4933f3f57a8e9d9da04db23fb153356ecaf00cbd14aee46279c33dc80925c37" + [[package]] name = "kube" version = "0.98.0" @@ -2214,6 +2382,7 @@ checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" dependencies = [ "bitflags", "libc", + "redox_syscall", ] [[package]] @@ -2426,12 +2595,15 @@ dependencies = [ "hf-hub", "jiff", "mockall 0.13.1", + "oci-client", "prost 0.13.5", "reqwest 0.12.23", "rustls", "serde", "serde_json", "serde_yaml", + "sha2", + "tar", "tempfile", "thiserror 2.0.16", "tokio", @@ -2439,7 +2611,9 @@ dependencies = [ "tonic 0.13.1", "tonic-build", "tracing", + "uuid", "wiremock", + "zstd", ] [[package]] @@ -2554,6 +2728,60 @@ dependencies = [ "memchr", ] +[[package]] +name = "oci-client" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b7f8deaffcd3b0e3baf93dddcab3d18b91d46dc37d38a8b170089b234de5bb3" +dependencies = [ + "bytes", + "chrono", + "futures-util", + "http", + "http-auth", + "jsonwebtoken", + "lazy_static", + "oci-spec", + "olpc-cjson", + "regex", + "reqwest 0.13.2", + "serde", + "serde_json", + "sha2", + "thiserror 2.0.16", + "tokio", + "tracing", + "unicase", +] + +[[package]] +name = "oci-spec" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8445a2631507cec628a15fdd6154b54a3ab3f20ed4fe9d73a3b8b7a4e1ba03a" +dependencies = [ + "const_format", + "derive_builder", + "getset", + "regex", + "serde", + "serde_json", + "strum", + "strum_macros", + "thiserror 2.0.16", +] + +[[package]] +name = "olpc-cjson" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "696183c9b5fe81a7715d074fd632e8bd46f4ccc0231a3ed7fc580a80de5f7083" +dependencies = [ + "serde", + "serde_json", + "unicode-normalization", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -2762,6 +2990,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + [[package]] name = "plotters" version = "0.3.7" @@ -2865,6 +3099,28 @@ dependencies = [ "syn", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -2984,6 +3240,7 @@ version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ + "aws-lc-rs", "bytes", "getrandom 0.3.3", "lru-slab", @@ -3290,6 +3547,7 @@ dependencies = [ "mime_guess", "percent-encoding", "pin-project-lite", + "quinn", "rustls", "rustls-pki-types", "rustls-platform-verifier", @@ -3386,6 +3644,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -3472,6 +3731,7 @@ version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3804,6 +4064,15 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "slab" version = "0.4.11" @@ -3848,6 +4117,24 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "subtle" version = "2.6.1" @@ -3906,6 +4193,17 @@ dependencies = [ "libc", ] +[[package]] +name = "tar" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tempfile" version = "3.22.0" @@ -4440,6 +4738,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -5181,6 +5488,16 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "yaml-rust2" version = "0.10.3" @@ -5301,3 +5618,31 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index fa1dd9f2..d1435040 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ jiff = { version = "0.2.15", features = ["serde"] } modelexpress-common = { path = "modelexpress_common", version = "0.3.0" } modelexpress-client = { path = "modelexpress_client", version = "0.3.0" } modelexpress-server = { path = "modelexpress_server", version = "0.3.0" } +oci-client = { version = "0.16.1", default-features = false, features = ["rustls-tls"] } once_cell = "1.21.3" prost = "0.13" rustls = { version = "0.23.37", default-features = false, features = ["ring", "std"] } @@ -48,6 +49,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" mockall = "0.14.0" tempfile = "3.20" +tar = "0.4" tokio = { version = "1.46", features = ["full"] } tokio-stream = "0.1" tonic = "0.13" @@ -57,6 +59,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } futures = "0.3" uuid = { version = "1.17", features = ["v4", "serde"] } +zstd = "0.13" thiserror = "2.0" redis = { version = "0.27", features = ["tokio-comp", "connection-manager"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } diff --git a/README.md b/README.md index aa5c84ea..2dcc572b 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,9 @@ ModelExpress is a Rust-based service that manages the complete model weight life ### How ModelExpress manages weights in the cluster -ModelExpress orchestrates the full flow—from download to GPU memory. It ensures only one node downloads a model from external sources (e.g., HuggingFace); other nodes receive weights via P2P or shared storage—eliminating duplicate downloads and reducing cluster ingress. +ModelExpress orchestrates the full flow—from download to GPU memory. It ensures only one node downloads a model from external sources (e.g., HuggingFace, NGC, GCS, or OCI registries); other nodes receive weights via P2P or shared storage—eliminating duplicate downloads and reducing cluster ingress. -1. **Download from HuggingFace** — One node pulls the model; ModelExpress coordinates so no other node duplicates this download, reducing external ingress. In air-gapped mode, serve from cache only (`HF_HUB_OFFLINE=1`). +1. **Download from a model source** — One node pulls the model from HuggingFace, NGC, GCS, or a file/archive OCI artifact; ModelExpress coordinates so no other node duplicates this download, reducing external ingress. In air-gapped HuggingFace mode, serve from cache only (`HF_HUB_OFFLINE=1`). 2. **Persist to disk** — Store in a cache backed by disk: - **Host-attached disk** — Local disk on the node (single-node or per-node cache). - **PVC** — RWO (ReadWriteOnce) for single-node; RWX (ReadWriteMany) for shared access across nodes. @@ -54,7 +54,7 @@ ModelExpress orchestrates the full flow—from download to GPU memory. It ensure ## Features - **Cold start reduction** — GPU-to-GPU P2P transfer over InfiniBand instead of disk load -- **HuggingFace caching** — PVC-backed cache, `HF_HUB_OFFLINE`, `ignore_weights`, `get_model_path` for Dynamo +- **Model source caching** — HuggingFace, NGC, GCS, and OCI artifact providers with PVC-backed cache support, `ignore_weights`, and `get_model_path` for Dynamo - **P2P GPU transfer** — vLLM `mx` loader and TRT-LLM `PRESHARDED` loader with NVIDIA NIXL over RDMA - **Metadata backends** — In-memory, Redis, or Kubernetes CRD (layered write-through for HA) - **Kubernetes** — Helm chart, CRDs/Redis for P2P, no-shared-storage support @@ -98,9 +98,9 @@ ModelExpress orchestrates the full flow—from download to GPU memory. It ensure - **modelexpress_server**: gRPC server with configurable metadata backends (Redis, Kubernetes CRD). - **modelexpress_client**: Rust CLI for cache management; Python package with vLLM loaders and `MxClient` for gRPC. -- **modelexpress_common**: Protobuf definitions, provider trait (HuggingFace), shared configuration. +- **modelexpress_common**: Protobuf definitions, provider trait (HuggingFace, NGC, GCS, OCI), shared configuration. -See [Architecture](docs/ARCHITECTURE.md). +See [Architecture](docs/ARCHITECTURE.md), [GCS provider](docs/GCS_PROVIDER.md), and [OCI provider](docs/OCI_PROVIDER.md). --- @@ -241,7 +241,6 @@ cargo bench - **DRAM and NVMe-resident shard streaming**: Stream shards across workers while keeping weights in DRAM and host local high-speed NVMe. - **RL workloads**: Explore fast P2P transfers to optimize RL refit phase and support for weight resharding. - **Earlier weight availability**: Bring weights to prefill earlier; identify prefill workers that can act as strong source nodes. -- **Expanded model pull providers**: Support NGC in addition to Hugging Face. - **GDS (GPUDirect Storage) integration**: Load model weights directly from NVMe into GPU memory, bypassing the CPU/DRAM copy path. - **Multi-tier cache hierarchy**: Promote and demote models across DRAM, NVMe, and PVC tiers based on access patterns. - **Distributed sharded cache**: Shard large models across nodes using consistent hashing and parallel shard assembly. diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index baee52f5..8fd0a32d 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -11,7 +11,7 @@ Detailed reference document for the ModelExpress codebase. For deployment and co ModelExpress is a Rust-based model cache management service and GPU-to-GPU model weight transfer system. It serves two roles: -- **Model Cache Service** - A sidecar alongside inference solutions (vLLM, SGLang, NVIDIA Dynamo) that accelerates model downloads from HuggingFace, NGC, and GCS. Model lifecycle state lives in a distributed registry — Redis or Kubernetes CRDs (`ModelCacheEntry`), selected via `MX_METADATA_BACKEND` — so multiple server replicas can coordinate without a shared-filesystem database. LRU cache eviction runs off the same registry. +- **Model Cache Service** - A sidecar alongside inference solutions (vLLM, SGLang, NVIDIA Dynamo) that accelerates model downloads from HuggingFace, NGC, GCS, and file/archive OCI artifacts. Model lifecycle state lives in a distributed registry — Redis or Kubernetes CRDs (`ModelCacheEntry`), selected via `MX_METADATA_BACKEND` — so multiple server replicas can coordinate without a shared-filesystem database. LRU cache eviction runs off the same registry. - **P2P Weight Transfer** - GPU-to-GPU model weight transfers between vLLM instances using NVIDIA NIXL over RDMA/InfiniBand, enabling ~15-second transfers for 681GB models. ### Current Status @@ -31,6 +31,7 @@ graph TD S1 --> HF[HuggingFace Hub] S1 --> NGC[NVIDIA NGC] S1 --> GCS[Google Cloud Storage] + S1 --> OCI[OCI Registry] S1 --> Cache[Model Cache Dir] end @@ -177,7 +178,8 @@ ModelExpress/ │ ├── gcs.rs # GcsProvider implementation │ ├── gcs/ # GCS manifest, cache layout, locking, download helpers │ ├── huggingface.rs # HuggingFaceProvider implementation -│ └── ngc.rs # NgcProvider implementation +│ ├── ngc.rs # NgcProvider implementation +│ └── oci.rs # OciProvider implementation │ ├── workspace-tests/ │ ├── Cargo.toml @@ -283,7 +285,7 @@ Four proto files define four services, all compiled via `tonic-build` in `modele | `StreamModelFiles` | `ModelFilesRequest` | stream `FileChunk` | Stream model file contents (1MB chunks) | | `ListModelFiles` | `ModelFilesRequest` | `ModelFileList` | List files with sizes | -Key message types: `ModelProvider` (HuggingFace, NGC, GCS), `ModelStatus` (Downloading, Downloaded, Error), `ModelStatusUpdate`, `FileChunk`. +Key message types: `ModelProvider` (HuggingFace, NGC, GCS, OCI), `ModelStatus` (Downloading, Downloaded, Error), `ModelStatusUpdate`, `FileChunk`. ### p2p.proto - P2pService @@ -465,7 +467,7 @@ Output formats: `--format human` (default), `--format json`, `--format json-pret | `config` | Config trait utilities | | `download` | Download orchestration with strategy pattern | | `models` | `Status`, `ModelProvider`, `ModelStatus`, `ModelStatusResponse` | -| `providers` | `ModelProviderTrait` + `HuggingFaceProvider` + `NgcProvider` + `GcsProvider` | +| `providers` | `ModelProviderTrait` + `HuggingFaceProvider` + `NgcProvider` + `GcsProvider` + `OciProvider` | | `grpc` | Generated tonic stubs for all 4 services | | `constants` | `DEFAULT_GRPC_PORT` (8001), `DEFAULT_TIMEOUT_SECS` (30), `DEFAULT_TRANSFER_CHUNK_SIZE` (32KB) | @@ -484,10 +486,11 @@ pub trait ModelProviderTrait: Send + Sync { } ``` -Three implementations: -- `HuggingFaceProvider` - uses the `hf-hub` crate with high-CPU download mode. -- `NgcProvider` - downloads from NVIDIA NGC via the V2 artifact API (Bearer-authenticated `/files/{path}` for team artifacts; presigned S3 URLs for org-level artifacts). Falls back to `checksums.blake3` manifest enumeration when bulk file listing returns 400. Resolves the NGC API key from `NGC_API_KEY`, `NGC_CLI_API_KEY`, or `~/.ngc/config`. -- `GcsProvider` - downloads objects under a full `gs:///` URL using Google Application Default Credentials. It writes a `.mx/manifest.json` cache manifest, verifies downloaded files with GCS CRC32C checksums, skips dotfiles, README, and images, and stores models under `/gcs//`. See [`GCS_PROVIDER.md`](GCS_PROVIDER.md) for the detailed design. +Provider implementations: +- `HuggingFaceProvider` — uses the `hf-hub` crate with high-CPU download mode. +- `NgcProvider` — downloads from NVIDIA NGC via the V2 artifact API (Bearer-authenticated `/files/{path}` for team artifacts; presigned S3 URLs for org-level artifacts). Falls back to `checksums.blake3` manifest enumeration when bulk file listing returns 400. Resolves the NGC API key from `NGC_API_KEY`, `NGC_CLI_API_KEY`, or `~/.ngc/config`. +- `GcsProvider` — downloads objects under a full `gs:///` URL using Google Application Default Credentials. It writes a `.mx/manifest.json` cache manifest, verifies downloaded files with GCS CRC32C checksums, skips dotfiles, README, and images, and stores models under `/gcs//`. See [`GCS_PROVIDER.md`](GCS_PROVIDER.md) for the detailed design. +- `OciProvider` — downloads OCI model artifacts via `oci-client`. Raw layers use `org.opencontainers.image.title` or `org.cncf.model.filepath` as the output file path; simple `tar` and `tar+zstd` layers are safely extracted. ModelExpress atomically publishes the completed `files` directory. Container image unpacking remains out of scope: no whiteouts or rootfs layer merging. See [`OCI_PROVIDER.md`](OCI_PROVIDER.md). ### ClientConfig / ClientArgs diff --git a/docs/CLI.md b/docs/CLI.md index 819c0d10..1c8640e5 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -97,6 +97,11 @@ modelexpress-cli model download gs://my-bucket/models/qwen/rev-1 \ modelexpress-cli model download microsoft/DialoGPT-medium \ --strategy direct +# Download an OCI artifact from a registry +modelexpress-cli model download registry.example.com/team/model:v1 \ + --provider oci \ + --strategy direct + # Download with file transfer when no shared storage exists # Note: Global options must come before the subcommand modelexpress-cli --no-shared-storage --transfer-chunk-size 65536 \ @@ -153,9 +158,12 @@ modelexpress-cli model stats --detailed - `hugging-face`: Hugging Face model hub (default) - `ngc`: NVIDIA NGC catalog - `gcs`: Google Cloud Storage object prefix. The model name must be a full `gs:///` URL. See [`GCS_PROVIDER.md`](GCS_PROVIDER.md) for cache layout and provider behavior. +- `oci`: OCI model artifact with raw file blobs or simple `tar`/`tar+zstd` archive layers. References must be registry-qualified and include a tag or digest, for example `oci://registry.example.com/team/model:v1` or `registry.example.com/team/model@sha256:...`. See [`OCI_PROVIDER.md`](OCI_PROVIDER.md) for artifact format, cache layout, and publish behavior. For GCS downloads, configure Google Application Default Credentials on the process that performs the download: the server for `server-only`, the client for `direct`, and either process for `smart-fallback`. Common options are `GOOGLE_APPLICATION_CREDENTIALS`, `gcloud auth application-default login`, or Workload Identity on GKE. +For OCI downloads, set `MODEL_EXPRESS_OCI_*` credentials on the process that performs the download when anonymous registry access is not enough. See [`OCI_PROVIDER.md`](OCI_PROVIDER.md) for the exact auth precedence. + **Model Commands:** - `download`: Download model with automatic storage (use `--strategy` and `--provider` for options) - `init`: Initialize model storage configuration diff --git a/docs/DEPLOYMENT.md b/docs/DEPLOYMENT.md index fe7ff346..e42d5e2b 100644 --- a/docs/DEPLOYMENT.md +++ b/docs/DEPLOYMENT.md @@ -171,6 +171,8 @@ Cache directory resolution for NGC: `MODEL_EXPRESS_CACHE_DIRECTORY` -> `~/.cache GCS uses the configured/default ModelExpress cache root; `MODEL_EXPRESS_CACHE_DIRECTORY` overrides it. Cached GCS models are stored under `/gcs//`. See [`GCS_PROVIDER.md`](GCS_PROVIDER.md) for provider internals. +OCI uses the configured/default ModelExpress cache root; `MODEL_EXPRESS_CACHE_DIRECTORY` overrides it. Cached OCI artifacts are stored under `/oci///tags//files` or `/oci///digests/-/files`. See [`OCI_PROVIDER.md`](OCI_PROVIDER.md) for provider internals. + See [`CLI.md`](CLI.md) for full CLI usage documentation. ## Docker @@ -259,6 +261,17 @@ kubectl create secret generic gcs-service-account-key \ Mount the secret into the server or client pod and set `GOOGLE_APPLICATION_CREDENTIALS` to the mounted file path. When using Workload Identity, no key secret is needed. For cache layout, manifest behavior, and failure modes, see [`GCS_PROVIDER.md`](GCS_PROVIDER.md). +### OCI Registry Credentials + +OCI artifact downloads use registry-qualified refs such as `oci://registry.example.com/team/model:v1` or `registry.example.com/team/model@sha256:...`. Auth is selected in this order: + +1. `MODEL_EXPRESS_OCI_BEARER_TOKEN` +2. `MODEL_EXPRESS_OCI_USERNAME` plus `MODEL_EXPRESS_OCI_PASSWORD` +3. `MODEL_EXPRESS_OCI_USERNAME` plus `MODEL_EXPRESS_OCI_TOKEN` +4. Anonymous access + +For artifact format, archive support, cache layout, and failure behavior, see [`OCI_PROVIDER.md`](OCI_PROVIDER.md). + ### Helm Chart The `helm/` directory provides a full Helm chart with configurable replicas, PVC, ingress, and resource limits. diff --git a/docs/OCI_PROVIDER.md b/docs/OCI_PROVIDER.md new file mode 100644 index 00000000..56bc56a1 --- /dev/null +++ b/docs/OCI_PROVIDER.md @@ -0,0 +1,83 @@ + + +# OCI Provider + +ModelExpress can download file-oriented OCI model artifacts. The provider supports raw file blobs and simple archive layers. It uses the Rust `oci-client` crate for registry reference parsing, authentication, manifest fetches, and blob streaming. + +OCI support is a materializer, not a container image unpacker. It does not apply whiteouts, root filesystem merges, symlinks, hardlinks, or special files. + +## References + +Use `--provider oci` with a registry-qualified reference that includes a tag or digest: + +```bash +modelexpress-cli model download registry.example.com/team/model:v1 --provider oci +modelexpress-cli model download oci://registry.example.com/team/model:v1 --provider oci +modelexpress-cli model download registry.example.com/team/model@sha256: --provider oci +``` + +The optional `oci://` prefix is stripped before parsing and cache key generation. + +## Artifact Format + +Raw file layers must include `org.opencontainers.image.title` or `org.cncf.model.filepath`. ModelExpress uses that annotation as the output path relative to the model directory. + +Archive layers are supported when their media type is `tar` or `tar+zstd`, including `application/vnd.oci.image.layer.v1.tar+zstd` and model-specific media types ending in `.tar`. Tar member paths are materialized relative to the model directory. Layer titles are labels only; include any desired directory prefixes in the tar member names. + +The provider rejects empty paths, absolute paths, `.` and `..` components, backslashes, non-UTF-8 path data, duplicate output paths, symlinks, hardlinks, and special archive entries. README files, dotfiles, and images are skipped. When `ignore_weights=true`, raw weight-file layers are skipped before download and archive-like layers are skipped as whole blobs. + +Example artifact layout: + +```bash +oras push registry.example.com/team/model:v1 \ + config.json:application/json \ + tokenizer.json:application/json \ + model.safetensors:application/octet-stream +``` + +Example archive artifact layout: + +```text +layer media type: application/vnd.oci.image.layer.v1.tar+zstd +tar members: + tokenizer/tokenizer.json + part-0/program.0.gas + part-1/program.8.gas +``` + +This materializes those same tar member paths under the cache entry. + +## Authentication + +Authentication uses this precedence: + +1. `MODEL_EXPRESS_OCI_BEARER_TOKEN` +2. `MODEL_EXPRESS_OCI_USERNAME` plus `MODEL_EXPRESS_OCI_PASSWORD` +3. `MODEL_EXPRESS_OCI_USERNAME` plus `MODEL_EXPRESS_OCI_TOKEN` +4. Anonymous access + +## Cache Layout + +OCI artifacts are cached under the ModelExpress cache root: + +```text +/oci///tags//files +/oci///digests/-/files +``` + +The provider follows NGC-like cache reuse semantics: `ignore_weights` affects which files are materialized during the download, but it is not part of the cache identity. An existing non-empty `files` directory for the same OCI reference is reused. + +## Publish Behavior + +Downloads materialize into a staging directory: + +```text +/oci/.tmp//files +``` + +Raw blobs stream directly into files. Archive blobs stream to a temporary blob file under the staging entry, extract into `files`, and are removed before publish. + +After all selected blobs are written, the staging entry is atomically renamed into the final cache path. If the final cache entry already exists and has a non-empty `files` directory, ModelExpress removes the staging entry and reuses the existing cache. If the final cache entry exists but is incomplete or corrupt, publish fails with a cache-corruption error and removes the staging entry; clear the corrupt cache entry before retrying. diff --git a/docs/metadata.md b/docs/metadata.md index aa895fec..b630222a 100644 --- a/docs/metadata.md +++ b/docs/metadata.md @@ -173,7 +173,7 @@ Three types of Redis keys are relevant: | Field | Value | Purpose | |-------|-------|---------| -| `provider` | `HuggingFace`, `Ngc`, or `Gcs` | Provider associated with the cached model | +| `provider` | `HuggingFace`, `Ngc`, `Gcs`, or `Oci` | Provider associated with the cached model | | `status` | `DOWNLOADING`, `DOWNLOADED`, or `ERROR` | Download lifecycle state | | `created_at` | RFC3339 timestamp | First write time, preserved across status updates | | `last_used_at` | RFC3339 timestamp | Last status write or cache hit time for LRU eviction | diff --git a/examples/crds.yaml b/examples/crds.yaml index d266e460..dac5c298 100644 --- a/examples/crds.yaml +++ b/examples/crds.yaml @@ -183,6 +183,8 @@ spec: enum: - HuggingFace - Ngc + - Gcs + - Oci status: type: object properties: diff --git a/modelexpress-cli-completion.bash b/modelexpress-cli-completion.bash index 0d146a1e..223f1e7d 100644 --- a/modelexpress-cli-completion.bash +++ b/modelexpress-cli-completion.bash @@ -73,7 +73,7 @@ _model_express_cli_completions() { elif [[ "${words[i+1]}" == "download" ]]; then case "${prev}" in --provider|-p) - COMPREPLY=($(compgen -W "hugging-face" -- "$cur")) + COMPREPLY=($(compgen -W "hugging-face ngc gcs oci" -- "$cur")) ;; --strategy|-s) COMPREPLY=($(compgen -W "smart-fallback server-only direct" -- "$cur")) @@ -108,13 +108,16 @@ _model_express_cli_completions() { fi elif [[ "${words[i+1]}" == "clear" ]]; then case "${prev}" in + --provider|-p) + COMPREPLY=($(compgen -W "hugging-face ngc gcs oci" -- "$cur")) + ;; clear) # Could potentially list actual downloaded models here COMPREPLY=($(compgen -W "google-t5/t5-small microsoft/DialoGPT-small" -- "$cur")) ;; *) if [[ "$cur" == -* ]]; then - COMPREPLY=($(compgen -W "--help" -- "$cur")) + COMPREPLY=($(compgen -W "--provider --help" -- "$cur")) fi ;; esac diff --git a/modelexpress_client/src/bin/cli.rs b/modelexpress_client/src/bin/cli.rs index e94bde0a..f2fb4545 100644 --- a/modelexpress_client/src/bin/cli.rs +++ b/modelexpress_client/src/bin/cli.rs @@ -147,6 +147,9 @@ mod tests { let parsed = ModelProvider::from_str("gcs", false).expect("Failed to parse gcs provider"); assert_eq!(parsed, ModelProvider::Gcs); + + let parsed = ModelProvider::from_str("oci", false).expect("Failed to parse oci provider"); + assert_eq!(parsed, ModelProvider::Oci); } #[test] diff --git a/modelexpress_common/Cargo.toml b/modelexpress_common/Cargo.toml index f2467a83..c7d4cdaf 100644 --- a/modelexpress_common/Cargo.toml +++ b/modelexpress_common/Cargo.toml @@ -37,6 +37,10 @@ google-cloud-storage = { workspace = true } rustls = { workspace = true } crc32c = { workspace = true } fd-lock = { workspace = true } +oci-client = { workspace = true } +tar = { workspace = true } +uuid = { workspace = true } +zstd = { workspace = true } [dev-dependencies] google-cloud-auth = { version = "1.7.0", default-features = false } @@ -45,6 +49,7 @@ mockall = "0.13" tempfile = { workspace = true } tokio-test = "0.4" wiremock = "0.6.5" +sha2 = { workspace = true } [build-dependencies] tonic-build = { workspace = true } diff --git a/modelexpress_common/proto/model.proto b/modelexpress_common/proto/model.proto index ec3f2bc7..06e988d4 100644 --- a/modelexpress_common/proto/model.proto +++ b/modelexpress_common/proto/model.proto @@ -45,6 +45,7 @@ enum ModelProvider { HUGGING_FACE = 0; NGC = 1; GCS = 2; + OCI = 3; } // Request for streaming model files diff --git a/modelexpress_common/src/cache.rs b/modelexpress_common/src/cache.rs index 366d8fc3..a8f43c55 100644 --- a/modelexpress_common/src/cache.rs +++ b/modelexpress_common/src/cache.rs @@ -6,6 +6,7 @@ use crate::{ models::ModelProvider, providers::{ gcs::GcsProviderCache, huggingface::HuggingFaceProviderCache, ngc::NgcProviderCache, + oci::OciProviderCache, }, }; use anyhow::{Context, Result}; @@ -236,6 +237,7 @@ impl CacheConfig { ModelProvider::HuggingFace, ModelProvider::Ngc, ModelProvider::Gcs, + ModelProvider::Oci, ] { models.extend(cache_for_provider(provider).list_models(&self.local_path)?); } @@ -345,6 +347,7 @@ pub(crate) fn cache_for_provider(provider: ModelProvider) -> &'static dyn Provid ModelProvider::HuggingFace => &HuggingFaceProviderCache, ModelProvider::Ngc => &NgcProviderCache, ModelProvider::Gcs => &GcsProviderCache, + ModelProvider::Oci => &OciProviderCache, } } @@ -379,6 +382,7 @@ fn provider_sort_key(provider: ModelProvider) -> u8 { ModelProvider::HuggingFace => 0, ModelProvider::Ngc => 1, ModelProvider::Gcs => 2, + ModelProvider::Oci => 3, } } diff --git a/modelexpress_common/src/download.rs b/modelexpress_common/src/download.rs index f9b72513..7cc3e284 100644 --- a/modelexpress_common/src/download.rs +++ b/modelexpress_common/src/download.rs @@ -2,7 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use crate::models::ModelProvider; -use crate::providers::{GcsProvider, HuggingFaceProvider, ModelProviderTrait, NgcProvider}; +use crate::providers::{ + GcsProvider, HuggingFaceProvider, ModelProviderTrait, NgcProvider, OciProvider, +}; use anyhow::Result; use std::path::PathBuf; use tracing::{info, warn}; @@ -14,6 +16,7 @@ pub fn get_provider(provider: ModelProvider) -> Box { ModelProvider::HuggingFace => Box::new(HuggingFaceProvider), ModelProvider::Ngc => Box::new(NgcProvider), ModelProvider::Gcs => Box::new(GcsProvider), + ModelProvider::Oci => Box::new(OciProvider), } } @@ -112,6 +115,9 @@ mod tests { let provider = get_provider(ModelProvider::Gcs); assert_eq!(provider.provider_name(), "GCS"); + + let provider = get_provider(ModelProvider::Oci); + assert_eq!(provider.provider_name(), "OCI"); } #[test] @@ -126,6 +132,14 @@ mod tests { .expect("Expected canonical model name"), "gs://test-bucket/org/model/rev-1" ); + assert_eq!( + canonical_model_name( + "oci://registry.example.com/team/model:v1", + ModelProvider::Oci + ) + .expect("Expected canonical OCI model name"), + "registry.example.com/team/model:v1" + ); } #[tokio::test] diff --git a/modelexpress_common/src/lib.rs b/modelexpress_common/src/lib.rs index 3cbbf43b..76beda4a 100644 --- a/modelexpress_common/src/lib.rs +++ b/modelexpress_common/src/lib.rs @@ -173,6 +173,7 @@ impl From for grpc::model::ModelProvider { models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace, models::ModelProvider::Ngc => grpc::model::ModelProvider::Ngc, models::ModelProvider::Gcs => grpc::model::ModelProvider::Gcs, + models::ModelProvider::Oci => grpc::model::ModelProvider::Oci, } } } @@ -183,6 +184,7 @@ impl From for models::ModelProvider { grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace, grpc::model::ModelProvider::Ngc => models::ModelProvider::Ngc, grpc::model::ModelProvider::Gcs => models::ModelProvider::Gcs, + grpc::model::ModelProvider::Oci => models::ModelProvider::Oci, } } } @@ -297,9 +299,11 @@ mod tests { models::ModelProvider::HuggingFace, models::ModelProvider::Ngc, models::ModelProvider::Gcs, + models::ModelProvider::Oci, ] { let grpc_provider: grpc::model::ModelProvider = model_provider.into(); let back_to_model: models::ModelProvider = grpc_provider.into(); + assert_eq!(model_provider, back_to_model); } } diff --git a/modelexpress_common/src/models.rs b/modelexpress_common/src/models.rs index f0e9f9d7..528d8c66 100644 --- a/modelexpress_common/src/models.rs +++ b/modelexpress_common/src/models.rs @@ -34,6 +34,8 @@ pub enum ModelProvider { Ngc, /// Google Cloud Storage Gcs, + /// File or archive model artifact in an OCI registry + Oci, } impl ModelProvider { @@ -43,6 +45,7 @@ impl ModelProvider { Self::HuggingFace => "hugging-face", Self::Ngc => "ngc", Self::Gcs => "gcs", + Self::Oci => "oci", } } } @@ -55,7 +58,7 @@ impl Display for ModelProvider { impl ValueEnum for ModelProvider { fn value_variants<'a>() -> &'a [Self] { - &[Self::HuggingFace, Self::Ngc, Self::Gcs] + &[Self::HuggingFace, Self::Ngc, Self::Gcs, Self::Oci] } fn to_possible_value(&self) -> Option { @@ -91,6 +94,7 @@ mod tests { ModelProvider::HuggingFace, ModelProvider::Ngc, ModelProvider::Gcs, + ModelProvider::Oci, ] { let serialized = serde_json::to_string(&provider).expect("Failed to serialize ModelProvider"); @@ -111,6 +115,7 @@ mod tests { assert_eq!(ModelProvider::HuggingFace.to_string(), "hugging-face"); assert_eq!(ModelProvider::Ngc.to_string(), "ngc"); assert_eq!(ModelProvider::Gcs.to_string(), "gcs"); + assert_eq!(ModelProvider::Oci.to_string(), "oci"); } #[test] @@ -119,6 +124,7 @@ mod tests { ModelProvider::HuggingFace, ModelProvider::Ngc, ModelProvider::Gcs, + ModelProvider::Oci, ] { let parsed = ModelProvider::from_str(provider.as_str(), false) .expect("Failed to parse ModelProvider from clap value"); diff --git a/modelexpress_common/src/providers.rs b/modelexpress_common/src/providers.rs index b649c5f6..1de54422 100644 --- a/modelexpress_common/src/providers.rs +++ b/modelexpress_common/src/providers.rs @@ -86,10 +86,12 @@ pub trait ModelProviderTrait: Send + Sync { pub mod gcs; pub mod huggingface; pub mod ngc; +pub mod oci; pub use gcs::GcsProvider; pub use huggingface::HuggingFaceProvider; pub use ngc::NgcProvider; +pub use oci::OciProvider; #[cfg(test)] mod tests { diff --git a/modelexpress_common/src/providers/oci.rs b/modelexpress_common/src/providers/oci.rs new file mode 100644 index 00000000..bdd1fc98 --- /dev/null +++ b/modelexpress_common/src/providers/oci.rs @@ -0,0 +1,161 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{Utils, cache::ProviderCache, constants, providers::ModelProviderTrait}; +use anyhow::Result; +use std::{env, path::PathBuf}; +use tracing::info; + +mod archive_format; +mod cache_entry; +mod downloader; +mod layer_download; +mod path; +mod provider_cache; +mod reference; +mod registry_auth; + +use cache_entry::{CacheEntry, StagingCacheEntry}; +use downloader::Downloader; +use reference::OciReference; + +pub(crate) use provider_cache::OciProviderCache; + +const MODEL_EXPRESS_CACHE_ENV_VAR: &str = "MODEL_EXPRESS_CACHE_DIRECTORY"; + +/// File-oriented OCI artifact provider implementation. +pub struct OciProvider; + +impl OciProvider { + fn cache_root(cache_dir: Option) -> PathBuf { + if let Some(dir) = cache_dir { + return dir; + } + + if let Ok(cache_path) = env::var(MODEL_EXPRESS_CACHE_ENV_VAR) { + return PathBuf::from(cache_path); + } + + let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string()); + PathBuf::from(home).join(constants::DEFAULT_CACHE_PATH) + } +} + +#[async_trait::async_trait] +impl ModelProviderTrait for OciProvider { + async fn download_model( + &self, + model_name: &str, + cache_dir: Option, + ignore_weights: bool, + ) -> Result { + let cache_root = Self::cache_root(cache_dir); + let reference = OciReference::parse(model_name)?; + let final_entry = CacheEntry::new(&cache_root, &reference); + + // Keep OCI cache reuse NGC-like for now: a non-empty artifact directory is + // considered usable, even if it was created by an ignore_weights download. + if let Some(existing) = final_entry.existing_files_dir()? { + info!( + "OCI model '{model_name}' found in cache at {}", + existing.display() + ); + return Ok(existing); + } + + let staging_entry = StagingCacheEntry::new(&cache_root); + staging_entry.create().await?; + + let downloader = Downloader::new(model_name, &reference); + downloader + .download_to_staging(&staging_entry, ignore_weights) + .await?; + + let files_dir = final_entry.publish_from(&staging_entry)?; + info!( + "Downloaded OCI artifact '{model_name}' to {}", + files_dir.display() + ); + Ok(files_dir) + } + + async fn delete_model(&self, model_name: &str, cache_dir: PathBuf) -> Result<()> { + OciProviderCache.clear_model(&cache_dir, model_name) + } + + async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result { + let reference = OciReference::parse(model_name)?; + let entry = CacheEntry::new(&cache_dir, &reference); + entry.existing_files_dir()?.ok_or_else(|| { + anyhow::anyhow!( + "OCI model '{model_name}' not found in cache (expected {})", + entry.files_dir().display() + ) + }) + } + + fn canonical_model_name(&self, model_name: &str) -> Result { + Ok(OciReference::parse(model_name)?.canonical_name()) + } + + fn provider_name(&self) -> &'static str { + "OCI" + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + use super::{cache_entry::FILES_DIR_NAME as FILES_DIR, reference::OciReference}; + + #[test] + fn test_canonical_model_name_accepts_oci_scheme() { + assert_eq!( + OciProvider + .canonical_model_name("oci://registry.example.com/team/model:v1") + .expect("canonical ref"), + "registry.example.com/team/model:v1" + ); + + let digest = "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + assert_eq!( + OciProvider + .canonical_model_name(&format!( + "oci://registry.example.com/team/model:v1@{digest}" + )) + .expect("canonical digest ref"), + format!("registry.example.com/team/model@{digest}") + ); + } + + #[tokio::test] + async fn test_get_model_path_rejects_missing_or_incomplete_cache() { + let dir = TempDir::new().expect("temp dir"); + let missing = OciProvider + .get_model_path( + "registry.example.com/team/model:v1", + dir.path().to_path_buf(), + ) + .await + .expect_err("missing cache should fail"); + assert!(missing.to_string().contains("not found in cache")); + + let reference = OciReference::parse("registry.example.com/team/model:v1") + .expect("reference should parse"); + let entry = CacheEntry::path_for(dir.path(), &reference); + fs::create_dir_all(entry.join(FILES_DIR)).expect("create incomplete files dir"); + + let incomplete = OciProvider + .get_model_path( + "registry.example.com/team/model:v1", + dir.path().to_path_buf(), + ) + .await + .expect_err("incomplete cache should fail"); + assert!(incomplete.to_string().contains("incomplete or corrupt")); + } +} diff --git a/modelexpress_common/src/providers/oci/archive_format.rs b/modelexpress_common/src/providers/oci/archive_format.rs new file mode 100644 index 00000000..b97c552b --- /dev/null +++ b/modelexpress_common/src/providers/oci/archive_format.rs @@ -0,0 +1,325 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::path::ArtifactPath; +use anyhow::{Context, Result}; +use std::{ + collections::HashSet, + fs, + io::{BufReader, Read}, + path::{Path, PathBuf}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArchiveFormat { + Tar, + TarZstd, +} + +impl ArchiveFormat { + pub fn is_archive_media_type(media_type: &str) -> bool { + let media_type = media_type.to_ascii_lowercase(); + + media_type == "application/x-tar" + || media_type.ends_with(".tar") + || media_type.ends_with("+tar") + || media_type.contains(".tar+") + || media_type.contains("+tar+") + || media_type.contains("-tar+") + } + + pub fn from_media_type(media_type: &str) -> Result> { + let media_type = media_type.to_ascii_lowercase(); + + if media_type == "application/x-tar" + || media_type.ends_with(".tar") + || media_type.ends_with("+tar") + { + return Ok(Some(Self::Tar)); + } + + if media_type.ends_with(".tar+zstd") + || media_type.ends_with("+tar+zstd") + || media_type.ends_with("-tar+zstd") + { + return Ok(Some(Self::TarZstd)); + } + + if media_type.contains(".tar+") + || media_type.contains("+tar+") + || media_type.contains("-tar+") + { + anyhow::bail!( + "OCI archive layer media type '{media_type}' is not supported; supported archive formats are tar and tar+zstd" + ); + } + + Ok(None) + } + + pub fn extract_blob(self, blob_path: &Path, output_root: &Path) -> Result> { + let file = fs::File::open(blob_path) + .with_context(|| format!("Failed to open OCI archive blob {blob_path:?}"))?; + let reader = BufReader::new(file); + + match self { + Self::Tar => TarExtractor::new(output_root).extract(reader), + Self::TarZstd => { + let decoder = zstd::stream::read::Decoder::new(reader) + .with_context(|| format!("Failed to create zstd decoder for {blob_path:?}"))?; + TarExtractor::new(output_root).extract(decoder) + } + } + } +} + +struct TarExtractor<'a> { + output_root: &'a Path, + files: Vec, + seen_paths: HashSet, +} + +impl<'a> TarExtractor<'a> { + fn new(output_root: &'a Path) -> Self { + Self { + output_root, + files: Vec::new(), + seen_paths: HashSet::new(), + } + } + + fn extract(mut self, reader: R) -> Result> { + let mut archive = tar::Archive::new(reader); + + for entry in archive + .entries() + .context("Failed to read OCI tar entries")? + { + self.extract_entry(entry.context("Failed to read OCI tar entry")?)?; + } + + Ok(self.files) + } + + fn extract_entry(&mut self, mut entry: tar::Entry<'_, R>) -> Result<()> { + let entry_type = entry.header().entry_type(); + + if entry_type.is_dir() { + return Ok(()); + } + + if !entry_type.is_file() { + anyhow::bail!( + "OCI archive entry '{}' has unsupported type {:?}; only regular files are supported", + Self::entry_path_for_error(&entry), + entry_type + ); + } + + let relative_path = Self::member_path(&entry)?; + + if relative_path.is_skipped(false) { + tracing::debug!("Skipping OCI archive file: {relative_path}"); + return Ok(()); + } + + self.ensure_unique(&relative_path)?; + let output_path = self.create_output_path(&relative_path)?; + let mut output = Self::create_output_file(&output_path)?; + + std::io::copy(&mut entry, &mut output) + .with_context(|| format!("Failed to extract OCI archive file '{relative_path}'"))?; + output + .sync_all() + .with_context(|| format!("Failed to sync OCI archive output file {output_path:?}"))?; + + self.files.push(relative_path.to_string()); + Ok(()) + } + + fn member_path(entry: &tar::Entry<'_, R>) -> Result { + let path = entry.path().context("Failed to read OCI tar entry path")?; + ArtifactPath::from_relative_path( + path.as_ref(), + &format!("OCI archive member '{}'", path.display()), + ) + } + + fn entry_path_for_error(entry: &tar::Entry<'_, R>) -> String { + entry.path().map_or_else( + |_| "".to_string(), + |path| path.display().to_string(), + ) + } + + fn ensure_unique(&mut self, relative_path: &ArtifactPath) -> Result<()> { + if !self.seen_paths.insert(relative_path.clone()) { + anyhow::bail!("Duplicate OCI archive file path '{relative_path}'"); + } + Ok(()) + } + + fn create_output_path(&self, relative_path: &ArtifactPath) -> Result { + let output_path = self.output_root.join(relative_path.as_path()); + if output_path.exists() { + anyhow::bail!("Duplicate OCI artifact file path '{relative_path}'"); + } + + if let Some(parent) = output_path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!("Failed to create OCI archive output directory {parent:?}") + })?; + } + + Ok(output_path) + } + + fn create_output_file(output_path: &Path) -> Result { + fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(output_path) + .with_context(|| format!("Failed to create OCI archive output file {output_path:?}")) + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use std::{fs, io::Write, path::PathBuf}; + use tempfile::TempDir; + + fn tar_bytes(entries: &[(&str, &[u8])]) -> Vec { + let mut bytes = Vec::new(); + { + let mut builder = tar::Builder::new(&mut bytes); + for (path, contents) in entries { + let mut header = tar::Header::new_gnu(); + header.set_size(contents.len() as u64); + header.set_mode(0o644); + header.set_cksum(); + builder + .append_data(&mut header, path, *contents) + .expect("append tar entry"); + } + builder.finish().expect("finish tar"); + } + bytes + } + + fn write_tar_octal(field: &mut [u8], value: u64) { + let width = field.len().checked_sub(1).expect("tar field has room"); + let encoded = format!("{value:0width$o}\0"); + field.copy_from_slice(encoded.as_bytes()); + } + + fn unsafe_tar_bytes(path: &str, contents: &[u8]) -> Vec { + let mut header = [0_u8; 512]; + header[..path.len()].copy_from_slice(path.as_bytes()); + write_tar_octal(&mut header[100..108], 0o644); + write_tar_octal(&mut header[108..116], 0); + write_tar_octal(&mut header[116..124], 0); + write_tar_octal(&mut header[124..136], contents.len() as u64); + write_tar_octal(&mut header[136..148], 0); + header[148..156].fill(b' '); + header[156] = b'0'; + header[257..263].copy_from_slice(b"ustar\0"); + header[263..265].copy_from_slice(b"00"); + let checksum: u64 = header.iter().map(|byte| u64::from(*byte)).sum(); + let checksum = format!("{checksum:06o}\0 "); + header[148..156].copy_from_slice(checksum.as_bytes()); + + let mut bytes = Vec::new(); + bytes.extend_from_slice(&header); + bytes.extend_from_slice(contents); + let padding = match contents.len() % 512 { + 0 => 0, + remainder => 512usize + .checked_sub(remainder) + .expect("tar padding remainder is smaller than block size"), + }; + bytes.extend(std::iter::repeat_n(0, padding)); + bytes.extend_from_slice(&[0_u8; 1024]); + bytes + } + + fn write_blob(dir: &TempDir, bytes: &[u8]) -> PathBuf { + let path = dir.path().join("blob"); + let mut file = fs::File::create(&path).expect("create blob"); + file.write_all(bytes).expect("write blob"); + path + } + + #[test] + fn test_archive_format_accepts_x_tar_zstd_media_type() { + assert_eq!( + ArchiveFormat::from_media_type("application/x-tar+zstd") + .expect("media type should parse"), + Some(ArchiveFormat::TarZstd) + ); + } + + #[test] + fn test_extract_tar_archive_applies_ignore_rules() { + let dir = TempDir::new().expect("temp dir"); + let output = dir.path().join("out"); + fs::create_dir_all(&output).expect("create output"); + let tar = tar_bytes(&[ + ("program.0.gas", b"gas"), + ("README.md", b"readme"), + (".hidden", b"hidden"), + ("diagram.png", b"image"), + ]); + let blob = write_blob(&dir, &tar); + + let files = ArchiveFormat::Tar + .extract_blob(&blob, &output) + .expect("extract archive"); + + assert_eq!(files, vec!["program.0.gas".to_string()]); + assert_eq!( + fs::read(output.join("program.0.gas")).expect("read gas"), + b"gas" + ); + assert!(!output.join("README.md").exists()); + assert!(!output.join(".hidden").exists()); + assert!(!output.join("diagram.png").exists()); + } + + #[test] + fn test_extract_zstd_tar_archive() { + let dir = TempDir::new().expect("temp dir"); + let output = dir.path().join("out"); + fs::create_dir_all(&output).expect("create output"); + let tar = tar_bytes(&[("config.json", br#"{"ok":true}"#)]); + let compressed = zstd::stream::encode_all(tar.as_slice(), 3).expect("compress tar"); + let blob = write_blob(&dir, &compressed); + + let files = ArchiveFormat::TarZstd + .extract_blob(&blob, &output) + .expect("extract archive"); + + assert_eq!(files, vec!["config.json".to_string()]); + assert_eq!( + fs::read(output.join("config.json")).expect("read config"), + br#"{"ok":true}"# + ); + } + + #[test] + fn test_extract_tar_archive_rejects_unsafe_paths() { + let dir = TempDir::new().expect("temp dir"); + let output = dir.path().join("out"); + fs::create_dir_all(&output).expect("create output"); + let tar = unsafe_tar_bytes("../escape", b"bad"); + let blob = write_blob(&dir, &tar); + + let err = ArchiveFormat::Tar + .extract_blob(&blob, &output) + .expect_err("unsafe archive path should fail"); + + assert!(err.to_string().contains("..")); + assert!(!dir.path().join("escape").exists()); + } +} diff --git a/modelexpress_common/src/providers/oci/cache_entry.rs b/modelexpress_common/src/providers/oci/cache_entry.rs new file mode 100644 index 00000000..7305c050 --- /dev/null +++ b/modelexpress_common/src/providers/oci/cache_entry.rs @@ -0,0 +1,352 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::reference::OciReference; +use anyhow::{Context, Result}; +use std::{ + fs, io, + path::{Path, PathBuf}, +}; +use tracing::warn; +use uuid::Uuid; + +pub const CACHE_ROOT_DIR_NAME: &str = "oci"; +pub const TMP_DIR_NAME: &str = ".tmp"; +pub const FILES_DIR_NAME: &str = "files"; +const BLOBS_DIR_NAME: &str = ".blobs"; + +#[derive(Debug, Clone)] +pub struct CacheEntry { + path: PathBuf, +} + +impl CacheEntry { + pub fn new(cache_root: &Path, reference: &OciReference) -> Self { + Self { + path: Self::path_for(cache_root, reference), + } + } + + pub fn path_for(cache_root: &Path, reference: &OciReference) -> PathBuf { + let mut path = cache_root + .join(CACHE_ROOT_DIR_NAME) + .join(reference.registry()) + .join(repository_cache_key(reference.repository())); + + if let Some(digest) = reference.digest() { + path = path.join("digests").join(digest.replace(':', "-")); + } else if let Some(tag) = reference.tag() { + path = path.join("tags").join(tag); + } + + path + } + + pub fn path(&self) -> &Path { + &self.path + } + + pub fn files_dir(&self) -> PathBuf { + Self::files_dir_for(&self.path) + } + + pub fn existing_files_dir(&self) -> Result> { + Self::existing_files_dir_at(&self.path) + } + + pub fn publish_from(&self, staging: &StagingCacheEntry) -> Result { + let result = self.publish(staging); + staging.cleanup(); + result + } + + pub fn files_dir_is_non_empty(files_dir: &Path) -> Result { + Ok(fs::read_dir(files_dir) + .with_context(|| format!("Failed to read OCI files directory {files_dir:?}"))? + .next() + .is_some()) + } + + fn files_dir_for(entry_path: &Path) -> PathBuf { + entry_path.join(FILES_DIR_NAME) + } + + fn publish(&self, staging: &StagingCacheEntry) -> Result { + if let Some(existing) = Self::existing_files_dir_at(&self.path)? { + return Ok(existing); + } + + if let Some(parent) = self.path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create OCI cache parent {parent:?}"))?; + } + + match fs::rename(staging.path(), &self.path) { + Ok(()) => Ok(self.files_dir()), + Err(rename_error) if self.path.exists() => { + self.recover_existing_after_rename(rename_error) + } + Err(rename_error) => Err(anyhow::anyhow!(rename_error).context(format!( + "Failed to publish OCI cache entry from {} to {}", + staging.path().display(), + self.path.display() + ))), + } + } + + fn recover_existing_after_rename(&self, rename_error: io::Error) -> Result { + match Self::existing_files_dir_at(&self.path) { + Ok(Some(existing)) => Ok(existing), + Ok(None) => Err(anyhow::anyhow!( + "OCI cache publish failed for {}: {rename_error}", + self.path.display() + )), + Err(error) => Err(error).with_context(|| { + format!( + "OCI cache publish raced with an incomplete or corrupt cache entry at {}", + self.path.display() + ) + }), + } + } + + fn validate_path(entry_path: &Path) -> Result { + let files_dir = Self::files_dir_for(entry_path); + if !files_dir.is_dir() { + anyhow::bail!("missing files directory at {files_dir:?}"); + } + + if !Self::files_dir_is_non_empty(&files_dir)? { + anyhow::bail!("OCI files directory at {files_dir:?} is empty"); + } + + Ok(files_dir) + } + + fn existing_files_dir_at(entry_path: &Path) -> Result> { + if !entry_path.exists() { + return Ok(None); + } + + let files_dir = Self::validate_path(entry_path).with_context(|| { + format!( + "OCI cache entry at {} is incomplete or corrupt; remove it before retrying", + entry_path.display() + ) + })?; + + Ok(Some(files_dir)) + } +} + +pub fn repository_cache_key(repository: &str) -> String { + let mut key = String::with_capacity(repository.len()); + for character in repository.chars() { + match character { + '%' => key.push_str("%25"), + '/' => key.push_str("%2F"), + _ => key.push(character), + } + } + key +} + +pub fn repository_from_cache_key(key: &str) -> Result { + let mut repository = String::with_capacity(key.len()); + let mut characters = key.chars(); + + while let Some(character) = characters.next() { + if character != '%' { + repository.push(character); + continue; + } + + let first = characters + .next() + .ok_or_else(|| anyhow::anyhow!("Invalid OCI repository cache key '{key}'"))?; + let second = characters + .next() + .ok_or_else(|| anyhow::anyhow!("Invalid OCI repository cache key '{key}'"))?; + + match (first.to_ascii_uppercase(), second.to_ascii_uppercase()) { + ('2', '5') => repository.push('%'), + ('2', 'F') => repository.push('/'), + _ => anyhow::bail!("Invalid OCI repository cache key '{key}'"), + } + } + + Ok(repository) +} + +#[derive(Debug)] +pub struct StagingCacheEntry { + path: PathBuf, +} + +impl StagingCacheEntry { + pub fn new(cache_root: &Path) -> Self { + let path = cache_root + .join(CACHE_ROOT_DIR_NAME) + .join(TMP_DIR_NAME) + .join(Uuid::new_v4().to_string()); + Self { path } + } + + pub async fn create(&self) -> Result<()> { + let tmp_root = self.path.parent().ok_or_else(|| { + anyhow::anyhow!("OCI staging entry '{}' has no parent", self.path.display()) + })?; + tokio::fs::create_dir_all(tmp_root) + .await + .with_context(|| format!("Failed to create OCI temporary cache root {tmp_root:?}"))?; + tokio::fs::create_dir(&self.path) + .await + .with_context(|| format!("Failed to create OCI staging entry {:?}", self.path)) + } + + pub fn path(&self) -> &Path { + &self.path + } + + pub fn files_dir(&self) -> PathBuf { + CacheEntry::files_dir_for(&self.path) + } + + pub fn blob_root(&self) -> PathBuf { + self.path.join(BLOBS_DIR_NAME) + } + + pub fn cleanup(&self) { + if self.path.exists() + && let Err(error) = fs::remove_dir_all(&self.path) + { + warn!( + "Failed to remove OCI staging directory {}: {error}", + self.path.display() + ); + } + } +} + +impl Drop for StagingCacheEntry { + fn drop(&mut self) { + self.cleanup(); + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use crate::providers::oci::reference::OciReference; + + #[test] + fn test_cache_path_generation() { + let root = Path::new("/cache"); + let tagged = OciReference::parse("registry.example.com/team/model:v1") + .expect("tagged reference should parse"); + assert_eq!( + CacheEntry::path_for(root, &tagged), + PathBuf::from("/cache/oci/registry.example.com/team%2Fmodel/tags/v1") + ); + + let digest = "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + let by_digest = OciReference::parse(&format!("registry.example.com/team/model@{digest}")) + .expect("digest reference should parse"); + assert_eq!( + CacheEntry::path_for(root, &by_digest), + PathBuf::from( + "/cache/oci/registry.example.com/team%2Fmodel/digests/sha256-ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + ) + ); + } + + #[test] + fn test_cache_path_generation_keeps_repository_from_overlapping_layout() { + let root = Path::new("/cache"); + let nested = OciReference::parse("registry.example.com/team/model/tags/dev/files/other:v1") + .expect("nested reference should parse"); + let tagged = OciReference::parse("registry.example.com/team/model:dev") + .expect("tagged reference should parse"); + + assert_eq!( + CacheEntry::path_for(root, &nested), + PathBuf::from( + "/cache/oci/registry.example.com/team%2Fmodel%2Ftags%2Fdev%2Ffiles%2Fother/tags/v1" + ) + ); + assert!( + !CacheEntry::path_for(root, &nested) + .starts_with(CacheEntry::path_for(root, &tagged).join(FILES_DIR_NAME)) + ); + } + + #[test] + fn test_cache_entry_validation_requires_non_empty_files() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let entry = dir.path().join("entry"); + let files = entry.join(FILES_DIR_NAME); + fs::create_dir_all(&files).expect("create files dir"); + + let err = + CacheEntry::validate_path(&entry).expect_err("empty cache should fail validation"); + assert!(err.to_string().contains("is empty")); + + fs::write(files.join("config.json"), b"{}").expect("write model file"); + assert_eq!( + CacheEntry::validate_path(&entry).expect("valid cache"), + files + ); + } + + #[test] + fn test_existing_incomplete_cache_rejected() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let reference = OciReference::parse("registry.example.com/team/model:v1") + .expect("reference should parse"); + let entry = CacheEntry::path_for(dir.path(), &reference); + fs::create_dir_all(entry.join(FILES_DIR_NAME)).expect("create incomplete files dir"); + + let result = CacheEntry::existing_files_dir_at(&entry); + let err = result.expect_err("incomplete cache should fail"); + assert!(err.to_string().contains("incomplete or corrupt")); + } + + #[test] + fn test_publish_cleans_staging_when_existing_cache_is_corrupt() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let reference = OciReference::parse("registry.example.com/team/model:v1") + .expect("reference should parse"); + let final_entry = CacheEntry::new(dir.path(), &reference); + fs::create_dir_all(final_entry.path().join(FILES_DIR_NAME)) + .expect("create incomplete final files dir"); + + let staging_entry = StagingCacheEntry::new(dir.path()); + let staging_files = staging_entry.files_dir(); + fs::create_dir_all(&staging_files).expect("create staging files dir"); + fs::write(staging_files.join("config.json"), b"{}").expect("write staging model file"); + + let err = final_entry + .publish_from(&staging_entry) + .expect_err("corrupt final cache should fail publish"); + + assert!(err.to_string().contains("incomplete or corrupt")); + assert!(!staging_entry.path().exists()); + } + + #[test] + fn test_existing_cache_entry_uses_non_empty_files_dir() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let reference = OciReference::parse("registry.example.com/team/model:v1") + .expect("reference should parse"); + let entry = CacheEntry::path_for(dir.path(), &reference); + let files = entry.join(FILES_DIR_NAME); + fs::create_dir_all(&files).expect("create files dir"); + fs::write(files.join("config.json"), b"{}").expect("write model file"); + + let files_dir = CacheEntry::existing_files_dir_at(&entry) + .expect("cache lookup should succeed") + .expect("cache should exist"); + assert_eq!(files_dir, files); + } +} diff --git a/modelexpress_common/src/providers/oci/downloader.rs b/modelexpress_common/src/providers/oci/downloader.rs new file mode 100644 index 00000000..b59510ba --- /dev/null +++ b/modelexpress_common/src/providers/oci/downloader.rs @@ -0,0 +1,533 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + cache_entry::StagingCacheEntry, + layer_download::{LayerDownload, LayerDownloadKind, LayerDownloads}, + path::ArtifactPath, + reference::OciReference, + registry_auth, +}; +use anyhow::{Context, Result}; +use oci_client::{ + Client, + client::{ClientConfig, ClientProtocol}, + manifest::{OciDescriptor, OciImageManifest, OciManifest}, + secrets::RegistryAuth, +}; +use std::path::{Path, PathBuf}; +use tracing::info; + +const MANIFEST_FILE_NAME: &str = "manifest.json"; + +pub struct Downloader<'a> { + original_ref: &'a str, + reference: &'a OciReference, + auth: RegistryAuth, + client: Client, +} + +impl<'a> Downloader<'a> { + pub fn new(original_ref: &'a str, reference: &'a OciReference) -> Self { + Self { + original_ref, + reference, + auth: registry_auth::from_env(), + client: Self::client_for_reference(reference), + } + } + + pub async fn download_to_staging( + &self, + staging_entry: &StagingCacheEntry, + ignore_weights: bool, + ) -> Result<()> { + let staging_files = staging_entry.files_dir(); + tokio::fs::create_dir_all(&staging_files) + .await + .with_context(|| format!("Failed to create OCI staging directory {staging_files:?}"))?; + + let manifest = self.pull_image_manifest().await?; + if manifest.layers.is_empty() { + anyhow::bail!( + "OCI artifact '{}' contains no layer descriptors", + self.original_ref + ); + } + + let downloads = LayerDownloads::from_layers(&manifest.layers, ignore_weights)?; + self.download_layers(staging_entry, &staging_files, downloads.as_slice()) + .await?; + self.download_manifest_json(&manifest, &staging_files) + .await?; + + Ok(()) + } + + async fn pull_image_manifest(&self) -> Result { + let (manifest, _) = self + .client + .pull_manifest(self.reference.as_client_reference(), &self.auth) + .await + .with_context(|| format!("Failed to pull OCI manifest for '{}'", self.original_ref))?; + Self::image_manifest(manifest) + } + + async fn download_manifest_json( + &self, + manifest: &OciImageManifest, + staging_files: &Path, + ) -> Result<()> { + let output_path = staging_files.join(MANIFEST_FILE_NAME); + // The model artifact wins if it already provided manifest.json as a + // layer file or archive member; otherwise expose the OCI config blob as + // manifest.json so gbuild-produced models can carry model config there. + if tokio::fs::try_exists(&output_path) + .await + .with_context(|| format!("Failed to inspect OCI manifest.json {output_path:?}"))? + { + return Ok(()); + } + + self.pull_blob_to_file(&manifest.config, &output_path, "OCI manifest.json") + .await + .with_context(|| { + format!( + "Failed to download OCI config blob {} as manifest.json", + manifest.config.digest + ) + }) + } + + async fn download_layers( + &self, + staging_entry: &StagingCacheEntry, + staging_files: &Path, + downloads: &[LayerDownload], + ) -> Result { + let mut file_count = 0usize; + let blob_root = staging_entry.blob_root(); + + for download in downloads { + match &download.kind { + LayerDownloadKind::Raw { path } => { + self.download_raw_blob(download, staging_files, path) + .await?; + info!( + "Downloaded OCI blob {} for file '{}'", + download.descriptor.digest, path + ); + file_count = file_count.saturating_add(1); + } + LayerDownloadKind::Archive { format } => { + let path = self.download_archive_blob(download, &blob_root).await?; + // Archive member paths define the artifact layout. Layer title + // annotations are labels/debug metadata unless a manifest schema + // explicitly assigns placement semantics. + let extracted_files = + format.extract_blob(&path, staging_files).with_context(|| { + format!( + "Failed to extract OCI archive blob {}", + download.descriptor.digest + ) + })?; + + tokio::fs::remove_file(&path).await.with_context(|| { + format!("Failed to remove OCI temporary blob file {path:?}") + })?; + + file_count = file_count.saturating_add(extracted_files.len()); + } + } + } + + if tokio::fs::try_exists(&blob_root).await.with_context(|| { + format!("Failed to inspect OCI temporary blob directory {blob_root:?}") + })? { + tokio::fs::remove_dir_all(&blob_root) + .await + .with_context(|| { + format!("Failed to remove OCI temporary blob directory {blob_root:?}") + })?; + } + + Ok(file_count) + } + + async fn download_raw_blob( + &self, + download: &LayerDownload, + staging_files: &Path, + relative_path: &ArtifactPath, + ) -> Result<()> { + let output_path = staging_files.join(relative_path.as_path()); + self.pull_blob_to_file(&download.descriptor, &output_path, "OCI output file") + .await + .with_context(|| { + format!( + "Failed to download OCI blob {} for file '{}'", + download.descriptor.digest, relative_path + ) + })?; + + Ok(()) + } + + async fn download_archive_blob( + &self, + download: &LayerDownload, + blob_root: &Path, + ) -> Result { + let path = blob_root.join(download.descriptor.digest.replace(':', "-")); + self.pull_blob_to_file(&download.descriptor, &path, "OCI archive blob") + .await?; + + Ok(path) + } + + async fn pull_blob_to_file( + &self, + descriptor: &OciDescriptor, + output_path: &Path, + description: &str, + ) -> Result<()> { + if let Some(parent) = output_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create {description} directory {parent:?}"))?; + } + + let mut output = tokio::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(output_path) + .await + .with_context(|| format!("Failed to create {description} {output_path:?}"))?; + + info!( + "Downloading OCI blob {} to {}", + descriptor.digest, + output_path.display() + ); + + self.client + .pull_blob( + self.reference.as_client_reference(), + descriptor, + &mut output, + ) + .await + .with_context(|| { + format!( + "Failed to download OCI blob {} to {}", + descriptor.digest, + output_path.display() + ) + })?; + output + .sync_all() + .await + .with_context(|| format!("Failed to sync {description} {output_path:?}"))?; + + Ok(()) + } + + fn client_for_reference(reference: &OciReference) -> Client { + let mut config = ClientConfig::default(); + let registry = reference.registry_endpoint(); + + if Self::is_loopback_registry(registry) { + config.protocol = ClientProtocol::HttpsExcept(vec![registry.to_string()]); + } + + Client::new(config) + } + + fn is_loopback_registry(registry: &str) -> bool { + let host = registry + .split_once(':') + .map_or(registry, |(host, _)| host) + .trim_matches(['[', ']']); + + host == "localhost" || host == "127.0.0.1" || host == "::1" + } + + fn image_manifest(manifest: OciManifest) -> Result { + match manifest { + OciManifest::Image(manifest) => Ok(manifest), + OciManifest::ImageIndex(_) => { + anyhow::bail!( + "OCI image index manifests are not supported for model artifacts; use an OCI image manifest" + ); + } + } + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::super::{ + OciProvider, + cache_entry::{CACHE_ROOT_DIR_NAME, TMP_DIR_NAME}, + layer_download::TITLE_ANNOTATION, + }; + use super::MANIFEST_FILE_NAME; + use crate::providers::ModelProviderTrait; + use serde_json::json; + use sha2::{Digest, Sha256}; + use std::fs; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn digest_bytes(bytes: &[u8]) -> String { + format!("sha256:{:x}", Sha256::digest(bytes)) + } + + fn tar_bytes(entries: &[(&str, &[u8])]) -> Vec { + let mut bytes = Vec::new(); + { + let mut builder = tar::Builder::new(&mut bytes); + for (path, contents) in entries { + let mut header = tar::Header::new_gnu(); + header.set_size(contents.len() as u64); + header.set_mode(0o644); + header.set_cksum(); + builder + .append_data(&mut header, path, *contents) + .expect("append tar entry"); + } + builder.finish().expect("finish tar"); + } + bytes + } + + #[tokio::test] + async fn test_mock_registry_download_publishes_final_cache_entry() { + let cache_dir = TempDir::new().expect("temp cache"); + let server = MockServer::start().await; + let registry = server + .uri() + .strip_prefix("http://") + .expect("wiremock should use http") + .to_string(); + let repo = "team/model"; + let config = b"{}"; + let artifact_manifest = br#"{"artifact":true}"#; + let tokenizer = b"{\"tokenizer\":true}"; + let weights = b"weights"; + let config_digest = digest_bytes(config); + let artifact_manifest_digest = digest_bytes(artifact_manifest); + let tokenizer_digest = digest_bytes(tokenizer); + let weights_digest = digest_bytes(weights); + + let manifest = json!({ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "config": { + "mediaType": "application/vnd.oci.image.config.v1+json", + "size": 2, + "digest": digest_bytes(b"{}") + }, + "layers": [ + { + "mediaType": "application/octet-stream", + "size": config.len(), + "digest": config_digest, + "annotations": { TITLE_ANNOTATION: "config.json" } + }, + { + "mediaType": "application/octet-stream", + "size": tokenizer.len(), + "digest": tokenizer_digest, + "annotations": { TITLE_ANNOTATION: "tokenizer.json" } + }, + { + "mediaType": "application/octet-stream", + "size": artifact_manifest.len(), + "digest": artifact_manifest_digest, + "annotations": { TITLE_ANNOTATION: "manifest.json" } + }, + { + "mediaType": "application/octet-stream", + "size": weights.len(), + "digest": weights_digest, + "annotations": { TITLE_ANNOTATION: "model.safetensors" } + } + ] + }); + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/manifests/v1"))) + .respond_with(ResponseTemplate::new(200).set_body_json(manifest)) + .mount(&server) + .await; + + for (digest, body) in [ + (config_digest.as_str(), config.as_slice()), + ( + artifact_manifest_digest.as_str(), + artifact_manifest.as_slice(), + ), + (tokenizer_digest.as_str(), tokenizer.as_slice()), + (weights_digest.as_str(), weights.as_slice()), + ] { + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/blobs/{digest}"))) + .respond_with(ResponseTemplate::new(200).set_body_bytes(body.to_vec())) + .mount(&server) + .await; + } + + let model_ref = format!("{registry}/{repo}:v1"); + let path = OciProvider + .download_model(&model_ref, Some(cache_dir.path().to_path_buf()), true) + .await + .expect("download should succeed"); + + assert!(path.join("config.json").is_file()); + assert!(path.join("tokenizer.json").is_file()); + assert_eq!( + fs::read(path.join(MANIFEST_FILE_NAME)).expect("read artifact manifest.json"), + artifact_manifest + ); + assert!(!path.join("model.safetensors").exists()); + assert!( + !path + .parent() + .expect("files directory has a cache entry parent") + .join("metadata") + .exists() + ); + + let oci_root = cache_dir.path().join(CACHE_ROOT_DIR_NAME); + let tmp_root = oci_root.join(TMP_DIR_NAME); + assert!(!tmp_root.exists() || fs::read_dir(&tmp_root).expect("read tmp").next().is_none()); + } + + #[tokio::test] + async fn test_mock_registry_download_extracts_archive_layer() { + let cache_dir = TempDir::new().expect("temp cache"); + let server = MockServer::start().await; + let registry = server + .uri() + .strip_prefix("http://") + .expect("wiremock should use http") + .to_string(); + let repo = "team/archive-model"; + let manifest_json = br#"{"build":{"id":"archive-model"}}"#; + let archive = tar_bytes(&[ + ("config.json", b"{}"), + ("model.safetensors", b"weights"), + ("README.md", b"readme"), + ]); + let manifest_digest = digest_bytes(manifest_json); + let archive_digest = digest_bytes(&archive); + + let manifest = json!({ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "config": { + "mediaType": "application/vnd.kitops.modelkit.config.v1+json", + "size": manifest_json.len(), + "digest": manifest_digest + }, + "layers": [ + { + "mediaType": "application/vnd.kitops.modelkit.model.v1.tar", + "size": archive.len(), + "digest": archive_digest, + "annotations": { TITLE_ANNOTATION: "part-0" } + } + ] + }); + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/manifests/v1"))) + .respond_with(ResponseTemplate::new(200).set_body_json(manifest)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/blobs/{manifest_digest}"))) + .respond_with(ResponseTemplate::new(200).set_body_bytes(manifest_json)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/blobs/{archive_digest}"))) + .respond_with(ResponseTemplate::new(200).set_body_bytes(archive)) + .mount(&server) + .await; + + let model_ref = format!("{registry}/{repo}:v1"); + let path = OciProvider + .download_model(&model_ref, Some(cache_dir.path().to_path_buf()), false) + .await + .expect("download should succeed"); + + assert_eq!( + fs::read(path.join(MANIFEST_FILE_NAME)).expect("read artifact manifest.json"), + manifest_json + ); + assert!(path.join("config.json").is_file()); + assert!(path.join("model.safetensors").is_file()); + assert!(!path.join("part-0/config.json").exists()); + assert!(!path.join("README.md").exists()); + } + + #[tokio::test] + async fn test_mock_registry_downloads_manifest_after_filtering_layers() { + let cache_dir = TempDir::new().expect("temp cache"); + let server = MockServer::start().await; + let registry = server + .uri() + .strip_prefix("http://") + .expect("wiremock should use http") + .to_string(); + let repo = "team/archive-model"; + let manifest_json = br#"{"build":{"id":"manifest-only"}}"#; + let manifest_digest = digest_bytes(manifest_json); + + let manifest = json!({ + "schemaVersion": 2, + "mediaType": "application/vnd.oci.image.manifest.v1+json", + "config": { + "mediaType": "application/vnd.kitops.modelkit.config.v1+json", + "size": manifest_json.len(), + "digest": manifest_digest + }, + "layers": [ + { + "mediaType": "application/vnd.kitops.modelkit.model.v1.tar", + "size": 7, + "digest": digest_bytes(b"archive") + } + ] + }); + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/manifests/v1"))) + .respond_with(ResponseTemplate::new(200).set_body_json(manifest)) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path(format!("/v2/{repo}/blobs/{manifest_digest}"))) + .respond_with(ResponseTemplate::new(200).set_body_bytes(manifest_json)) + .mount(&server) + .await; + + let model_ref = format!("{registry}/{repo}:v1"); + let path = OciProvider + .download_model(&model_ref, Some(cache_dir.path().to_path_buf()), true) + .await + .expect("manifest-only download should publish manifest"); + + assert_eq!( + fs::read(path.join(MANIFEST_FILE_NAME)).expect("read artifact manifest.json"), + manifest_json + ); + } +} diff --git a/modelexpress_common/src/providers/oci/layer_download.rs b/modelexpress_common/src/providers/oci/layer_download.rs new file mode 100644 index 00000000..c281eb23 --- /dev/null +++ b/modelexpress_common/src/providers/oci/layer_download.rs @@ -0,0 +1,262 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{archive_format::ArchiveFormat, path::ArtifactPath}; +use anyhow::{Context, Result}; +use oci_client::manifest::OciDescriptor; +use std::collections::HashSet; + +pub const TITLE_ANNOTATION: &str = "org.opencontainers.image.title"; +const CNCF_FILEPATH_ANNOTATION: &str = "org.cncf.model.filepath"; + +#[derive(Debug, Clone)] +pub struct LayerDownload { + pub descriptor: OciDescriptor, + pub kind: LayerDownloadKind, +} + +#[derive(Debug, Clone)] +pub enum LayerDownloadKind { + Raw { path: ArtifactPath }, + Archive { format: ArchiveFormat }, +} + +pub struct LayerDownloads { + downloads: Vec, +} + +impl LayerDownloads { + pub fn from_layers(layers: &[OciDescriptor], ignore_weights: bool) -> Result { + let mut seen_paths = HashSet::new(); + let mut downloads = Vec::new(); + + for layer in layers { + if ignore_weights && ArchiveFormat::is_archive_media_type(&layer.media_type) { + tracing::debug!( + "Skipping OCI archive layer {} because ignore_weights=true", + layer.digest + ); + continue; + } + + if let Some(format) = ArchiveFormat::from_media_type(&layer.media_type)? { + downloads.push(LayerDownload::archive(layer, format)); + continue; + } + + let Some(path) = Self::raw_layer_path(layer, ignore_weights)? else { + continue; + }; + + if !seen_paths.insert(path.clone()) { + anyhow::bail!("Duplicate OCI artifact file path '{path}'"); + } + + downloads.push(LayerDownload::raw(layer, path)); + } + + Ok(Self { downloads }) + } + + pub fn as_slice(&self) -> &[LayerDownload] { + &self.downloads + } + + fn raw_layer_path(layer: &OciDescriptor, ignore_weights: bool) -> Result> { + let title = LayerDownload::output_path_annotation(layer).with_context(|| { + format!( + "OCI layer {} is missing required '{TITLE_ANNOTATION}' or '{CNCF_FILEPATH_ANNOTATION}' annotation", + layer.digest + ) + })?; + let path = ArtifactPath::from_title(title)?; + + if path.is_skipped(ignore_weights) { + tracing::debug!("Skipping OCI artifact file: {path}"); + return Ok(None); + } + + Ok(Some(path)) + } +} + +impl LayerDownload { + fn archive(layer: &OciDescriptor, format: ArchiveFormat) -> Self { + Self { + descriptor: layer.clone(), + kind: LayerDownloadKind::Archive { format }, + } + } + + fn raw(layer: &OciDescriptor, path: ArtifactPath) -> Self { + Self { + descriptor: layer.clone(), + kind: LayerDownloadKind::Raw { path }, + } + } + + fn output_path_annotation(layer: &OciDescriptor) -> Option<&str> { + layer.annotations.as_ref().and_then(|annotations| { + annotations + .get(TITLE_ANNOTATION) + .or_else(|| annotations.get(CNCF_FILEPATH_ANNOTATION)) + .map(String::as_str) + }) + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::super::archive_format::ArchiveFormat; + use super::*; + use sha2::{Digest, Sha256}; + use std::collections::BTreeMap; + + fn digest_bytes(bytes: &[u8]) -> String { + format!("sha256:{:x}", Sha256::digest(bytes)) + } + + fn descriptor(title: Option<&str>, bytes: &[u8]) -> OciDescriptor { + descriptor_with_media_type("application/octet-stream", title, bytes) + } + + fn descriptor_with_media_type( + media_type: &str, + title: Option<&str>, + bytes: &[u8], + ) -> OciDescriptor { + let annotations = + title.map(|title| BTreeMap::from([(TITLE_ANNOTATION.to_string(), title.to_string())])); + + OciDescriptor { + media_type: media_type.to_string(), + digest: digest_bytes(bytes), + size: bytes.len() as i64, + urls: None, + annotations, + } + } + + fn raw_download_path(download: &LayerDownload) -> &str { + match &download.kind { + LayerDownloadKind::Raw { path } => path.as_str(), + LayerDownloadKind::Archive { .. } => panic!("expected raw download"), + } + } + + #[test] + fn test_prepare_layers_requires_title_and_rejects_duplicates() { + let missing_title = vec![descriptor(None, b"config")]; + let Err(err) = LayerDownloads::from_layers(&missing_title, false) else { + panic!("missing title should fail"); + }; + assert!(err.to_string().contains(TITLE_ANNOTATION)); + + let duplicate = vec![ + descriptor(Some("config.json"), b"one"), + descriptor(Some("config.json"), b"two"), + ]; + let Err(err) = LayerDownloads::from_layers(&duplicate, false) else { + panic!("duplicate path should fail"); + }; + assert!(err.to_string().contains("Duplicate OCI artifact file path")); + } + + #[test] + fn test_prepare_layers_applies_ignore_rules() { + let layers = vec![ + descriptor(Some("README.md"), b"readme"), + descriptor(Some("README.md"), b"duplicate ignored readme"), + descriptor(Some(".gitattributes"), b"dotfile"), + descriptor(Some("diagram.png"), b"image"), + descriptor(Some("model.safetensors"), b"weights"), + descriptor(Some("config.json"), b"config"), + ]; + + let without_weights = + LayerDownloads::from_layers(&layers, true).expect("ignore_weights should succeed"); + assert_eq!(without_weights.as_slice().len(), 1); + assert_eq!( + raw_download_path(&without_weights.as_slice()[0]), + "config.json" + ); + + let with_weights = + LayerDownloads::from_layers(&layers, false).expect("download selection should succeed"); + assert_eq!(with_weights.as_slice().len(), 2); + assert_eq!( + raw_download_path(&with_weights.as_slice()[0]), + "model.safetensors" + ); + assert_eq!( + raw_download_path(&with_weights.as_slice()[1]), + "config.json" + ); + } + + #[test] + fn test_prepare_layers_accepts_archive_layers_without_title() { + let archive = descriptor_with_media_type( + "application/vnd.kitops.modelkit.model.v1.tar", + None, + b"archive", + ); + let downloads = + LayerDownloads::from_layers(&[archive], false).expect("archive should select"); + assert_eq!(downloads.as_slice().len(), 1); + match &downloads.as_slice()[0].kind { + LayerDownloadKind::Archive { format } => { + assert_eq!(*format, ArchiveFormat::Tar); + } + LayerDownloadKind::Raw { .. } => panic!("expected archive download"), + } + + let archive = descriptor_with_media_type( + "application/vnd.oci.image.layer.v1.tar+zstd", + Some("part-0"), + b"archive", + ); + let downloads = + LayerDownloads::from_layers(&[archive], false).expect("archive should select"); + match &downloads.as_slice()[0].kind { + LayerDownloadKind::Archive { format } => { + assert_eq!(*format, ArchiveFormat::TarZstd); + } + LayerDownloadKind::Raw { .. } => panic!("expected archive download"), + } + } + + #[test] + fn test_prepare_layers_skips_archives_when_ignoring_weights() { + let layers = vec![ + descriptor_with_media_type( + "application/vnd.kitops.modelkit.model.v1.tar", + None, + b"archive", + ), + descriptor_with_media_type( + "application/vnd.oci.image.layer.v1.tar+gzip", + None, + b"unsupported archive", + ), + descriptor(Some("config.json"), b"config"), + ]; + + let downloads = + LayerDownloads::from_layers(&layers, true).expect("download selection should succeed"); + + assert_eq!(downloads.as_slice().len(), 1); + assert_eq!(raw_download_path(&downloads.as_slice()[0]), "config.json"); + + let unsupported = descriptor_with_media_type( + "application/vnd.oci.image.layer.v1.tar+gzip", + None, + b"unsupported archive", + ); + let Err(err) = LayerDownloads::from_layers(&[unsupported], false) else { + panic!("unsupported archive should fail without ignore_weights"); + }; + assert!(err.to_string().contains("not supported")); + } +} diff --git a/modelexpress_common/src/providers/oci/path.rs b/modelexpress_common/src/providers/oci/path.rs new file mode 100644 index 00000000..f6d0ea7f --- /dev/null +++ b/modelexpress_common/src/providers/oci/path.rs @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::OciProvider; +use crate::providers::ModelProviderTrait; +use anyhow::Result; +use std::fmt; +use std::path::{Component, Path, PathBuf}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArtifactPath { + path: PathBuf, + raw: String, +} + +impl ArtifactPath { + pub fn from_title(title: &str) -> Result { + if title.is_empty() { + anyhow::bail!("OCI layer title annotation must not be empty"); + } + + Self::from_relative_path(Path::new(title), &format!("OCI layer title '{title}'")) + } + + pub fn from_relative_path(path: &Path, description: &str) -> Result { + let mut relative_path = PathBuf::new(); + let mut parts = Vec::new(); + + for component in path.components() { + match component { + Component::Normal(part) => { + let Some(part) = part.to_str() else { + anyhow::bail!("{description} contains non-UTF-8 path data"); + }; + if part.contains('\\') { + anyhow::bail!("{description} must use forward-slash relative paths"); + } + relative_path.push(part); + parts.push(part.to_owned()); + } + Component::CurDir => { + anyhow::bail!("{description} must not contain '.' path components"); + } + Component::ParentDir => { + anyhow::bail!("{description} must not contain '..' path components"); + } + Component::RootDir | Component::Prefix(_) => { + anyhow::bail!("{description} must be a relative path"); + } + } + } + + if parts.is_empty() { + anyhow::bail!("{description} must name a path"); + } + + Ok(Self { + path: relative_path, + raw: parts.join("/"), + }) + } + + pub fn as_path(&self) -> &Path { + &self.path + } + + pub fn as_str(&self) -> &str { + &self.raw + } + + pub fn is_skipped(&self, ignore_weights: bool) -> bool { + OciProvider::is_ignored(self.as_str()) + || OciProvider::is_image(self.as_path()) + || (ignore_weights && OciProvider::is_weight_file(self.as_str())) + } +} + +impl fmt::Display for ArtifactPath { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str(self.as_str()) + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + + #[test] + fn test_validate_title_path_rejects_unsafe_paths() { + for title in [ + "", + ".", + "..", + "../model.bin", + "/model.bin", + "a/../model.bin", + ] { + assert!( + ArtifactPath::from_title(title).is_err(), + "title should be rejected: {title}" + ); + } + + let artifact_path = + ArtifactPath::from_title("nested/config.json").expect("safe path should pass"); + assert_eq!(artifact_path.as_path(), Path::new("nested/config.json")); + assert_eq!(artifact_path.as_str(), "nested/config.json"); + } +} diff --git a/modelexpress_common/src/providers/oci/provider_cache.rs b/modelexpress_common/src/providers/oci/provider_cache.rs new file mode 100644 index 00000000..76adb772 --- /dev/null +++ b/modelexpress_common/src/providers/oci/provider_cache.rs @@ -0,0 +1,260 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + cache_entry::{ + CACHE_ROOT_DIR_NAME, CacheEntry, FILES_DIR_NAME, TMP_DIR_NAME, repository_from_cache_key, + }, + reference::OciReference, +}; +use crate::{ + cache::{ModelInfo, ProviderCache, directory_size}, + models::ModelProvider, +}; +use anyhow::{Context, Result}; +use std::{ + fs, + path::{Path, PathBuf}, +}; +use tracing::{info, warn}; + +pub(crate) struct OciProviderCache; + +struct CacheWalker<'a> { + root: &'a Path, + models: Vec, +} + +impl<'a> CacheWalker<'a> { + fn collect(root: &'a Path) -> Result> { + let mut walker = Self { + root, + models: Vec::new(), + }; + walker.visit(root)?; + Ok(walker.models) + } + + fn visit(&mut self, current: &Path) -> Result<()> { + if current + .file_name() + .is_some_and(|name| name == FILES_DIR_NAME) + && self.push_model(current)? + { + return Ok(()); + } + + for entry in fs::read_dir(current).with_context(|| format!("Failed to read {current:?}"))? { + let entry = entry.with_context(|| format!("Failed to read entry in {current:?}"))?; + let path = entry.path(); + let file_type = entry + .file_type() + .with_context(|| format!("Failed to inspect {path:?}"))?; + + if file_type.is_dir() && !self.is_staging_path(&path) { + self.visit(&path)?; + } + } + + Ok(()) + } + + fn push_model(&mut self, files_dir: &Path) -> Result { + let Some(name) = self.model_name_from_files_dir(files_dir)? else { + return Ok(false); + }; + + if CacheEntry::files_dir_is_non_empty(files_dir)? { + self.models.push(ModelInfo { + provider: ModelProvider::Oci, + name, + size: directory_size(files_dir)?, + path: files_dir.to_path_buf(), + }); + } + Ok(true) + } + + fn is_staging_path(&self, path: &Path) -> bool { + path.strip_prefix(self.root) + .ok() + .and_then(|relative| relative.components().next()) + .is_some_and(|component| component.as_os_str() == TMP_DIR_NAME) + } + + fn model_name_from_files_dir(&self, files_dir: &Path) -> Result> { + let relative = files_dir.strip_prefix(self.root).with_context(|| { + format!("Failed to strip prefix {:?} from {files_dir:?}", self.root) + })?; + let parts = Self::path_parts(relative)?; + + if parts.len() != 5 || parts.last().is_none_or(|part| part != FILES_DIR_NAME) { + return Ok(None); + } + + let registry = &parts[0]; + let repository = repository_from_cache_key(&parts[1])?; + let reference = &parts[3]; + + match parts[2].as_str() { + "tags" => Ok(Some(format!("{registry}/{repository}:{reference}"))), + "digests" => { + let (algorithm, digest) = reference.rsplit_once('-').ok_or_else(|| { + anyhow::anyhow!("Invalid OCI digest cache entry {reference:?}") + })?; + Ok(Some(format!( + "{registry}/{repository}@{algorithm}:{digest}" + ))) + } + _ => Ok(None), + } + } + + fn path_parts(path: &Path) -> Result> { + path.components() + .map(|component| { + component + .as_os_str() + .to_str() + .map(str::to_owned) + .ok_or_else(|| anyhow::anyhow!("OCI cache path contains non-UTF-8 data")) + }) + .collect() + } +} + +impl ProviderCache for OciProviderCache { + fn clear_model(&self, cache_root: &Path, model_name: &str) -> Result<()> { + let reference = OciReference::parse(model_name)?; + let entry = CacheEntry::new(cache_root, &reference); + + if entry.path().exists() { + fs::remove_dir_all(entry.path()) + .with_context(|| format!("Failed to remove OCI model cache {:?}", entry.path()))?; + info!("Cleared OCI model: {model_name}"); + } else { + warn!("OCI model '{model_name}' not found in cache"); + } + + Ok(()) + } + + fn resolve_model_path( + &self, + cache_root: &Path, + model_name: &str, + _revision: Option<&str>, + ) -> Result { + let reference = OciReference::parse(model_name)?; + // This is a deterministic destination path, not an existing-cache check. + // In no-shared-storage mode the client streams files directly here, so an + // interrupted transfer can leave a non-empty partial directory. Keep this + // method as a provider-specific path mapper; direct OCI downloads still use + // staging plus rename before publishing a final cache entry. + Ok(CacheEntry::new(cache_root, &reference).files_dir()) + } + + fn list_models(&self, cache_root: &Path) -> Result> { + let root = cache_root.join(CACHE_ROOT_DIR_NAME); + if !root.exists() { + return Ok(Vec::new()); + } + + CacheWalker::collect(&root) + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use crate::cache::ProviderCache; + use crate::models::ModelProvider; + use crate::providers::oci::{cache_entry::FILES_DIR_NAME, reference::OciReference}; + + #[test] + fn test_model_path_returns_layout_without_validating_cache() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let path = OciProviderCache + .resolve_model_path(dir.path(), "registry.example.com/team/model:v1", None) + .expect("model path"); + + assert_eq!( + path, + dir.path() + .join("oci/registry.example.com/team%2Fmodel/tags/v1/files") + ); + } + + #[test] + fn test_cache_list_and_clear() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let reference = OciReference::parse("registry.example.com/team/model:v1") + .expect("reference should parse"); + let entry = CacheEntry::path_for(dir.path(), &reference); + let files = entry.join(FILES_DIR_NAME); + fs::create_dir_all(&files).expect("create files dir"); + fs::write(files.join("config.json"), b"{}").expect("write model file"); + + let cache = OciProviderCache; + let models = cache.list_models(dir.path()).expect("list models"); + assert_eq!(models.len(), 1); + assert_eq!(models[0].provider, ModelProvider::Oci); + assert_eq!(models[0].name, "registry.example.com/team/model:v1"); + assert_eq!(models[0].path, files); + + cache + .clear_model(dir.path(), "registry.example.com/team/model:v1") + .expect("clear model"); + assert!(!entry.exists()); + } + + #[test] + fn test_cache_list_allows_files_in_repository_path() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let reference = OciReference::parse("registry.example.com/team/files/model:v1") + .expect("reference should parse"); + let entry = CacheEntry::path_for(dir.path(), &reference); + let files = entry.join(FILES_DIR_NAME); + fs::create_dir_all(&files).expect("create files dir"); + fs::write(files.join("config.json"), b"{}").expect("write model file"); + + let models = OciProviderCache + .list_models(dir.path()) + .expect("list models"); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "registry.example.com/team/files/model:v1"); + assert_eq!(models[0].path, files); + } + + #[test] + fn test_cache_list_keeps_repository_from_overlapping_layout() { + let dir = tempfile::TempDir::new().expect("temp dir"); + let nested = OciReference::parse("registry.example.com/team/model/tags/dev/files/other:v1") + .expect("nested reference should parse"); + let nested_entry = CacheEntry::path_for(dir.path(), &nested); + let nested_files = nested_entry.join(FILES_DIR_NAME); + fs::create_dir_all(&nested_files).expect("create nested files dir"); + fs::write(nested_files.join("config.json"), b"{}").expect("write model file"); + + let alias = OciReference::parse("registry.example.com/team/model:dev") + .expect("alias reference should parse"); + assert!( + !CacheEntry::path_for(dir.path(), &alias) + .join(FILES_DIR_NAME) + .exists() + ); + + let models = OciProviderCache + .list_models(dir.path()) + .expect("list models"); + + assert_eq!(models.len(), 1); + assert_eq!( + models[0].name, + "registry.example.com/team/model/tags/dev/files/other:v1" + ); + assert_eq!(models[0].path, nested_files); + } +} diff --git a/modelexpress_common/src/providers/oci/reference.rs b/modelexpress_common/src/providers/oci/reference.rs new file mode 100644 index 00000000..b46e3c93 --- /dev/null +++ b/modelexpress_common/src/providers/oci/reference.rs @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::{Context, Result}; +use oci_client::Reference; +use std::fmt; + +const OCI_SCHEME: &str = "oci://"; + +#[derive(Debug, Clone)] +pub struct OciReference { + inner: Reference, +} + +impl OciReference { + pub fn parse(model_name: &str) -> Result { + let reference = model_name.strip_prefix(OCI_SCHEME).unwrap_or(model_name); + + if !Self::has_explicit_registry(reference) { + anyhow::bail!( + "OCI reference '{model_name}' must be registry-qualified, for example registry.example.com/repo/model:tag" + ); + } + + if !Self::has_explicit_tag_or_digest(reference) { + anyhow::bail!("OCI reference '{model_name}' must include an explicit tag or digest"); + } + + let inner = reference + .parse::() + .with_context(|| format!("Failed to parse OCI reference '{model_name}'"))?; + Ok(Self { inner }) + } + + pub fn as_client_reference(&self) -> &Reference { + &self.inner + } + + pub fn registry(&self) -> &str { + self.inner.registry() + } + + pub fn repository(&self) -> &str { + self.inner.repository() + } + + pub fn tag(&self) -> Option<&str> { + self.inner.tag() + } + + pub fn digest(&self) -> Option<&str> { + self.inner.digest() + } + + pub fn registry_endpoint(&self) -> &str { + self.inner.resolve_registry() + } + + pub fn canonical_name(&self) -> String { + match self.digest() { + Some(digest) => format!("{}/{}@{}", self.registry(), self.repository(), digest), + None => self.to_string(), + } + } + + fn has_explicit_tag_or_digest(reference: &str) -> bool { + if reference.contains('@') { + return true; + } + + let last_slash = reference.rfind('/'); + let last_colon = reference.rfind(':'); + + match (last_slash, last_colon) { + (Some(slash), Some(colon)) => colon > slash, + (None, Some(_)) => true, + _ => false, + } + } + + fn has_explicit_registry(reference: &str) -> bool { + let Some((registry, _)) = reference.split_once('/') else { + return false; + }; + + registry.contains('.') || registry.contains(':') || registry == "localhost" + } +} + +impl fmt::Display for OciReference { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "{}", self.inner) + } +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + + #[test] + fn test_parse_oci_reference_accepts_scheme_tag_and_digest() { + let tagged = OciReference::parse("oci://registry.example.com/team/model:v1") + .expect("tagged reference should parse"); + assert_eq!(tagged.registry(), "registry.example.com"); + assert_eq!(tagged.repository(), "team/model"); + assert_eq!(tagged.tag(), Some("v1")); + assert_eq!(tagged.to_string(), "registry.example.com/team/model:v1"); + + let digest = "sha256:ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + let by_digest = OciReference::parse(&format!("registry.example.com/team/model@{digest}")) + .expect("digest reference should parse"); + assert_eq!(by_digest.digest(), Some(digest)); + + let tagged_digest = + OciReference::parse(&format!("registry.example.com/team/model:v1@{digest}")) + .expect("tagged digest reference should parse"); + assert_eq!( + tagged_digest.canonical_name(), + format!("registry.example.com/team/model@{digest}") + ); + } + + #[test] + fn test_parse_oci_reference_rejects_missing_explicit_ref_or_registry() { + let missing_ref = OciReference::parse("registry.example.com/team/model") + .expect_err("missing tag or digest should fail"); + assert!( + missing_ref + .to_string() + .contains("must include an explicit tag or digest") + ); + + let missing_registry = + OciReference::parse("team/model:v1").expect_err("missing registry should fail"); + assert!( + missing_registry + .to_string() + .contains("must be registry-qualified") + ); + } +} diff --git a/modelexpress_common/src/providers/oci/registry_auth.rs b/modelexpress_common/src/providers/oci/registry_auth.rs new file mode 100644 index 00000000..e8bfabe3 --- /dev/null +++ b/modelexpress_common/src/providers/oci/registry_auth.rs @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use oci_client::secrets::RegistryAuth; +use std::env; + +const OCI_BEARER_TOKEN_ENV_VAR: &str = "MODEL_EXPRESS_OCI_BEARER_TOKEN"; +const OCI_USERNAME_ENV_VAR: &str = "MODEL_EXPRESS_OCI_USERNAME"; +const OCI_PASSWORD_ENV_VAR: &str = "MODEL_EXPRESS_OCI_PASSWORD"; +const OCI_TOKEN_ENV_VAR: &str = "MODEL_EXPRESS_OCI_TOKEN"; + +fn env_non_empty(key: &str) -> Option { + env::var(key).ok().filter(|value| !value.is_empty()) +} + +pub fn from_env() -> RegistryAuth { + if let Some(token) = env_non_empty(OCI_BEARER_TOKEN_ENV_VAR) { + return RegistryAuth::Bearer(token); + } + + if let Some(username) = env_non_empty(OCI_USERNAME_ENV_VAR) { + if let Some(password) = env_non_empty(OCI_PASSWORD_ENV_VAR) { + return RegistryAuth::Basic(username, password); + } + + if let Some(token) = env_non_empty(OCI_TOKEN_ENV_VAR) { + return RegistryAuth::Basic(username, token); + } + } + + RegistryAuth::Anonymous +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +mod tests { + use super::*; + use crate::test_support::{EnvVarGuard, acquire_env_mutex}; + + #[test] + fn test_auth_precedence() { + let env_lock = acquire_env_mutex(); + let _bearer = EnvVarGuard::set(&env_lock, OCI_BEARER_TOKEN_ENV_VAR, "bearer"); + let _username = EnvVarGuard::set(&env_lock, OCI_USERNAME_ENV_VAR, "user"); + let _password = EnvVarGuard::set(&env_lock, OCI_PASSWORD_ENV_VAR, "password"); + let _token = EnvVarGuard::set(&env_lock, OCI_TOKEN_ENV_VAR, "token"); + + assert_eq!(from_env(), RegistryAuth::Bearer("bearer".to_string())); + } + + #[test] + fn test_auth_uses_password_then_token_then_anonymous() { + let env_lock = acquire_env_mutex(); + let _bearer = EnvVarGuard::remove(&env_lock, OCI_BEARER_TOKEN_ENV_VAR); + let _username = EnvVarGuard::set(&env_lock, OCI_USERNAME_ENV_VAR, "user"); + let password = EnvVarGuard::set(&env_lock, OCI_PASSWORD_ENV_VAR, "password"); + let _token = EnvVarGuard::set(&env_lock, OCI_TOKEN_ENV_VAR, "token"); + + assert_eq!( + from_env(), + RegistryAuth::Basic("user".to_string(), "password".to_string()) + ); + + drop(password); + let _password = EnvVarGuard::remove(&env_lock, OCI_PASSWORD_ENV_VAR); + assert_eq!( + from_env(), + RegistryAuth::Basic("user".to_string(), "token".to_string()) + ); + + let _username = EnvVarGuard::remove(&env_lock, OCI_USERNAME_ENV_VAR); + assert_eq!(from_env(), RegistryAuth::Anonymous); + } +} diff --git a/modelexpress_server/src/registry/backend/kubernetes.rs b/modelexpress_server/src/registry/backend/kubernetes.rs index 7b2b06de..f82fc9a9 100644 --- a/modelexpress_server/src/registry/backend/kubernetes.rs +++ b/modelexpress_server/src/registry/backend/kubernetes.rs @@ -36,7 +36,7 @@ const NAME_BUDGET: usize = K8S_NAME_MAX - CR_NAME_PREFIX.len(); /// Hex chars of SHA256 suffix appended when the sanitized name exceeds the budget. const HASH_SUFFIX_LEN: usize = 12; -/// Sanitize a HuggingFace/NGC model name into a DNS-1123 `metadata.name` component. +/// Sanitize a model name into a DNS-1123 `metadata.name` component. /// /// Transform rules: /// - `/` → `--` @@ -111,6 +111,7 @@ impl KubernetesRegistryBackend { ModelProvider::HuggingFace => "HuggingFace", ModelProvider::Ngc => "Ngc", ModelProvider::Gcs => "Gcs", + ModelProvider::Oci => "Oci", } } @@ -119,6 +120,7 @@ impl KubernetesRegistryBackend { "HuggingFace" => Ok(ModelProvider::HuggingFace), "Ngc" => Ok(ModelProvider::Ngc), "Gcs" => Ok(ModelProvider::Gcs), + "Oci" => Ok(ModelProvider::Oci), other => Err(format!("unknown provider in CR spec: {other:?}").into()), } } @@ -489,6 +491,21 @@ impl RegistryBackend for KubernetesRegistryBackend { mod tests { use super::*; + #[test] + fn provider_roundtrip() { + for provider in [ + ModelProvider::HuggingFace, + ModelProvider::Ngc, + ModelProvider::Gcs, + ModelProvider::Oci, + ] { + let stored = KubernetesRegistryBackend::provider_str(provider); + let parsed = KubernetesRegistryBackend::provider_from_str(stored); + assert!(matches!(parsed, Ok(parsed) if parsed == provider)); + } + assert!(KubernetesRegistryBackend::provider_from_str("bogus").is_err()); + } + #[test] fn sanitize_preserves_readable_prefix() { // The readable prefix is still present; the hash suffix disambiguates collisions. diff --git a/modelexpress_server/src/registry/backend/redis.rs b/modelexpress_server/src/registry/backend/redis.rs index 461f9737..3a0a204d 100644 --- a/modelexpress_server/src/registry/backend/redis.rs +++ b/modelexpress_server/src/registry/backend/redis.rs @@ -42,6 +42,7 @@ fn provider_str(p: ModelProvider) -> &'static str { ModelProvider::HuggingFace => "HuggingFace", ModelProvider::Ngc => "Ngc", ModelProvider::Gcs => "Gcs", + ModelProvider::Oci => "Oci", } } @@ -50,6 +51,7 @@ fn provider_from_str(s: &str) -> RegistryResult { "HuggingFace" => Ok(ModelProvider::HuggingFace), "Ngc" => Ok(ModelProvider::Ngc), "Gcs" => Ok(ModelProvider::Gcs), + "Oci" => Ok(ModelProvider::Oci), other => Err(format!("unknown provider in Redis record: {other:?}").into()), } } @@ -449,6 +451,7 @@ mod tests { ModelProvider::HuggingFace, ModelProvider::Ngc, ModelProvider::Gcs, + ModelProvider::Oci, ] { let s = provider_str(p); assert_eq!(provider_from_str(s).expect("roundtrip"), p); diff --git a/modelexpress_server/src/registry/k8s_types.rs b/modelexpress_server/src/registry/k8s_types.rs index 88ad66fb..0980d430 100644 --- a/modelexpress_server/src/registry/k8s_types.rs +++ b/modelexpress_server/src/registry/k8s_types.rs @@ -30,7 +30,7 @@ pub struct ModelCacheEntrySpec { #[serde(rename = "modelName")] pub model_name: String, - /// Provider string — `"HuggingFace"`, `"Ngc"`, or `"Gcs"`. + /// Provider string — `"HuggingFace"`, `"Ngc"`, `"Gcs"`, or `"Oci"`. pub provider: String, }