Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions architectures/centralized/client/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,8 @@ impl App {
// it will be re-created during cooldown with the current coordinator state.
CheckpointUploader::new_hub(repo_id.clone(), token.clone()).await?;
}
Ok(CheckpointData::Gcs {
ref bucket,
ref prefix,
}) => {
CheckpointUploader::new_gcs(bucket.clone(), prefix.clone()).await?;
Ok(CheckpointData::Gcs { .. }) => {
// GCS uploads use run-down signed URLs; auth is validated at request time.
}
_ => {}
}
Expand Down
28 changes: 22 additions & 6 deletions architectures/decentralized/solana-client/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use psyche_coordinator::{
model_extra_data::CheckpointData,
};
use psyche_core::sha256;
use psyche_data_provider::RunDownClient;
use psyche_metrics::ClientMetrics;

use psyche_network::{DiscoveryMode, NetworkTUIState, NetworkTui, SecretKey, allowlist};
Expand Down Expand Up @@ -96,9 +97,19 @@ pub async fn build_app(

let eval_tasks = p.eval_tasks()?;
let hub_read_token = std::env::var("HF_TOKEN").ok();
let checkpoint_config = p.checkpoint_config()?;
let mut checkpoint_config = p.checkpoint_config()?;
let model_extra_data_override: Option<ModelExtraData> = p.model_extra_data_override()?;

// Construct RunDownClient using the wallet keypair for signing.
// This enables GCS checkpoint upload/download via run-down signed URLs.
let run_down_keypair = wallet_keypair.clone();
let run_down_client = Arc::new(RunDownClient::new(
p.run_id.clone(),
wallet_keypair.pubkey().to_string(),
move |msg| run_down_keypair.sign_message(msg).as_ref().to_vec(),
));
checkpoint_config.run_down_client = Some(run_down_client);

let solana_pubkey = wallet_keypair.pubkey();
let wandb_info = p.wandb_info(format!("{}-{solana_pubkey}", p.run_id))?;

Expand Down Expand Up @@ -250,11 +261,16 @@ impl App {
))?;
CheckpointUploader::new_hub(repo_id.clone(), token.clone()).await?;
}
Ok(CheckpointData::Gcs {
ref bucket,
ref prefix,
}) => {
CheckpointUploader::new_gcs(bucket.clone(), prefix.clone()).await?;
Ok(CheckpointData::Gcs { .. }) => {
// GCS uploads use run-down signed URLs; auth is validated at request time.
if self
.state_options
.checkpoint_config
.run_down_client
.is_none()
{
anyhow::bail!("RunDownClient not configured for GCS checkpoint upload");
}
}
_ => {}
}
Expand Down
1 change: 1 addition & 0 deletions shared/client/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ impl TrainArgs {
keep_steps: self.keep_steps,
hub_token,
skip_upload: self.skip_checkpoint_upload,
run_down_client: None,
})
}

Expand Down
33 changes: 17 additions & 16 deletions shared/client/src/state/cooldown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::CheckpointUploader;
use psyche_coordinator::{
CheckpointerSelection, Coordinator, model::Model, model_extra_data::CheckpointData,
};
use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub};
use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs_signed, upload_to_hub};
use psyche_event_sourcing::event;
#[cfg(feature = "python")]
use psyche_modeling::CausalLM;
Expand Down Expand Up @@ -216,6 +216,7 @@ impl CooldownStepMetadata {
keep_steps,
hub_token,
skip_upload,
run_down_client,
} = checkpoint_info;

