diff --git a/Cargo.lock b/Cargo.lock index 0312196d1..2584e0059 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7511,6 +7511,7 @@ dependencies = [ "anchor-lang", "anyhow", "async-trait", + "bs58", "bytemuck", "bytes", "chrono", @@ -7575,6 +7576,7 @@ dependencies = [ name = "psyche-deserialize-zerocopy-wasm" version = "0.2.0" dependencies = [ + "anchor-lang", "psyche-coordinator", "psyche-core", "psyche-solana-coordinator", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 479d6628e..db9264d0a 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -224,11 +224,8 @@ impl App { // it will be re-created during cooldown with the current coordinator state. CheckpointUploader::new_hub(repo_id.clone(), token.clone()).await?; } - Ok(CheckpointData::Gcs { - ref bucket, - ref prefix, - }) => { - CheckpointUploader::new_gcs(bucket.clone(), prefix.clone()).await?; + Ok(CheckpointData::Gcs { .. }) => { + // GCS uploads use run-down signed URLs; auth is validated at request time. } _ => {} } diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 95947d17e..c9fbd2d39 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -19,6 +19,7 @@ use psyche_coordinator::{ model_extra_data::CheckpointData, }; use psyche_core::sha256; +use psyche_data_provider::RunDownClient; use psyche_metrics::ClientMetrics; use psyche_network::{DiscoveryMode, NetworkTUIState, NetworkTui, SecretKey, allowlist}; @@ -96,9 +97,19 @@ pub async fn build_app( let eval_tasks = p.eval_tasks()?; let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_config = p.checkpoint_config()?; + let mut checkpoint_config = p.checkpoint_config()?; let model_extra_data_override: Option = p.model_extra_data_override()?; + // Construct RunDownClient using the wallet keypair for signing. + // This enables GCS checkpoint upload/download via run-down signed URLs. + let run_down_keypair = wallet_keypair.clone(); + let run_down_client = Arc::new(RunDownClient::new( + p.run_id.clone(), + wallet_keypair.pubkey().to_string(), + move |msg| run_down_keypair.sign_message(msg).as_ref().to_vec(), + )); + checkpoint_config.run_down_client = Some(run_down_client); + let solana_pubkey = wallet_keypair.pubkey(); let wandb_info = p.wandb_info(format!("{}-{solana_pubkey}", p.run_id))?; @@ -250,11 +261,16 @@ impl App { ))?; CheckpointUploader::new_hub(repo_id.clone(), token.clone()).await?; } - Ok(CheckpointData::Gcs { - ref bucket, - ref prefix, - }) => { - CheckpointUploader::new_gcs(bucket.clone(), prefix.clone()).await?; + Ok(CheckpointData::Gcs { .. }) => { + // GCS uploads use run-down signed URLs; auth is validated at request time. + if self + .state_options + .checkpoint_config + .run_down_client + .is_none() + { + anyhow::bail!("RunDownClient not configured for GCS checkpoint upload"); + } } _ => {} } diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index e7b4237f3..d0434c987 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -253,6 +253,7 @@ impl TrainArgs { keep_steps: self.keep_steps, hub_token, skip_upload: self.skip_checkpoint_upload, + run_down_client: None, }) } diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ab39708d0..e2bd03823 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -2,7 +2,7 @@ use crate::CheckpointUploader; use psyche_coordinator::{ CheckpointerSelection, Coordinator, model::Model, model_extra_data::CheckpointData, }; -use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; +use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs_signed, upload_to_hub}; use psyche_event_sourcing::event; #[cfg(feature = "python")] use psyche_modeling::CausalLM; @@ -216,6 +216,7 @@ impl CooldownStepMetadata { keep_steps, hub_token, skip_upload, + run_down_client, } = checkpoint_info; // When skip_upload is true (testing), skip all checkpoint saving @@ -240,16 +241,12 @@ impl CooldownStepMetadata { None } } - Some(CheckpointData::Gcs { ref bucket, ref prefix }) => { - match CheckpointUploader::new_gcs( - bucket.clone(), - prefix.clone(), - ).await { - Ok(uploader) => Some(uploader), - Err(err) => { - error!("Failed to create GCS uploader: {}", err); - None - } + Some(CheckpointData::Gcs { .. }) => { + if let Some(ref client) = run_down_client { + Some(CheckpointUploader::Gcs(client.clone())) + } else { + warn!("RunDownClient not configured, skipping GCS signed URL upload"); + None } } _ => None, @@ -350,11 +347,15 @@ async fn upload_checkpoint( ) -> Result<(), CheckpointError> { event!(cooldown::CheckpointUploadStarted); let result = match uploader { - CheckpointUploader::Gcs(gcs_info) => { - upload_to_gcs(gcs_info, manifest_metadata, local, step, cancellation_token) - .await - .map_err(CheckpointError::UploadError) - } + CheckpointUploader::Gcs(run_down_client) => upload_to_gcs_signed( + &run_down_client, + manifest_metadata, + local, + step, + cancellation_token, + ) + .await + .map_err(CheckpointError::UploadError), CheckpointUploader::Hub(hub_info) => { upload_to_hub(hub_info, local, step, cancellation_token) .await diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index d0aa33bfb..b7ddc0eaf 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -11,8 +11,8 @@ use psyche_core::{ use psyche_data_provider::{ DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, - download_model_from_gcs_async, download_model_repo_async, fetch_json_from_gcs, - fetch_json_from_hub, + download_model_from_gcs_async, download_model_from_gcs_signed_async, download_model_repo_async, + fetch_json_from_gcs, fetch_json_from_hub, http::{FileURLs, HttpDataProvider}, }; use psyche_event_sourcing::event; @@ -377,6 +377,7 @@ impl RunInitConfigAndIO { }) } else { let checkpoint_data = checkpoint_data.unwrap_or(CheckpointData::Dummy); + let run_down_client = init_config.checkpoint_config.run_down_client.clone(); tokio::spawn(async move { let (source, tokenizer, checkpoint_extra_files) = if is_p2p { let (tx_model_config_response, rx_model_config_response) = @@ -516,21 +517,41 @@ impl RunInitConfigAndIO { ); event!(warmup::CheckpointDownloadStarted { size_bytes: 0 }); - let repo_files = match download_model_from_gcs_async( - &bucket, - prefix.as_deref(), - ) - .await + let repo_files = if let Some(ref run_down_client) = + run_down_client { - Ok(files) => { - event!(warmup::CheckpointDownloadComplete(Ok(()))); - files + info!("Using run-down signed URLs for GCS download"); + match download_model_from_gcs_signed_async(run_down_client) + .await + { + Ok(files) => { + event!(warmup::CheckpointDownloadComplete(Ok(()))); + files + } + Err(e) => { + event!(warmup::CheckpointDownloadComplete(Err( + e.to_string() + ))); + return Err(e.into()); + } } - Err(e) => { - event!(warmup::CheckpointDownloadComplete(Err( - e.to_string() - ))); - return Err(e.into()); + } else { + match download_model_from_gcs_async( + &bucket, + prefix.as_deref(), + ) + .await + { + Ok(files) => { + event!(warmup::CheckpointDownloadComplete(Ok(()))); + files + } + Err(e) => { + event!(warmup::CheckpointDownloadComplete(Err( + e.to_string() + ))); + return Err(e.into()); + } } }; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 94ae80f60..abd711bce 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -1,9 +1,9 @@ use std::path::PathBuf; +use std::sync::Arc; -use google_cloud_storage::client::{Storage, StorageControl}; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; -use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; +use psyche_data_provider::{HubUploadInfo, RunDownClient}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; @@ -15,7 +15,7 @@ use tokio::task::JoinHandle; #[derive(Debug, Clone)] pub enum CheckpointUploader { Hub(HubUploadInfo), - Gcs(GcsUploadInfo), + Gcs(Arc), Dummy, } @@ -37,50 +37,6 @@ impl CheckpointUploader { hub_token: token, })) } - - /// Creates a new GCS uploader after validating bucket permissions. - pub async fn new_gcs(bucket: String, prefix: Option) -> anyhow::Result { - let _storage = Storage::builder() - .build() - .await - .map_err(|e| anyhow::anyhow!("Failed to create GCS client: {}", e))?; - - let client = StorageControl::builder() - .build() - .await - .map_err(|e| anyhow::anyhow!("Failed to create GCS control client: {}", e))?; - - let permissions_to_test = vec![ - "storage.objects.list", - "storage.objects.get", - "storage.objects.create", - "storage.objects.delete", - ]; - - let resource = format!("projects/_/buckets/{}", bucket); - let perms_vec: Vec = permissions_to_test.iter().map(|s| s.to_string()).collect(); - let response = client - .test_iam_permissions() - .set_resource(&resource) - .set_permissions(perms_vec) - .send() - .await?; - - let correct_permissions = permissions_to_test - .into_iter() - .all(|p| response.permissions.contains(&p.to_string())); - if !correct_permissions { - anyhow::bail!( - "GCS bucket {} does not have the required permissions for checkpoint upload. Make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly and have the correct permissions to the bucket.", - bucket - ) - } - - Ok(Self::Gcs(GcsUploadInfo { - gcs_bucket: bucket, - gcs_prefix: prefix, - })) - } } #[derive(Debug, Clone)] @@ -91,6 +47,8 @@ pub struct CheckpointConfig { pub hub_token: Option, /// Skip saving and uploading checkpoints (for testing). pub skip_upload: bool, + /// RunDownClient for GCS signed URL uploads. If None, GCS uploads are skipped. + pub run_down_client: Option>, } #[derive(Debug)] diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index de836ef24..d4cb0277d 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -27,6 +27,7 @@ postcard.workspace = true bytemuck.workspace = true google-cloud-storage.workspace = true reqwest = "0.12.12" +bs58 = "0.5" bytes = "1" google-cloud-auth = "0.16" google-cloud-gax = "1.4.0" diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index ed82bba07..4a2881844 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -16,6 +16,9 @@ pub enum UploadError { #[error("GCS error: {0}")] Gcs(String), + #[error("run-down service error: {0}")] + RunDown(String), + // Common errors #[error("IO error: {0}")] Io(#[from] std::io::Error), @@ -38,6 +41,9 @@ pub enum DownloadError { #[error("GCS error: {0}")] Gcs(String), + #[error("run-down service error: {0}")] + RunDown(String), + #[error("IO error: {0}")] Io(#[from] std::io::Error), diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index a812a39bf..c70567c48 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -43,9 +43,9 @@ pub struct GcsManifestMetadata { pub run_id: String, } -const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; +pub(crate) const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; -fn get_cache_base(bucket: &str) -> PathBuf { +pub(crate) fn get_cache_base(bucket: &str) -> PathBuf { // Use HF_HOME if set, otherwise fall back to ~/.cache std::env::var("HF_HOME") .map(PathBuf::from) @@ -59,7 +59,7 @@ fn get_cache_base(bucket: &str) -> PathBuf { .join(bucket) } -fn get_cache_dir( +pub(crate) fn get_cache_dir( bucket: &str, prefix: Option<&str>, step: u32, @@ -74,7 +74,7 @@ fn get_cache_dir( } } -fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { +pub(crate) fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { let base = get_cache_base(bucket); match prefix { @@ -83,7 +83,7 @@ fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { } } -fn collect_cached_files( +pub(crate) fn collect_cached_files( cache_dir: &Path, manifest: &GcsCheckpointManifest, ) -> Option> { diff --git a/shared/data-provider/src/gcs_signed.rs b/shared/data-provider/src/gcs_signed.rs new file mode 100644 index 000000000..829dbbc8a --- /dev/null +++ b/shared/data-provider/src/gcs_signed.rs @@ -0,0 +1,428 @@ +use crate::errors::{DownloadError, UploadError}; +use crate::gcs::{ + GcsCheckpointManifest, GcsManifestMetadata, MODEL_EXTENSIONS, ManifestFileEntry, + ManifestMetadata, collect_cached_files, get_cache_dir, get_cache_dir_no_manifest, +}; +use crate::run_down::{DownloadUrlEntry, RunDownClient}; +use chrono::Utc; +use futures::TryStreamExt; +use std::path::{Path, PathBuf}; +use tokio::io::AsyncWriteExt; +use tracing::{info, warn}; + +pub async fn upload_to_gcs_signed( + run_down: &RunDownClient, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + cancellation_token: tokio_util::sync::CancellationToken, +) -> Result<(), UploadError> { + let http = reqwest::Client::new(); + + let mut manifest = GcsCheckpointManifest { + metadata: ManifestMetadata { + timestamp: Utc::now(), + epoch: manifest_metadata.epoch, + step: step as u32, + run_id: manifest_metadata.run_id, + }, + files: Vec::new(), + }; + + for path in local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + { + if cancellation_token.is_cancelled() { + info!("Upload cancelled before uploading {}", path.display()); + return Ok(()); + } + + let file_name = path + .file_name() + .and_then(|n| n.to_str()) + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + let file = tokio::fs::File::open(&path).await?; + let size = file.metadata().await?.len(); + + let upload_url = run_down + .get_upload_url(file_name) + .await + .map_err(|e| UploadError::RunDown(e.to_string()))?; + + info!(file = file_name, size, "Uploading file via signed URL"); + + let upload_future = http + .put(&upload_url.url) + .header("Content-Type", "application/octet-stream") + .header("Content-Length", size) + .body(reqwest::Body::from(file)) + .send(); + + let response = tokio::select! { + biased; + + _ = cancellation_token.cancelled() => { + info!("Upload cancelled during upload of {}", path.display()); + return Ok(()); + } + result = upload_future => { + result.map_err(|e| UploadError::RunDown(e.to_string()))? + } + }; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(UploadError::RunDown(format!( + "Signed URL upload failed for {}: {} {}", + file_name, status, error_text + ))); + } + + let generation = match response + .headers() + .get("x-goog-generation") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + { + Some(g) => g, + None => { + warn!( + file = file_name, + "x-goog-generation header missing or invalid in upload response, using 0" + ); + 0 + } + }; + + info!( + file = file_name, + size, generation, "Successfully uploaded file via signed URL" + ); + + manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation, + size_bytes: size, + }); + } + + let manifest_json = serde_json::to_string_pretty(&manifest)?; + let manifest_bytes = manifest_json.into_bytes(); + + let manifest_upload_url = run_down + .get_upload_url("manifest.json") + .await + .map_err(|e| UploadError::RunDown(e.to_string()))?; + + let response = http + .put(&manifest_upload_url.url) + .header("Content-Type", "application/json") + .body(manifest_bytes) + .send() + .await + .map_err(|e| UploadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(UploadError::RunDown(format!( + "Signed URL upload failed for manifest.json: {} {}", + status, error_text + ))); + } + + info!( + run_id = run_down.run_id(), + "Upload via signed URLs complete" + ); + + Ok(()) +} + +pub async fn download_model_from_gcs_signed_async( + run_down: &RunDownClient, +) -> Result, DownloadError> { + let http = reqwest::Client::new(); + let run_id = run_down.run_id(); + + let download_response = run_down + .get_download_urls() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + info!( + "Got {} download URLs from run-down for run {}", + download_response.urls.len(), + run_id + ); + + let manifest_entry = download_response + .urls + .iter() + .find(|e| e.path.ends_with("manifest.json")); + + let cache_key = run_id; + + match manifest_entry { + Some(manifest_entry) => { + let response = http + .get(&manifest_entry.url) + .send() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + return Err(DownloadError::RunDown(format!( + "Failed to download manifest.json: {}", + response.status() + ))); + } + + // Get GCS generation number from manifest response + let manifest_generation = response + .headers() + .get("x-goog-generation") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let manifest_data = response + .bytes() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + + info!( + "Found manifest: step {}, epoch {}, generation {}", + manifest.metadata.step, manifest.metadata.epoch, manifest_generation + ); + + let cache_dir = + get_cache_dir(cache_key, None, manifest.metadata.step, manifest_generation); + + let mut files = if let Some(cached) = collect_cached_files(&cache_dir, &manifest) { + info!("Using cached checkpoint at {:?}", cache_dir); + cached + } else { + info!("Downloading checkpoint via signed URLs to {:?}", cache_dir); + std::fs::create_dir_all(&cache_dir)?; + download_files_from_signed_urls( + &http, + &download_response.urls, + &cache_dir, + &manifest, + ) + .await? + }; + + let config_files = download_non_manifest_files_from_signed_urls( + &http, + &download_response.urls, + &cache_dir, + &[".json", ".py"], + &manifest, + ) + .await?; + files.extend(config_files); + + Ok(files) + } + None => { + info!("No manifest found in signed URLs, downloading all model files"); + let cache_dir = get_cache_dir_no_manifest(cache_key, None); + std::fs::create_dir_all(&cache_dir)?; + download_all_model_files_from_signed_urls( + &http, + &download_response.urls, + &cache_dir, + &MODEL_EXTENSIONS, + ) + .await + } + } +} + +async fn download_files_from_signed_urls( + http: &reqwest::Client, + urls: &[DownloadUrlEntry], + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for file_entry in &manifest.files { + let local_path = cache_dir.join(&file_entry.filename); + + if local_path.exists() { + info!("Using cached: {}", file_entry.filename); + downloaded_files.push(local_path); + continue; + } + + let url_entry = urls + .iter() + .find(|e| e.path.ends_with(&file_entry.filename)) + .ok_or_else(|| { + DownloadError::RunDown(format!( + "No signed URL found for file: {}", + file_entry.filename + )) + })?; + + info!("Downloading via signed URL: {}", file_entry.filename); + + let response = http + .get(&url_entry.url) + .send() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + return Err(DownloadError::RunDown(format!( + "Failed to download {}: {}", + file_entry.filename, + response.status() + ))); + } + + let mut stream = response.bytes_stream(); + let mut file = tokio::fs::File::create(&local_path) + .await + .map_err(DownloadError::Io)?; + while let Some(chunk) = stream + .try_next() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))? + { + file.write_all(&chunk).await.map_err(DownloadError::Io)?; + } + info!("Downloaded: {}", file_entry.filename); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +async fn download_non_manifest_files_from_signed_urls( + http: &reqwest::Client, + urls: &[DownloadUrlEntry], + cache_dir: &Path, + extensions: &[&str], + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let manifest_filenames: std::collections::HashSet<&str> = + manifest.files.iter().map(|f| f.filename.as_str()).collect(); + + let mut downloaded_files = Vec::new(); + + for url_entry in urls { + let filename = url_entry.path.rsplit('/').next().unwrap_or(&url_entry.path); + + if manifest_filenames.contains(filename) { + continue; + } + + if !extensions.iter().any(|ext| filename.ends_with(ext)) { + continue; + } + + let local_path = cache_dir.join(filename); + if local_path.exists() { + info!("Using cached: {}", filename); + downloaded_files.push(local_path); + continue; + } + + info!("Downloading config via signed URL: {}", filename); + + let response = http + .get(&url_entry.url) + .send() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + return Err(DownloadError::RunDown(format!( + "Failed to download {}: {}", + filename, + response.status() + ))); + } + + let mut stream = response.bytes_stream(); + let mut file = tokio::fs::File::create(&local_path) + .await + .map_err(DownloadError::Io)?; + while let Some(chunk) = stream + .try_next() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))? + { + file.write_all(&chunk).await.map_err(DownloadError::Io)?; + } + info!("Downloaded: {}", filename); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +async fn download_all_model_files_from_signed_urls( + http: &reqwest::Client, + urls: &[DownloadUrlEntry], + cache_dir: &Path, + extensions: &[&str], +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for url_entry in urls { + let filename = url_entry.path.rsplit('/').next().unwrap_or(&url_entry.path); + + if !extensions.iter().any(|ext| filename.ends_with(ext)) { + continue; + } + + let local_path = cache_dir.join(filename); + if local_path.exists() { + info!("Using cached: {}", filename); + downloaded_files.push(local_path); + continue; + } + + info!("Downloading via signed URL: {}", filename); + + let response = http + .get(&url_entry.url) + .send() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + return Err(DownloadError::RunDown(format!( + "Failed to download {}: {}", + filename, + response.status() + ))); + } + + let mut stream = response.bytes_stream(); + let mut file = tokio::fs::File::create(&local_path) + .await + .map_err(DownloadError::Io)?; + while let Some(chunk) = stream + .try_next() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))? + { + file.write_all(&chunk).await.map_err(DownloadError::Io)?; + } + info!("Downloaded: {}", filename); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 8899313e3..0cb103b54 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -4,11 +4,13 @@ mod dummy; mod errors; mod file_extensions; mod gcs; +mod gcs_signed; pub mod http; mod hub; mod local; mod preprocessed; mod remote; +pub mod run_down; mod traits; mod weighted; @@ -22,6 +24,7 @@ pub use gcs::{ download_model_from_gcs_async, download_model_from_gcs_sync, fetch_json_from_gcs, upload_json_to_gcs, upload_to_gcs, }; +pub use gcs_signed::{download_model_from_gcs_signed_async, upload_to_gcs_signed}; pub use hub::{ HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, fetch_json_from_hub, @@ -31,5 +34,6 @@ pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; pub use preprocessed::PreprocessedDataProvider; pub use remote::{DataProviderTcpClient, DataProviderTcpServer, DataServerTui}; +pub use run_down::RunDownClient; pub use traits::{LengthKnownDataProvider, TokenizedData, TokenizedDataProvider}; pub use weighted::{WeightedDataProvider, http::WeightedHttpProvidersConfig}; diff --git a/shared/data-provider/src/run_down.rs b/shared/data-provider/src/run_down.rs new file mode 100644 index 000000000..adc5c3018 --- /dev/null +++ b/shared/data-provider/src/run_down.rs @@ -0,0 +1,177 @@ +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use tracing::info; + +const DEFAULT_RUN_DOWN_BASE_URL: &str = "https://run-down.nousresearch.com/v1"; +const SIGNED_URL_EXPIRY_SECONDS: u64 = 3600; + +fn base_url() -> String { + std::env::var("RUN_DOWN_URL").unwrap_or_else(|_| DEFAULT_RUN_DOWN_BASE_URL.to_string()) +} + +/// Client for the Nous run-down service that provides signed URLs for GCS checkpoint +/// upload/download. Uses a generic signing function to decouple from specific wallet +/// implementations. +type SignFn = dyn Fn(&[u8]) -> Vec + Send + Sync; + +pub struct RunDownClient { + http: reqwest::Client, + run_id: String, + wallet_address: String, + sign_fn: Arc, +} + +impl std::fmt::Debug for RunDownClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RunDownClient") + .field("run_id", &self.run_id) + .finish() + } +} + +impl RunDownClient { + pub fn new( + run_id: String, + wallet_address: String, + sign_fn: impl Fn(&[u8]) -> Vec + Send + Sync + 'static, + ) -> Self { + Self { + http: reqwest::Client::new(), + run_id, + wallet_address, + sign_fn: Arc::new(sign_fn), + } + } + + pub fn run_id(&self) -> &str { + &self.run_id + } + + fn generate_signature(&self, expires_in_seconds: u64, nonce: u64) -> String { + let message = format!( + "nous-run-down-service:{}:{}:{}:{}", + self.run_id, self.wallet_address, expires_in_seconds, nonce + ); + let signature_bytes = (self.sign_fn)(message.as_bytes()); + bs58::encode(&signature_bytes).into_string() + } + + fn nonce() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as u64 + } + + /// Get a signed upload URL for a single file. + pub async fn get_upload_url(&self, filename: &str) -> Result { + let expires_in_seconds = SIGNED_URL_EXPIRY_SECONDS; + let nonce = Self::nonce(); + let signature = self.generate_signature(expires_in_seconds, nonce); + + let url = format!("{}/upload/{}", base_url(), self.run_id); + let body = serde_json::json!({ + "walletAddress": self.wallet_address, + "filename": filename, + "expiresInSeconds": expires_in_seconds, + "nonce": nonce.to_string(), + }); + + info!(filename, url, "Requesting signed upload URL from run-down"); + + let response = self + .http + .post(&url) + .header("X-Solana-Signature", &signature) + .json(&body) + .send() + .await + .map_err(|e| RunDownError::Request(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(RunDownError::Api(format!( + "Upload URL request failed with status {}: {}", + status, error_text + ))); + } + + response + .json() + .await + .map_err(|e| RunDownError::Parse(e.to_string())) + } + + /// Get signed download URLs for all files in the run. + pub async fn get_download_urls(&self) -> Result { + let expires_in_seconds = SIGNED_URL_EXPIRY_SECONDS; + let nonce = Self::nonce(); + let signature = self.generate_signature(expires_in_seconds, nonce); + + let url = format!("{}/download/{}", base_url(), self.run_id); + let body = serde_json::json!({ + "walletAddress": self.wallet_address, + "expiresInSeconds": expires_in_seconds, + "nonce": nonce.to_string(), + }); + + info!(url, "Requesting signed download URLs from run-down"); + + let response = self + .http + .post(&url) + .header("X-Solana-Signature", &signature) + .json(&body) + .send() + .await + .map_err(|e| RunDownError::Request(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(RunDownError::Api(format!( + "Download URLs request failed with status {}: {}", + status, error_text + ))); + } + + response + .json() + .await + .map_err(|e| RunDownError::Parse(e.to_string())) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UploadUrlResponse { + pub url: String, + pub expires_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DownloadUrlEntry { + pub path: String, + pub url: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DownloadUrlsResponse { + pub urls: Vec, + pub expires_at: String, +} + +#[derive(Debug, thiserror::Error)] +pub enum RunDownError { + #[error("run-down request failed: {0}")] + Request(String), + + #[error("run-down API error: {0}")] + Api(String), + + #[error("failed to parse run-down response: {0}")] + Parse(String), +} diff --git a/tools/rust-tools/run-manager/src/docker/manager.rs b/tools/rust-tools/run-manager/src/docker/manager.rs index 7377c73a8..2bc364dbe 100644 --- a/tools/rust-tools/run-manager/src/docker/manager.rs +++ b/tools/rust-tools/run-manager/src/docker/manager.rs @@ -26,7 +26,6 @@ pub struct RunManager { local_docker: bool, coordinator_client: CoordinatorClient, scratch_dir: Option, - gcs_credentials_path: Option, client_authorizer: Pubkey, } @@ -64,23 +63,6 @@ impl RunManager { let rpc = get_env_var("RPC")?; let scratch_dir = get_env_var("SCRATCH_DIR").ok(); - // Check for GCS credentials file path - will be mounted into container - let gcs_credentials_path = get_env_var("GOOGLE_CREDENTIALS_FILE_PATH") - .ok() - .map(PathBuf::from) - .and_then(|path| { - if path.exists() { - info!("Found GCS credentials file at: {}", path.display()); - Some(path) - } else { - warn!( - "GOOGLE_CREDENTIALS_FILE_PATH set to {} but file does not exist", - path.display() - ); - None - } - }); - let coordinator_client = CoordinatorClient::new(rpc, coordinator_program_id); // Read delegate key from AUTHORIZER env var (separate from --authorizer flag) @@ -103,7 +85,6 @@ impl RunManager { env_file, local_docker, scratch_dir, - gcs_credentials_path, client_authorizer, }); } @@ -130,7 +111,6 @@ impl RunManager { env_file, local_docker, scratch_dir, - gcs_credentials_path, client_authorizer, }) } @@ -226,27 +206,6 @@ impl RunManager { .arg(format!("type=bind,src={dir},dst=/scratch")); } - // Mount GCS credentials file if provided and set the env var inside container - if let Some(creds_path) = &self.gcs_credentials_path { - let container_creds_path = "/scratch/application_default_credentials.json"; - cmd.arg("--mount") - .arg(format!( - "type=bind,src={},dst={},readonly", - creds_path.display(), - container_creds_path - )) - .arg("--env") - .arg(format!( - "GOOGLE_APPLICATION_CREDENTIALS={}", - container_creds_path - )); - info!( - "Mounting GCS credentials from {} to {}", - creds_path.display(), - container_creds_path - ); - } - if let Some(Entrypoint { entrypoint, .. }) = entrypoint { cmd.arg("--entrypoint").arg(entrypoint); } @@ -362,13 +321,8 @@ impl RunManager { match checkpoint_data { CheckpointData::Gcs { .. } => { - if self.gcs_credentials_path.is_none() { - bail!( - "This run uses GCS checkpointing but no GCS credentials found. \ - Set GOOGLE_CREDENTIALS_FILE_PATH in your env file." - ); - } - info!("GCS credentials validated for checkpoint upload"); + // GCS uploads use run-down signed URLs; no local credentials required. + info!("GCS checkpointing via run-down signed URLs, no local credentials required"); } CheckpointData::Hub { .. } => { // HF_TOKEN should be in the env file