// When skip_upload is true (testing), skip all checkpoint saving
Expand All @@ -240,16 +241,12 @@ impl CooldownStepMetadata {
None
}
}
Some(CheckpointData::Gcs { ref bucket, ref prefix }) => {
match CheckpointUploader::new_gcs(
bucket.clone(),
prefix.clone(),
).await {
Ok(uploader) => Some(uploader),
Err(err) => {
error!("Failed to create GCS uploader: {}", err);
None
}
Some(CheckpointData::Gcs { .. }) => {
if let Some(ref client) = run_down_client {
Some(CheckpointUploader::Gcs(client.clone()))
} else {
warn!("RunDownClient not configured, skipping GCS signed URL upload");
None
}
}
_ => None,
Expand Down Expand Up @@ -350,11 +347,15 @@ async fn upload_checkpoint(
) -> Result<(), CheckpointError> {
event!(cooldown::CheckpointUploadStarted);
let result = match uploader {
CheckpointUploader::Gcs(gcs_info) => {
upload_to_gcs(gcs_info, manifest_metadata, local, step, cancellation_token)
.await
.map_err(CheckpointError::UploadError)
}
CheckpointUploader::Gcs(run_down_client) => upload_to_gcs_signed(
&run_down_client,
manifest_metadata,
local,
step,
cancellation_token,
)
.await
.map_err(CheckpointError::UploadError),
CheckpointUploader::Hub(hub_info) => {
upload_to_hub(hub_info, local, step, cancellation_token)
.await
Expand Down
51 changes: 36 additions & 15 deletions shared/client/src/state/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use psyche_core::{
use psyche_data_provider::{
DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider,
PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async,
download_model_from_gcs_async, download_model_repo_async, fetch_json_from_gcs,
fetch_json_from_hub,
download_model_from_gcs_async, download_model_from_gcs_signed_async, download_model_repo_async,
fetch_json_from_gcs, fetch_json_from_hub,
http::{FileURLs, HttpDataProvider},
};
use psyche_event_sourcing::event;
Expand Down Expand Up @@ -377,6 +377,7 @@ impl RunInitConfigAndIO {
})
} else {
let checkpoint_data = checkpoint_data.unwrap_or(CheckpointData::Dummy);
let run_down_client = init_config.checkpoint_config.run_down_client.clone();
tokio::spawn(async move {
let (source, tokenizer, checkpoint_extra_files) = if is_p2p {
let (tx_model_config_response, rx_model_config_response) =
Expand Down Expand Up @@ -516,21 +517,41 @@ impl RunInitConfigAndIO {
);

event!(warmup::CheckpointDownloadStarted { size_bytes: 0 });
let repo_files = match download_model_from_gcs_async(
&bucket,
prefix.as_deref(),
)
.await
let repo_files = if let Some(ref run_down_client) =
run_down_client
{
Ok(files) => {
event!(warmup::CheckpointDownloadComplete(Ok(())));
files
info!("Using run-down signed URLs for GCS download");
match download_model_from_gcs_signed_async(run_down_client)
.await
{
Ok(files) => {
event!(warmup::CheckpointDownloadComplete(Ok(())));
files
}
Err(e) => {
event!(warmup::CheckpointDownloadComplete(Err(
e.to_string()
)));
return Err(e.into());
}
}
Err(e) => {
event!(warmup::CheckpointDownloadComplete(Err(
e.to_string()
)));
return Err(e.into());
} else {
match download_model_from_gcs_async(
&bucket,
prefix.as_deref(),
)
.await
{
Ok(files) => {
event!(warmup::CheckpointDownloadComplete(Ok(())));
files
}
Err(e) => {
event!(warmup::CheckpointDownloadComplete(Err(
e.to_string()
)));
return Err(e.into());
}
}
};

Expand Down
52 changes: 5 additions & 47 deletions shared/client/src/state/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::path::PathBuf;
use std::sync::Arc;

use google_cloud_storage::client::{Storage, StorageControl};
use psyche_coordinator::CommitteeProof;
use psyche_core::{BatchId, MerkleRoot, NodeIdentity};
use psyche_data_provider::{GcsUploadInfo, HubUploadInfo};
use psyche_data_provider::{HubUploadInfo, RunDownClient};
use psyche_modeling::DistroResult;
use psyche_network::{BlobTicket, TransmittableDistroResult};
use tch::TchError;
Expand All @@ -15,7 +15,7 @@ use tokio::task::JoinHandle;
#[derive(Debug, Clone)]
pub enum CheckpointUploader {
Hub(HubUploadInfo),
Gcs(GcsUploadInfo),
Gcs(Arc<RunDownClient>),
Dummy,
}

Expand All @@ -37,50 +37,6 @@ impl CheckpointUploader {
hub_token: token,
}))
}

/// Creates a new GCS uploader after validating bucket permissions.
pub async fn new_gcs(bucket: String, prefix: Option<String>) -> anyhow::Result<Self> {
let _storage = Storage::builder()
.build()
.await
.map_err(|e| anyhow::anyhow!("Failed to create GCS client: {}", e))?;

let client = StorageControl::builder()
.build()
.await
.map_err(|e| anyhow::anyhow!("Failed to create GCS control client: {}", e))?;

let permissions_to_test = vec![
"storage.objects.list",
"storage.objects.get",
"storage.objects.create",
"storage.objects.delete",
];

let resource = format!("projects/_/buckets/{}", bucket);
let perms_vec: Vec<String> = permissions_to_test.iter().map(|s| s.to_string()).collect();
let response = client
.test_iam_permissions()
.set_resource(&resource)
.set_permissions(perms_vec)
.send()
.await?;

let correct_permissions = permissions_to_test
.into_iter()
.all(|p| response.permissions.contains(&p.to_string()));
if !correct_permissions {
anyhow::bail!(
"GCS bucket {} does not have the required permissions for checkpoint upload. Make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly and have the correct permissions to the bucket.",
bucket
)
}

Ok(Self::Gcs(GcsUploadInfo {
gcs_bucket: bucket,
gcs_prefix: prefix,
}))
}
}

#[derive(Debug, Clone)]
Expand All @@ -91,6 +47,8 @@ pub struct CheckpointConfig {
pub hub_token: Option<String>,
/// Skip saving and uploading checkpoints (for testing).
pub skip_upload: bool,
/// RunDownClient for GCS signed URL uploads. If None, GCS uploads are skipped.
pub run_down_client: Option<Arc<RunDownClient>>,
}

#[derive(Debug)]
Expand Down
1 change: 1 addition & 0 deletions shared/data-provider/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ postcard.workspace = true
bytemuck.workspace = true
google-cloud-storage.workspace = true
reqwest = "0.12.12"
bs58 = "0.5"
bytes = "1"
google-cloud-auth = "0.16"
google-cloud-gax = "1.4.0"
Expand Down
6 changes: 6 additions & 0 deletions shared/data-provider/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub enum UploadError {
#[error("GCS error: {0}")]
Gcs(String),

#[error("run-down service error: {0}")]
RunDown(String),

// Common errors
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
Expand All @@ -38,6 +41,9 @@ pub enum DownloadError {
#[error("GCS error: {0}")]
Gcs(String),

#[error("run-down service error: {0}")]
RunDown(String),

#[error("IO error: {0}")]
Io(#[from] std::io::Error),

Expand Down
10 changes: 5 additions & 5 deletions shared/data-provider/src/gcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ pub struct GcsManifestMetadata {
pub run_id: String,
}

const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"];
pub(crate) const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"];

fn get_cache_base(bucket: &str) -> PathBuf {
pub(crate) fn get_cache_base(bucket: &str) -> PathBuf {
// Use HF_HOME if set, otherwise fall back to ~/.cache
std::env::var("HF_HOME")
.map(PathBuf::from)
Expand All @@ -59,7 +59,7 @@ fn get_cache_base(bucket: &str) -> PathBuf {
.join(bucket)
}

fn get_cache_dir(
pub(crate) fn get_cache_dir(
bucket: &str,
prefix: Option<&str>,
step: u32,
Expand All @@ -74,7 +74,7 @@ fn get_cache_dir(
}
}

fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf {
pub(crate) fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf {
let base = get_cache_base(bucket);

match prefix {
Expand All @@ -83,7 +83,7 @@ fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf {
}
}

fn collect_cached_files(
pub(crate) fn collect_cached_files(
cache_dir: &Path,
manifest: &GcsCheckpointManifest,
) -> Option<Vec<PathBuf>> {
Expand Down
Loading
Loading