From 54e2411894416ab57479052de32bae2699906bec Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Wed, 24 Sep 2025 14:24:51 -0700 Subject: [PATCH 1/5] Runtime/namespace/client wide worker heartbeat (#983) * worker heartbeat * Address Spencer's comments * wip use client_identity_override as part of key, added test * Refactor almost complete, need to plumb through telemetry to SharedNamespaceWorker * Verified client replacement works, need to update tests and cleanup * formating * clean up * forgot to remove new() now that using builder pattern * Switch to worker_set_key * Replace client test passes, need to write unit tests in worker_registry * cargo test-lint * limit nexus to 1 poller, add tests for worker_registry for heartbeat * PR comments * new test helper * Return error on multi worker register for same namespace and task queue on same client * cargo fmt * Fix registration order, unique task queue for test worker * Remove TEST_Q variable * Missing quotes * CI lint and docker test fix, rename worker_set_key to worker_grouping_key * clippy bug --- client/src/lib.rs | 20 +- client/src/raw.rs | 12 +- client/src/worker_registry/mod.rs | 634 ++++++++++++++---- client/src/workflow_handle/mod.rs | 3 +- core-api/src/worker.rs | 6 - .../include/temporal-sdk-core-c-bridge.h | 5 +- core-c-bridge/src/client.rs | 10 +- core-c-bridge/src/runtime.rs | 21 +- core-c-bridge/src/tests/context.rs | 15 +- core-c-bridge/src/worker.rs | 13 +- core/src/core_tests/activity_tasks.rs | 9 +- core/src/core_tests/workers.rs | 2 +- core/src/core_tests/workflow_tasks.rs | 4 +- core/src/ephemeral_server/mod.rs | 5 +- core/src/lib.rs | 83 ++- core/src/pollers/poll_buffer.rs | 14 +- core/src/protosext/mod.rs | 2 +- core/src/protosext/protocol_messages.rs | 2 +- core/src/replay/mod.rs | 2 +- core/src/telemetry/metrics.rs | 17 +- core/src/telemetry/prometheus_meter.rs | 3 +- core/src/test_help/integ_helpers.rs | 12 +- core/src/worker/client.rs | 80 +-- core/src/worker/client/mocks.rs | 26 +- core/src/worker/heartbeat.rs | 315 ++++----- core/src/worker/mod.rs | 306 +++++++-- core/src/worker/slot_provider.rs | 11 +- core/src/worker/workflow/mod.rs | 7 +- sdk-core-protos/src/history_builder.rs | 2 +- sdk/src/interceptors.rs | 5 +- sdk/src/lib.rs | 15 +- sdk/src/workflow_future.rs | 2 +- tests/common/mod.rs | 32 +- tests/global_metric_tests.rs | 21 +- tests/heavy_tests.rs | 3 +- tests/integ_tests/metrics_tests.rs | 39 +- tests/integ_tests/polling_tests.rs | 4 +- tests/integ_tests/worker_tests.rs | 11 +- tests/integ_tests/workflow_tests.rs | 4 +- tests/main.rs | 7 +- tests/manual_tests.rs | 7 +- tests/workflow_replay_bench.rs | 6 +- 42 files changed, 1241 insertions(+), 556 deletions(-) diff --git a/client/src/lib.rs b/client/src/lib.rs index 41cd2b79f..4d6088f87 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -32,7 +32,9 @@ pub use temporal_sdk_core_protos::temporal::api::{ }, }; pub use tonic; -pub use worker_registry::{Slot, SlotManager, SlotProvider, WorkerKey}; +pub use worker_registry::{ + ClientWorker, ClientWorkerSet, HeartbeatCallback, SharedNamespaceWorkerTrait, Slot, +}; pub use workflow_handle::{ GetWorkflowResultOpts, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle, }; @@ -388,7 +390,7 @@ pub struct ConfiguredClient { headers: Arc>, /// Capabilities as read from the `get_system_info` RPC call made on client connection capabilities: Option, - workers: Arc, + workers: Arc, } impl ConfiguredClient { @@ -438,9 +440,14 @@ impl ConfiguredClient { } /// Returns a cloned reference to a registry with workers using this client instance - pub fn workers(&self) -> Arc { + pub fn workers(&self) -> Arc { self.workers.clone() } + + /// Returns the worker grouping key, this should be unique across each client + pub fn worker_grouping_key(&self) -> Uuid { + self.workers.worker_grouping_key() + } } #[derive(Debug)] @@ -584,7 +591,7 @@ impl ClientOptions { client: TemporalServiceClient::new(svc), options: Arc::new(self.clone()), capabilities: None, - workers: Arc::new(SlotManager::new()), + workers: Arc::new(ClientWorkerSet::new()), }; if !self.skip_get_system_info { match client @@ -901,6 +908,11 @@ impl Client { pub fn into_inner(self) -> ConfiguredClient { self.inner } + + /// Returns the client-wide key + pub fn worker_grouping_key(&self) -> Uuid { + self.inner.worker_grouping_key() + } } impl NamespacedClient for Client { diff --git a/client/src/raw.rs b/client/src/raw.rs index aa5500ee1..92e6a7956 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -7,7 +7,7 @@ use crate::{ TEMPORAL_NAMESPACE_HEADER_KEY, TemporalServiceClient, metrics::{namespace_kv, task_queue_kv}, raw::sealed::RawClientLike, - worker_registry::{Slot, SlotManager}, + worker_registry::{ClientWorkerSet, Slot}, }; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; use std::sync::Arc; @@ -68,7 +68,7 @@ pub(super) mod sealed { fn health_client_mut(&mut self) -> &mut HealthClient; /// Return a registry with workers using this client instance - fn get_workers_info(&self) -> Option>; + fn get_workers_info(&self) -> Option>; async fn call( &mut self, @@ -134,7 +134,7 @@ where self.get_client_mut().health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { self.get_client().get_workers_info() } @@ -213,7 +213,7 @@ where self.health_svc_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { None } } @@ -268,7 +268,7 @@ where self.client.health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { Some(self.workers()) } } @@ -316,7 +316,7 @@ impl RawClientLike for Client { self.inner.health_client_mut() } - fn get_workers_info(&self) -> Option> { + fn get_workers_info(&self) -> Option> { self.inner.get_workers_info() } } diff --git a/client/src/worker_registry/mod.rs b/client/src/worker_registry/mod.rs index 90882a718..f10b128ce 100644 --- a/client/src/worker_registry/mod.rs +++ b/client/src/worker_registry/mod.rs @@ -2,27 +2,16 @@ //! This is needed to implement Eager Workflow Start, a latency optimization in which the client, //! after reserving a slot, directly forwards a WFT to a local worker. +use anyhow::bail; use parking_lot::RwLock; -use slotmap::SlotMap; -use std::collections::{HashMap, hash_map::Entry::Vacant}; - +use std::collections::{ + HashMap, + hash_map::Entry::{Occupied, Vacant}, +}; +use std::sync::Arc; +use temporal_sdk_core_protos::temporal::api::worker::v1::WorkerHeartbeat; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; - -slotmap::new_key_type! { - /// Registration key for a worker - pub struct WorkerKey; -} - -/// This trait is implemented by an object associated with a worker, which provides WFT processing slots. -#[cfg_attr(test, mockall::automock)] -pub trait SlotProvider: std::fmt::Debug { - /// The namespace for the WFTs that it can process. - fn namespace(&self) -> &str; - /// The task queue this provider listens to. - fn task_queue(&self) -> &str; - /// Try to reserve a slot on this worker. - fn try_reserve_wft_slot(&self) -> Option>; -} +use uuid::Uuid; /// This trait represents a slot reserved for processing a WFT by a worker. #[cfg_attr(test, mockall::automock)] @@ -49,21 +38,23 @@ impl SlotKey { } } -/// This is an inner class for [SlotManager] needed to hide the mutex. -#[derive(Default, Debug)] -struct SlotManagerImpl { - /// Maps keys, i.e., namespace#task_queue, to provider. - providers: HashMap>, - /// Maps ids to keys in `providers`. - index: SlotMap, +/// This is an inner class for [ClientWorkerSet] needed to hide the mutex. +struct ClientWorkerSetImpl { + /// Maps slot keys to slot provider worker. + slot_providers: HashMap, + /// Maps worker_instance_key to registered workers + all_workers: HashMap>, + /// Maps namespace to shared worker for worker heartbeating + shared_worker: HashMap>, } -impl SlotManagerImpl { +impl ClientWorkerSetImpl { /// Factory method. fn new() -> Self { Self { - index: Default::default(), - providers: Default::default(), + slot_providers: Default::default(), + all_workers: Default::default(), + shared_worker: Default::default(), } } @@ -73,55 +64,140 @@ impl SlotManagerImpl { task_queue: String, ) -> Option> { let key = SlotKey::new(namespace, task_queue); - if let Some(p) = self.providers.get(&key) - && let Some(slot) = p.try_reserve_wft_slot() + if let Some(p) = self.slot_providers.get(&key) + && let Some(worker) = self.all_workers.get(p) + && let Some(slot) = worker.try_reserve_wft_slot() { return Some(slot); } None } - fn register(&mut self, provider: Box) -> Option { - let key = SlotKey::new( - provider.namespace().to_string(), - provider.task_queue().to_string(), + fn register( + &mut self, + worker: Arc, + ) -> Result<(), anyhow::Error> { + let slot_key = SlotKey::new( + worker.namespace().to_string(), + worker.task_queue().to_string(), ); - if let Vacant(p) = self.providers.entry(key.clone()) { - p.insert(provider); - Some(self.index.insert(key)) - } else { - warn!("Ignoring registration for worker: {key:?}."); - None + if self.slot_providers.contains_key(&slot_key) { + bail!( + "Registration of multiple workers on the same namespace and task queue for the same client not allowed: {slot_key:?}, worker_instance_key: {:?}.", + worker.worker_instance_key() + ); } + + if worker.heartbeat_enabled() + && let Some(heartbeat_callback) = worker.heartbeat_callback() + { + let worker_instance_key = worker.worker_instance_key(); + let namespace = worker.namespace().to_string(); + + let shared_worker = match self.shared_worker.entry(namespace.clone()) { + Occupied(o) => o.into_mut(), + Vacant(v) => { + let shared_worker = worker.new_shared_namespace_worker()?; + v.insert(shared_worker) + } + }; + shared_worker.register_callback(worker_instance_key, heartbeat_callback); + } + + self.slot_providers + .insert(slot_key.clone(), worker.worker_instance_key()); + + self.all_workers + .insert(worker.worker_instance_key(), worker); + + Ok(()) } - fn unregister(&mut self, id: WorkerKey) -> Option> { - if let Some(key) = self.index.remove(id) { - self.providers.remove(&key) - } else { - None + fn unregister( + &mut self, + worker_instance_key: Uuid, + ) -> Result, anyhow::Error> { + let worker = self + .all_workers + .remove(&worker_instance_key) + .ok_or_else(|| { + anyhow::anyhow!("Worker with worker_instance_key {worker_instance_key} not found") + })?; + + let slot_key = SlotKey::new( + worker.namespace().to_string(), + worker.task_queue().to_string(), + ); + + self.slot_providers.remove(&slot_key); + + if let Some(w) = self.shared_worker.get_mut(worker.namespace()) { + let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key()); + if let Some(cb) = callback { + if is_empty { + self.shared_worker.remove(worker.namespace()); + } + + // To maintain single ownership of the callback, we must re-register the callback + // back to the ClientWorker + worker.register_callback(cb); + } } + + Ok(worker) + } + + #[cfg(test)] + fn num_providers(&self) -> usize { + self.slot_providers.len() } #[cfg(test)] - fn num_providers(&self) -> (usize, usize) { - (self.index.len(), self.providers.len()) + fn num_heartbeat_workers(&self) -> usize { + self.shared_worker.values().map(|v| v.num_workers()).sum() } } +/// This trait represents a shared namespace worker that sends worker heartbeats and +/// receives worker commands. +pub trait SharedNamespaceWorkerTrait { + /// Namespace that the shared namespace worker is connected to. + fn namespace(&self) -> String; + + /// Registers a heartbeat callback. + fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatCallback); + + /// Unregisters a heartbeat callback. Returns the callback removed, as well as a bool that + /// indicates if there are no remaining callbacks in the SharedNamespaceWorker, indicating + /// the shared worker itself can be shut down. + fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option, bool); + + /// Returns the number of workers registered to this shared worker. + fn num_workers(&self) -> usize; +} + /// Enables local workers to make themselves visible to a shared client instance. -/// There can only be one worker registered per namespace+queue_name+client, others will get ignored. +/// +/// For slot managing, there can only be one worker registered per +/// namespace+queue_name+client, others will return an error. /// It also provides a convenient method to find compatible slots within the collection. -#[derive(Default, Debug)] -pub struct SlotManager { - manager: RwLock, +pub struct ClientWorkerSet { + worker_grouping_key: Uuid, + worker_manager: RwLock, } -impl SlotManager { +impl Default for ClientWorkerSet { + fn default() -> Self { + Self::new() + } +} + +impl ClientWorkerSet { /// Factory method. pub fn new() -> Self { Self { - manager: RwLock::new(SlotManagerImpl::new()), + worker_grouping_key: Uuid::new_v4(), + worker_manager: RwLock::new(ClientWorkerSetImpl::new()), } } @@ -131,29 +207,93 @@ impl SlotManager { namespace: String, task_queue: String, ) -> Option> { - self.manager + self.worker_manager .read() .try_reserve_wft_slot(namespace, task_queue) } - /// Register a local worker that can provide WFT processing slots. - pub fn register(&self, provider: Box) -> Option { - self.manager.write().register(provider) + /// Unregisters a local worker, typically when that worker starts shutdown. + pub fn unregister_worker( + &self, + worker_instance_key: Uuid, + ) -> Result, anyhow::Error> { + self.worker_manager.write().unregister(worker_instance_key) + } + + /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating. + pub fn register_worker( + &self, + worker: Arc, + ) -> Result<(), anyhow::Error> { + self.worker_manager.write().register(worker) } - /// Unregister a provider, typically when its worker starts shutdown. - pub fn unregister(&self, id: WorkerKey) -> Option> { - self.manager.write().unregister(id) + /// Returns the worker grouping key, which is unique for each worker. + pub fn worker_grouping_key(&self) -> Uuid { + self.worker_grouping_key } #[cfg(test)] /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue. /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`. - pub fn num_providers(&self) -> (usize, usize) { - self.manager.read().num_providers() + pub fn num_providers(&self) -> usize { + self.worker_manager.read().num_providers() + } + + #[cfg(test)] + /// Returns the total number of heartbeat workers registered across all namespaces. + pub fn num_heartbeat_workers(&self) -> usize { + self.worker_manager.read().num_heartbeat_workers() } } +impl std::fmt::Debug for ClientWorkerSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientWorkerSet") + .field("worker_grouping_key", &self.worker_grouping_key) + .finish() + } +} + +/// Contains a worker heartbeat callback, wrapped for mocking +pub type HeartbeatCallback = Box WorkerHeartbeat + Send + Sync>; + +/// Represents a complete worker that can handle both slot management +/// and worker heartbeat functionality. +#[cfg_attr(test, mockall::automock)] +pub trait ClientWorker: Send + Sync { + /// The namespace this worker operates in + fn namespace(&self) -> &str; + + /// The task queue this worker listens to + fn task_queue(&self) -> &str; + + /// Try to reserve a slot for workflow task processing. + /// + /// This method should return `Some(slot)` if a workflow task slot is available, + /// or `None` if all slots are currently in use. The returned slot will be used + /// to process exactly one workflow task. + fn try_reserve_wft_slot(&self) -> Option>; + + /// Unique identifier for this worker instance. + /// This must be stable across the worker's lifetime but unique per instance. + fn worker_instance_key(&self) -> Uuid; + + /// Indicates if worker heartbeating is enabled for this client worker. + fn heartbeat_enabled(&self) -> bool; + + /// Returns the heartbeat callback that can be used to get WorkerHeartbeat data. + fn heartbeat_callback(&self) -> Option; + + /// Creates a new worker that implements the [SharedNamespaceWorkerTrait] + fn new_shared_namespace_worker( + &self, + ) -> Result, anyhow::Error>; + + /// Registers a worker heartbeat callback, typically when a worker is unregistered from a client + fn register_callback(&self, callback: HeartbeatCallback); +} + #[cfg(test)] mod tests { use super::*; @@ -175,8 +315,9 @@ mod tests { task_queue: String, with_error: bool, no_slots: bool, - ) -> MockSlotProvider { - let mut mock_provider = MockSlotProvider::new(); + heartbeat_enabled: bool, + ) -> MockClientWorker { + let mut mock_provider = MockClientWorker::new(); mock_provider .expect_try_reserve_wft_slot() .returning(move || { @@ -189,78 +330,315 @@ mod tests { mock_provider.expect_namespace().return_const(namespace); mock_provider.expect_task_queue().return_const(task_queue); mock_provider + .expect_heartbeat_enabled() + .return_const(heartbeat_enabled); + mock_provider + .expect_worker_instance_key() + .return_const(Uuid::new_v4()); + mock_provider } #[test] - fn registry_respects_registration_order() { - let mock_provider1 = - new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); - let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); - - let manager = SlotManager::new(); - let some_slots = manager.register(Box::new(mock_provider1)); - let no_slots = manager.register(Box::new(mock_provider2)); - assert!(no_slots.is_none()); - - let mut found = 0; - for _ in 0..10 { - if manager - .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) - .is_some() - { - found += 1; + fn registry_keeps_one_provider_per_namespace() { + let manager = ClientWorkerSet::new(); + let mut worker_keys = vec![]; + let mut successful_registrations = 0; + + for i in 0..10 { + let namespace = format!("myId{}", i % 3); + let mock_provider = + new_mock_provider(namespace, "bar_q".to_string(), false, false, false); + let worker_instance_key = mock_provider.worker_instance_key(); + + let result = manager.register_worker(Arc::new(mock_provider)); + if result.is_ok() { + successful_registrations += 1; + worker_keys.push(worker_instance_key); + } else { + // Should get error for duplicate namespace+task_queue combinations + assert!(result.unwrap_err().to_string().contains( + "Registration of multiple workers on the same namespace and task queue" + )); } } - assert_eq!(found, 10); - assert_eq!((1, 1), manager.num_providers()); - - manager.unregister(some_slots.unwrap()); - assert_eq!((0, 0), manager.num_providers()); - - let mock_provider1 = - new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); - let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); - - let no_slots = manager.register(Box::new(mock_provider2)); - let some_slots = manager.register(Box::new(mock_provider1)); - assert!(some_slots.is_none()); - - let mut not_found = 0; - for _ in 0..10 { - if manager - .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) - .is_none() - { - not_found += 1; + + assert_eq!(successful_registrations, 3); + assert_eq!(3, manager.num_providers()); + + let count = worker_keys.iter().fold(0, |count, key| { + manager.unregister_worker(*key).unwrap(); + // expect error since worker is already unregistered + let result = manager.unregister_worker(*key); + assert!(result.is_err()); + count + 1 + }); + assert_eq!(3, count); + assert_eq!(0, manager.num_providers()); + } + + struct MockSharedNamespaceWorker { + namespace: String, + callbacks: Arc>>, + } + + impl std::fmt::Debug for MockSharedNamespaceWorker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MockSharedNamespaceWorker") + .field("namespace", &self.namespace) + .field("callbacks_count", &self.callbacks.read().len()) + .finish() + } + } + + impl MockSharedNamespaceWorker { + fn new(namespace: String) -> Self { + Self { + namespace, + callbacks: Arc::new(RwLock::new(HashMap::new())), } } - assert_eq!(not_found, 10); - assert_eq!((1, 1), manager.num_providers()); - manager.unregister(no_slots.unwrap()); - assert_eq!((0, 0), manager.num_providers()); } - #[test] - fn registry_keeps_one_provider_per_namespace() { - let manager = SlotManager::new(); - let mut worker_keys = vec![]; - for i in 0..10 { - let namespace = format!("myId{}", i % 3); - let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false); - worker_keys.push(manager.register(Box::new(mock_provider))); + impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker { + fn namespace(&self) -> String { + self.namespace.clone() } - assert_eq!((3, 3), manager.num_providers()); - - let count = worker_keys - .iter() - .filter(|key| key.is_some()) - .fold(0, |count, key| { - manager.unregister(key.unwrap()); - // Should be idempotent - manager.unregister(key.unwrap()); - count + 1 - }); - assert_eq!(3, count); - assert_eq!((0, 0), manager.num_providers()); + + fn register_callback( + &self, + worker_instance_key: Uuid, + heartbeat_callback: HeartbeatCallback, + ) { + self.callbacks + .write() + .insert(worker_instance_key, heartbeat_callback); + } + + fn unregister_callback( + &self, + worker_instance_key: Uuid, + ) -> (Option, bool) { + let mut callbacks = self.callbacks.write(); + let callback = callbacks.remove(&worker_instance_key); + let is_empty = callbacks.is_empty(); + (callback, is_empty) + } + + fn num_workers(&self) -> usize { + self.callbacks.read().len() + } + } + + fn new_mock_provider_with_heartbeat( + namespace: String, + task_queue: String, + heartbeat_enabled: bool, + worker_instance_key: Uuid, + ) -> MockClientWorker { + let mut mock_provider = MockClientWorker::new(); + mock_provider + .expect_try_reserve_wft_slot() + .returning(|| Some(new_mock_slot(false))); + mock_provider + .expect_namespace() + .return_const(namespace.clone()); + mock_provider.expect_task_queue().return_const(task_queue); + mock_provider + .expect_heartbeat_enabled() + .return_const(heartbeat_enabled); + mock_provider + .expect_worker_instance_key() + .return_const(worker_instance_key); + + if heartbeat_enabled { + mock_provider + .expect_heartbeat_callback() + .returning(|| Some(Box::new(WorkerHeartbeat::default))); + + let namespace_clone = namespace.clone(); + mock_provider + .expect_new_shared_namespace_worker() + .returning(move || { + Ok(Box::new(MockSharedNamespaceWorker::new( + namespace_clone.clone(), + ))) + }); + + mock_provider.expect_register_callback().returning(|_| {}); + } + + mock_provider + } + + #[test] + fn duplicate_namespace_task_queue_registration_fails() { + let manager = ClientWorkerSet::new(); + + let worker1 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "test_queue".to_string(), + true, + Uuid::new_v4(), + ); + + // Same namespace+task_queue but different worker instance + let worker2 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "test_queue".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + + // second worker register should fail due to duplicate namespace+task_queue + let result = manager.register_worker(Arc::new(worker2)); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Registration of multiple workers on the same namespace and task queue") + ); + + assert_eq!(1, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 1); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + } + + #[test] + fn multiple_workers_same_namespace_share_heartbeat_manager() { + let manager = ClientWorkerSet::new(); + + let worker1 = new_mock_provider_with_heartbeat( + "shared_namespace".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + + // Same namespace but different task queue + let worker2 = new_mock_provider_with_heartbeat( + "shared_namespace".to_string(), + "queue2".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("shared_namespace")); + + let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap(); + assert_eq!(shared_worker.namespace(), "shared_namespace"); + } + + #[test] + fn different_namespaces_get_separate_heartbeat_managers() { + let manager = ClientWorkerSet::new(); + let worker1 = new_mock_provider_with_heartbeat( + "namespace1".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + let worker2 = new_mock_provider_with_heartbeat( + "namespace2".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.num_heartbeat_workers(), 2); + assert!(impl_ref.shared_worker.contains_key("namespace1")); + assert!(impl_ref.shared_worker.contains_key("namespace2")); + } + + #[test] + fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() { + let manager = ClientWorkerSet::new(); + + // Create two workers with same namespace but different task queues + let worker1 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "queue1".to_string(), + true, + Uuid::new_v4(), + ); + let worker2 = new_mock_provider_with_heartbeat( + "test_namespace".to_string(), + "queue2".to_string(), + true, + Uuid::new_v4(), + ); + let worker_instance_key1 = worker1.worker_instance_key(); + let worker_instance_key2 = worker2.worker_instance_key(); + + manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker2)).unwrap(); + + // Verify initial state: 2 slot providers, 2 heartbeat workers, 1 shared worker + assert_eq!(2, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 2); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 1); + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + assert_eq!( + impl_ref + .shared_worker + .get("test_namespace") + .unwrap() + .num_workers(), + 2 + ); + drop(impl_ref); + + // Unregister first worker + manager.unregister_worker(worker_instance_key1).unwrap(); + + // After unregistering first worker: 1 slot provider, 1 heartbeat worker, shared worker still exists + assert_eq!(1, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 1); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.num_heartbeat_workers(), 1); // SharedNamespaceWorker still exists + assert!(impl_ref.shared_worker.contains_key("test_namespace")); + assert_eq!( + impl_ref + .shared_worker + .get("test_namespace") + .unwrap() + .num_workers(), + 1 + ); + drop(impl_ref); + + // Unregister second worker + manager.unregister_worker(worker_instance_key2).unwrap(); + + // After unregistering last worker: 0 slot providers, 0 heartbeat workers, shared worker is removed + assert_eq!(0, manager.num_providers()); + assert_eq!(manager.num_heartbeat_workers(), 0); + + let impl_ref = manager.worker_manager.read(); + assert_eq!(impl_ref.shared_worker.len(), 0); // SharedNamespaceWorker is cleaned up + assert!(!impl_ref.shared_worker.contains_key("test_namespace")); } } diff --git a/client/src/workflow_handle/mod.rs b/client/src/workflow_handle/mod.rs index 6d284711b..af6d7ab16 100644 --- a/client/src/workflow_handle/mod.rs +++ b/client/src/workflow_handle/mod.rs @@ -200,8 +200,7 @@ where o => Err(anyhow!( "Server returned an event that didn't match the CloseEvent filter. \ This is either a server bug or a new event the SDK does not understand. \ - Event details: {:?}", - o + Event details: {o:?}" )), }; } diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index b7304c5c9..d92efeec0 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -161,12 +161,6 @@ pub struct WorkerConfig { /// A versioning strategy for this worker. pub versioning_strategy: WorkerVersioningStrategy, - - /// The interval within which the worker will send a heartbeat. - /// The timer is reset on each existing RPC call that also happens to send this data, like - /// `PollWorkflowTaskQueueRequest`. - #[builder(default)] - pub heartbeat_interval: Option, } impl WorkerConfig { diff --git a/core-c-bridge/include/temporal-sdk-core-c-bridge.h b/core-c-bridge/include/temporal-sdk-core-c-bridge.h index 803879f7d..b6ee3e275 100644 --- a/core-c-bridge/include/temporal-sdk-core-c-bridge.h +++ b/core-c-bridge/include/temporal-sdk-core-c-bridge.h @@ -397,6 +397,7 @@ typedef struct TemporalCoreTelemetryOptions { typedef struct TemporalCoreRuntimeOptions { const struct TemporalCoreTelemetryOptions *telemetry; + uint64_t worker_heartbeat_duration_millis; } TemporalCoreRuntimeOptions; typedef struct TemporalCoreTestServerOptions { @@ -868,8 +869,8 @@ void temporal_core_worker_validate(struct TemporalCoreWorker *worker, void *user_data, TemporalCoreWorkerCallback callback); -void temporal_core_worker_replace_client(struct TemporalCoreWorker *worker, - struct TemporalCoreClient *new_client); +const struct TemporalCoreByteArray *temporal_core_worker_replace_client(struct TemporalCoreWorker *worker, + struct TemporalCoreClient *new_client); void temporal_core_worker_poll_workflow_activation(struct TemporalCoreWorker *worker, void *user_data, diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index fde3548ce..ccdd660fd 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -685,7 +685,7 @@ async fn call_workflow_service( "UpdateWorkerBuildIdCompatibility" => { rpc_call!(client, call, update_worker_build_id_compatibility) } - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -715,7 +715,7 @@ async fn call_operator_service( "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, update_nexus_endpoint) } - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -785,7 +785,7 @@ async fn call_cloud_service(client: &CoreClient, call: &RpcCallOptions) -> anyho "GetConnectivityRule" => rpc_call!(client, call, get_connectivity_rule), "GetConnectivityRules" => rpc_call!(client, call, get_connectivity_rules), "DeleteConnectivityRule" => rpc_call!(client, call, delete_connectivity_rule), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -799,7 +799,7 @@ async fn call_test_service(client: &CoreClient, call: &RpcCallOptions) -> anyhow "Sleep" => rpc_call!(client, call, sleep), "UnlockTimeSkippingWithSleep" => rpc_call!(client, call, unlock_time_skipping_with_sleep), "UnlockTimeSkipping" => rpc_call!(client, call, unlock_time_skipping), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -814,7 +814,7 @@ async fn call_health_service( "Watch" => Err(anyhow::anyhow!( "Health service Watch method is not implemented in C bridge" )), - rpc => Err(anyhow::anyhow!("Unknown RPC call {}", rpc)), + rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } diff --git a/core-c-bridge/src/runtime.rs b/core-c-bridge/src/runtime.rs index 94fe46929..5a330e268 100644 --- a/core-c-bridge/src/runtime.rs +++ b/core-c-bridge/src/runtime.rs @@ -16,7 +16,8 @@ use std::{ time::{Duration, UNIX_EPOCH}, }; use temporal_sdk_core::{ - CoreRuntime, TokioRuntimeBuilder, + CoreRuntime, RuntimeOptions as CoreRuntimeOptions, + RuntimeOptionsBuilder as CoreRuntimeOptionsBuilder, TokioRuntimeBuilder, telemetry::{build_otlp_metric_exporter, start_prometheus_metric_exporter}, }; use temporal_sdk_core_api::telemetry::{ @@ -30,6 +31,7 @@ use url::Url; #[repr(C)] pub struct RuntimeOptions { pub telemetry: *const TelemetryOptions, + pub worker_heartbeat_duration_millis: u64, } #[repr(C)] @@ -142,7 +144,7 @@ pub extern "C" fn temporal_core_runtime_new(options: *const RuntimeOptions) -> R let mut runtime = Runtime { core: Arc::new( CoreRuntime::new( - CoreTelemetryOptions::default(), + CoreRuntimeOptions::default(), TokioRuntimeBuilder::default(), ) .unwrap(), @@ -238,8 +240,21 @@ impl Runtime { CoreTelemetryOptions::default() }; + let heartbeat_interval = if options.worker_heartbeat_duration_millis == 0 { + None + } else { + Some(Duration::from_millis( + options.worker_heartbeat_duration_millis, + )) + }; + + let core_runtime_options = CoreRuntimeOptionsBuilder::default() + .telemetry_options(telemetry_options) + .heartbeat_interval(heartbeat_interval) + .build()?; + // Build core runtime - let mut core = CoreRuntime::new(telemetry_options, TokioRuntimeBuilder::default())?; + let mut core = CoreRuntime::new(core_runtime_options, TokioRuntimeBuilder::default())?; // We late-bind the metrics after core runtime is created since it needs // the Tokio handle diff --git a/core-c-bridge/src/tests/context.rs b/core-c-bridge/src/tests/context.rs index 5f889d442..0fb7aaa04 100644 --- a/core-c-bridge/src/tests/context.rs +++ b/core-c-bridge/src/tests/context.rs @@ -153,6 +153,7 @@ impl Context { let RuntimeOrFail { runtime, fail } = temporal_core_runtime_new(&RuntimeOptions { telemetry: std::ptr::null(), + worker_heartbeat_duration_millis: 0, }); if let Some(fail) = byte_array_to_string(runtime, fail) { @@ -162,11 +163,7 @@ impl Context { temporal_core_runtime_free(runtime); "" }; - Err(anyhow!( - "Runtime creation failed: {}{}", - runtime_is_null, - fail - )) + Err(anyhow!("Runtime creation failed: {runtime_is_null}{fail}")) } else if runtime.is_null() { Err(anyhow!("Runtime creation failed: runtime is null")) } else { @@ -522,8 +519,7 @@ extern "C" fn ephemeral_server_start_callback( if let Some(fail) = fail { ContextOperationState::CallbackError(anyhow!( - "Ephemeral server start failed: {}", - fail + "Ephemeral server start failed: {fail}" )) } else if server.is_null() { ContextOperationState::CallbackError(anyhow!( @@ -568,8 +564,7 @@ extern "C" fn ephemeral_server_shutdown_callback( let _ = context.complete_operation_catch_unwind(|guard| { if let Some(fail) = byte_array_to_string(guard.runtime, std::mem::take(&mut fail)) { ContextOperationState::CallbackError(anyhow!( - "Ephemeral server shutdown failed: {}", - fail + "Ephemeral server shutdown failed: {fail}" )) } else { ContextOperationState::CallbackOk(None) @@ -591,7 +586,7 @@ extern "C" fn client_connect_callback( if let Some(context) = user_data.context.upgrade() { let _ = context.complete_operation_catch_unwind(|guard| { if let Some(fail) = byte_array_to_string(guard.runtime, std::mem::take(&mut fail)) { - ContextOperationState::CallbackError(anyhow!("Client connect failed: {}", fail)) + ContextOperationState::CallbackError(anyhow!("Client connect failed: {fail}")) } else { guard.client = std::mem::take(&mut client); ContextOperationState::CallbackOk(None) diff --git a/core-c-bridge/src/worker.rs b/core-c-bridge/src/worker.rs index 8eaa13f26..c752fdae0 100644 --- a/core-c-bridge/src/worker.rs +++ b/core-c-bridge/src/worker.rs @@ -522,11 +522,20 @@ pub extern "C" fn temporal_core_worker_validate( pub extern "C" fn temporal_core_worker_replace_client( worker: *mut Worker, new_client: *mut Client, -) { +) -> *const ByteArray { let worker = unsafe { &*worker }; let core_worker = worker.worker.as_ref().expect("missing worker").clone(); let client = unsafe { &*new_client }; - core_worker.replace_client(client.core.get_client().clone()); + + match core_worker.replace_client(client.core.get_client().clone()) { + Ok(()) => std::ptr::null(), + Err(err) => worker + .runtime + .clone() + .alloc_utf8(&format!("Replace client failed: {err}")) + .into_raw() + .cast_const(), + } } /// If success or fail are present, they must be freed. They will both be null diff --git a/core/src/core_tests/activity_tasks.rs b/core/src/core_tests/activity_tasks.rs index 6a3acdaaf..b508a10e0 100644 --- a/core/src/core_tests/activity_tasks.rs +++ b/core/src/core_tests/activity_tasks.rs @@ -1,7 +1,7 @@ use crate::{ ActivityHeartbeat, Worker, advance_fut, job_assert, prost_dur, test_help::{ - MockPollCfg, MockWorkerInputs, MocksHolder, QueueResponse, TEST_Q, WorkerExt, + MockPollCfg, MockWorkerInputs, MocksHolder, QueueResponse, WorkerExt, WorkflowCachingPolicy, build_fake_worker, build_mock_pollers, fanout_tasks, gen_assert_and_reply, mock_manual_poller, mock_poller, mock_worker, poll_and_reply, single_hist_mock_sg, test_worker_cfg, @@ -734,7 +734,7 @@ async fn no_eager_activities_requested_when_worker_options_disable_it( ScheduleActivity { seq: 1, activity_id: "act_id".to_string(), - task_queue: TEST_Q.to_string(), + task_queue: core.get_config().task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } @@ -821,6 +821,7 @@ async fn activity_tasks_from_completion_are_delivered() { let mut mock = build_mock_pollers(mh); mock.worker_cfg(|wc| wc.max_cached_workflows = 2); let core = mock_worker(mock); + let task_queue = core.get_config().task_queue.clone(); // Test start let wf_task = core.poll_workflow_activation().await.unwrap(); @@ -829,7 +830,7 @@ async fn activity_tasks_from_completion_are_delivered() { ScheduleActivity { seq, activity_id: format!("act_id_{seq}_same_queue"), - task_queue: TEST_Q.to_string(), + task_queue: task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } @@ -840,7 +841,7 @@ async fn activity_tasks_from_completion_are_delivered() { ScheduleActivity { seq: 4, activity_id: "act_id_same_queue_not_eager".to_string(), - task_queue: TEST_Q.to_string(), + task_queue: task_queue.clone(), cancellation_type: ActivityCancellationType::TryCancel as i32, ..Default::default() } diff --git a/core/src/core_tests/workers.rs b/core/src/core_tests/workers.rs index b29fbe4b5..f5288a442 100644 --- a/core/src/core_tests/workers.rs +++ b/core/src/core_tests/workers.rs @@ -315,7 +315,7 @@ async fn worker_shutdown_api(#[case] use_cache: bool, #[case] api_success: bool) mock.expect_is_mock().returning(|| true); mock.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - mock.expect_get_identity() + mock.expect_identity() .returning(|| "test-identity".to_string()); if use_cache { if api_success { diff --git a/core/src/core_tests/workflow_tasks.rs b/core/src/core_tests/workflow_tasks.rs index e26b6c887..e5550938b 100644 --- a/core/src/core_tests/workflow_tasks.rs +++ b/core/src/core_tests/workflow_tasks.rs @@ -2996,7 +2996,9 @@ async fn both_normal_and_sticky_pollers_poll_concurrently() { Arc::new(mock_client), None, None, - ); + false, + ) + .unwrap(); for _ in 1..50 { let activation = worker.poll_workflow_activation().await.unwrap(); diff --git a/core/src/ephemeral_server/mod.rs b/core/src/ephemeral_server/mod.rs index f8f0ee57b..e64aa6482 100644 --- a/core/src/ephemeral_server/mod.rs +++ b/core/src/ephemeral_server/mod.rs @@ -241,8 +241,7 @@ impl EphemeralServer { } } Err(anyhow!( - "Failed connecting to test server after 5 seconds, last error: {:?}", - last_error + "Failed connecting to test server after 5 seconds, last error: {last_error:?}" )) } @@ -368,7 +367,7 @@ impl EphemeralExe { let arch = match std::env::consts::ARCH { "x86_64" => "amd64", "arm" | "aarch64" => "arm64", - other => return Err(anyhow!("Unsupported arch: {}", other)), + other => return Err(anyhow!("Unsupported arch: {other}")), }; let mut get_info_params = vec![("arch", arch), ("platform", platform)]; if let Some(format) = preferred_format { diff --git a/core/src/lib.rs b/core/src/lib.rs index db8f80da4..3995b7562 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -61,7 +61,8 @@ use crate::{ }; use anyhow::bail; use futures_util::Stream; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; +use std::time::Duration; use temporal_client::{ConfiguredClient, NamespacedClient, TemporalServiceClientWithMetrics}; use temporal_sdk_core_api::{ Worker as WorkerTrait, @@ -89,39 +90,41 @@ pub fn init_worker( where CT: Into, { - let client = init_worker_client(&worker_config, *client.into().into_inner()); - if client.namespace() != worker_config.namespace { + let client_inner = *client.into().into_inner(); + let client = init_worker_client( + worker_config.namespace.clone(), + worker_config.client_identity_override.clone(), + client_inner, + ); + let namespace = worker_config.namespace.clone(); + if client.namespace() != namespace { bail!("Passed in client is not bound to the same namespace as the worker"); } if client.namespace() == "" { bail!("Client namespace cannot be empty"); } let client_ident = client.get_identity().to_owned(); - let sticky_q = sticky_q_name_for_worker(&client_ident, &worker_config); + let sticky_q = sticky_q_name_for_worker(&client_ident, worker_config.max_cached_workflows); if client_ident.is_empty() { bail!("Client identity cannot be empty. Either lang or user should be setting this value"); } - let heartbeat_fn = worker_config - .heartbeat_interval - .map(|_| Arc::new(OnceLock::new())); - let client_bag = Arc::new(WorkerClientBag::new( client, - worker_config.namespace.clone(), - client_ident, + namespace.clone(), + client_ident.clone(), worker_config.versioning_strategy.clone(), - heartbeat_fn.clone(), )); - Ok(Worker::new( - worker_config, + Worker::new( + worker_config.clone(), sticky_q, - client_bag, + client_bag.clone(), Some(&runtime.telemetry), - heartbeat_fn, - )) + runtime.heartbeat_interval, + false, + ) } /// Create a worker for replaying one or more existing histories. It will auto-shutdown as soon as @@ -142,11 +145,12 @@ where } pub(crate) fn init_worker_client( - config: &WorkerConfig, + namespace: String, + client_identity_override: Option, client: ConfiguredClient, ) -> RetryClient { - let mut client = Client::new(client, config.namespace.clone()); - if let Some(ref id_override) = config.client_identity_override { + let mut client = Client::new(client, namespace); + if let Some(ref id_override) = client_identity_override { client.options_mut().identity.clone_from(id_override); } RetryClient::new(client, RetryConfig::default()) @@ -156,9 +160,9 @@ pub(crate) fn init_worker_client( /// workflows. pub(crate) fn sticky_q_name_for_worker( process_identity: &str, - config: &WorkerConfig, + max_cached_workflows: usize, ) -> Option { - if config.max_cached_workflows > 0 { + if max_cached_workflows > 0 { Some(format!( "{}-{}", &process_identity, @@ -220,6 +224,21 @@ pub struct CoreRuntime { telemetry: TelemetryInstance, runtime: Option, runtime_handle: tokio::runtime::Handle, + heartbeat_interval: Option, +} + +/// Holds telemetry options, as well as worker heartbeat_interval. Construct with [RuntimeOptionsBuilder] +#[derive(derive_builder::Builder)] +#[non_exhaustive] +#[derive(Default)] +pub struct RuntimeOptions { + /// Telemetry configuration options. + #[builder(default)] + telemetry_options: TelemetryOptions, + /// Optional worker heartbeat interval - This configures the heartbeat setting of all + /// workers created using this runtime. + #[builder(default = "Some(Duration::from_secs(60))")] + heartbeat_interval: Option, } /// Wraps a [tokio::runtime::Builder] to allow layering multiple on_thread_start functions @@ -254,13 +273,13 @@ impl CoreRuntime { /// If a tokio runtime has already been initialized. To re-use an existing runtime, call /// [CoreRuntime::new_assume_tokio]. pub fn new( - telemetry_options: TelemetryOptions, + runtime_options: RuntimeOptions, mut tokio_builder: TokioRuntimeBuilder, ) -> Result where F: Fn() + Send + Sync + 'static, { - let telemetry = telemetry_init(telemetry_options)?; + let telemetry = telemetry_init(runtime_options.telemetry_options)?; let subscriber = telemetry.trace_subscriber(); let runtime = tokio_builder .inner @@ -275,7 +294,8 @@ impl CoreRuntime { }) .build()?; let _rg = runtime.enter(); - let mut me = Self::new_assume_tokio_initialized_telem(telemetry); + let mut me = + Self::new_assume_tokio_initialized_telem(telemetry, runtime_options.heartbeat_interval); me.runtime = Some(runtime); Ok(me) } @@ -285,9 +305,12 @@ impl CoreRuntime { /// /// # Panics /// If there is no currently active Tokio runtime - pub fn new_assume_tokio(telemetry_options: TelemetryOptions) -> Result { - let telemetry = telemetry_init(telemetry_options)?; - Ok(Self::new_assume_tokio_initialized_telem(telemetry)) + pub fn new_assume_tokio(runtime_options: RuntimeOptions) -> Result { + let telemetry = telemetry_init(runtime_options.telemetry_options)?; + Ok(Self::new_assume_tokio_initialized_telem( + telemetry, + runtime_options.heartbeat_interval, + )) } /// Construct a runtime from an already-initialized telemetry instance, assuming a tokio runtime @@ -295,7 +318,10 @@ impl CoreRuntime { /// /// # Panics /// If there is no currently active Tokio runtime - pub fn new_assume_tokio_initialized_telem(telemetry: TelemetryInstance) -> Self { + pub fn new_assume_tokio_initialized_telem( + telemetry: TelemetryInstance, + heartbeat_interval: Option, + ) -> Self { let runtime_handle = tokio::runtime::Handle::current(); if let Some(sub) = telemetry.trace_subscriber() { set_trace_subscriber_for_current_thread(sub); @@ -304,6 +330,7 @@ impl CoreRuntime { telemetry, runtime: None, runtime_handle, + heartbeat_interval, } } diff --git a/core/src/pollers/poll_buffer.rs b/core/src/pollers/poll_buffer.rs index 72f2e8a41..7bc4311fb 100644 --- a/core/src/pollers/poll_buffer.rs +++ b/core/src/pollers/poll_buffer.rs @@ -203,6 +203,7 @@ impl LongPollBuffer { permit_dealer: MeteredPermitDealer, shutdown: CancellationToken, num_pollers_handler: Option, + send_heartbeat: bool, ) -> Self { let no_retry = if matches!(poller_behavior, PollerBehavior::Autoscaling { .. }) { Some(NoRetryOnMatching { @@ -216,11 +217,14 @@ impl LongPollBuffer { let task_queue = task_queue.clone(); async move { client - .poll_nexus_task(PollOptions { - task_queue, - no_retry, - timeout_override, - }) + .poll_nexus_task( + PollOptions { + task_queue, + no_retry, + timeout_override, + }, + send_heartbeat, + ) .await } }; diff --git a/core/src/protosext/mod.rs b/core/src/protosext/mod.rs index fe0d60686..1a1119033 100644 --- a/core/src/protosext/mod.rs +++ b/core/src/protosext/mod.rs @@ -135,7 +135,7 @@ impl TryFrom for ValidPollWFTQResponse { _cant_construct_me: (), }) } - _ => Err(anyhow!("Unable to interpret poll response: {:?}", value)), + _ => Err(anyhow!("Unable to interpret poll response: {value:?}",)), } } } diff --git a/core/src/protosext/protocol_messages.rs b/core/src/protosext/protocol_messages.rs index 3962edec5..af2aab47d 100644 --- a/core/src/protosext/protocol_messages.rs +++ b/core/src/protosext/protocol_messages.rs @@ -116,7 +116,7 @@ impl TryFrom> for IncomingProtocolMessageBody { v.unpack_as(update::v1::Request::default())?.try_into()?, ) } - o => bail!("Could not understand protocol message type {}", o), + o => bail!("Could not understand protocol message type {o}"), }) } } diff --git a/core/src/replay/mod.rs b/core/src/replay/mod.rs index 650070b20..1e4990000 100644 --- a/core/src/replay/mod.rs +++ b/core/src/replay/mod.rs @@ -114,7 +114,7 @@ where hist_allow_tx.send("Failed".to_string()).unwrap(); async move { Ok(RespondWorkflowTaskFailedResponse::default()) }.boxed() }); - let mut worker = Worker::new(self.config, None, Arc::new(client), None, None); + let mut worker = Worker::new(self.config, None, Arc::new(client), None, None, false)?; worker.set_post_activate_hook(post_activate); shutdown_tok(worker.shutdown_token()); Ok(worker) diff --git a/core/src/telemetry/metrics.rs b/core/src/telemetry/metrics.rs index 28bf18845..fe6aec8e1 100644 --- a/core/src/telemetry/metrics.rs +++ b/core/src/telemetry/metrics.rs @@ -1,4 +1,6 @@ -use crate::{abstractions::dbg_panic, telemetry::TelemetryInstance}; +#[cfg(test)] +use crate::TelemetryInstance; +use crate::abstractions::dbg_panic; use std::{ fmt::{Debug, Display}, @@ -11,7 +13,7 @@ use temporal_sdk_core_api::telemetry::metrics::{ GaugeF64, GaugeF64Base, Histogram, HistogramBase, HistogramDuration, HistogramDurationBase, HistogramF64, HistogramF64Base, LazyBufferInstrument, MetricAttributable, MetricAttributes, MetricCallBufferer, MetricEvent, MetricKeyValue, MetricKind, MetricParameters, MetricUpdateVal, - NewAttributes, NoOpCoreMeter, + NewAttributes, NoOpCoreMeter, TemporalMeter, }; use temporal_sdk_core_protos::temporal::api::{ enums::v1::WorkflowTaskFailedCause, failure::v1::Failure, @@ -76,8 +78,17 @@ impl MetricsContext { } } + #[cfg(test)] pub(crate) fn top_level(namespace: String, tq: String, telemetry: &TelemetryInstance) -> Self { - if let Some(mut meter) = telemetry.get_temporal_metric_meter() { + MetricsContext::top_level_with_meter(namespace, tq, telemetry.get_temporal_metric_meter()) + } + + pub(crate) fn top_level_with_meter( + namespace: String, + tq: String, + temporal_meter: Option, + ) -> Self { + if let Some(mut meter) = temporal_meter { meter .default_attribs .attributes diff --git a/core/src/telemetry/prometheus_meter.rs b/core/src/telemetry/prometheus_meter.rs index 0b53e8c12..37810a534 100644 --- a/core/src/telemetry/prometheus_meter.rs +++ b/core/src/telemetry/prometheus_meter.rs @@ -315,8 +315,7 @@ where Ok(labels) } else { let e = anyhow!( - "Must use Prometheus attributes with a Prometheus metric implementation. Got: {:?}", - attributes + "Must use Prometheus attributes with a Prometheus metric implementation. Got: {attributes:?}" ); dbg_panic!("{:?}", e); Err(e) diff --git a/core/src/test_help/integ_helpers.rs b/core/src/test_help/integ_helpers.rs index 2c3417dc6..2f3cded78 100644 --- a/core/src/test_help/integ_helpers.rs +++ b/core/src/test_help/integ_helpers.rs @@ -62,13 +62,11 @@ use temporal_sdk_core_protos::{ }; use tokio::sync::{Notify, mpsc::unbounded_channel}; use tokio_stream::wrappers::UnboundedReceiverStream; +use uuid::Uuid; /// Default namespace for testing pub const NAMESPACE: &str = "default"; -/// Default task queue for testing -pub const TEST_Q: &str = "q"; - /// Initiate shutdown, drain the pollers (handling evictions), and wait for shutdown to complete. pub async fn drain_pollers_and_shutdown(worker: &dyn WorkerTrait) { worker.initiate_shutdown(); @@ -102,7 +100,7 @@ pub async fn drain_pollers_and_shutdown(worker: &dyn WorkerTrait) { pub fn test_worker_cfg() -> WorkerConfigBuilder { let mut wcb = WorkerConfigBuilder::default(); wcb.namespace(NAMESPACE) - .task_queue(TEST_Q) + .task_queue(Uuid::new_v4().to_string()) .versioning_strategy(WorkerVersioningStrategy::None { build_id: "test_bin_id".to_string(), }) @@ -185,7 +183,7 @@ pub fn build_fake_worker( } pub fn mock_worker(mocks: MocksHolder) -> Worker { - let sticky_q = sticky_q_name_for_worker("unit-test", &mocks.inputs.config); + let sticky_q = sticky_q_name_for_worker("unit-test", mocks.inputs.config.max_cached_workflows); let act_poller = if mocks.inputs.config.no_remote_activities { None } else { @@ -205,7 +203,9 @@ pub fn mock_worker(mocks: MocksHolder) -> Worker { }, None, None, + false, ) + .unwrap() } pub struct FakeWfResponses { @@ -275,7 +275,7 @@ impl MocksHolder { } } - /// Uses the provided list of tasks to create a mock poller for the `TEST_Q` + /// Uses the provided list of tasks to create a mock poller with a randomly generated task queue pub fn from_client_with_activities( client: impl WorkerClient + 'static, act_tasks: ACT, diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 519994fc3..5d773330e 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -1,17 +1,12 @@ //! Worker-specific client needs pub(crate) mod mocks; -use crate::{ - abstractions::dbg_panic, protosext::legacy_query_failure, worker::heartbeat::HeartbeatFn, -}; +use crate::protosext::legacy_query_failure; use parking_lot::RwLock; -use std::{ - sync::{Arc, OnceLock}, - time::Duration, -}; +use std::{sync::Arc, time::Duration}; use temporal_client::{ - Client, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, RetryClient, - SlotManager, WorkflowService, + Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, + RetryClient, WorkflowService, }; use temporal_sdk_core_api::worker::WorkerVersioningStrategy; use temporal_sdk_core_protos::{ @@ -38,6 +33,7 @@ use temporal_sdk_core_protos::{ }, }; use tonic::IntoRequest; +use uuid::Uuid; type Result = std::result::Result; @@ -52,7 +48,6 @@ pub(crate) struct WorkerClientBag { namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, - heartbeat_data: Option>>, } impl WorkerClientBag { @@ -61,14 +56,12 @@ impl WorkerClientBag { namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, - heartbeat_data: Option>>, ) -> Self { Self { replaceable_client: RwLock::new(client), namespace, identity, worker_versioning_strategy, - heartbeat_data, } } @@ -129,19 +122,6 @@ impl WorkerClientBag { None } } - - fn capture_heartbeat(&self) -> Option { - if let Some(heartbeat_data) = self.heartbeat_data.as_ref() { - if let Some(hb) = heartbeat_data.get() { - hb() - } else { - dbg_panic!("Heartbeat function never set"); - None - } - } else { - None - } - } } /// This trait contains everything workers need to interact with Temporal, and hence provides a @@ -165,6 +145,7 @@ pub trait WorkerClient: Sync + Send { async fn poll_nexus_task( &self, poll_options: PollOptions, + send_heartbeat: bool, ) -> Result; /// Complete a workflow task async fn complete_workflow_task( @@ -234,7 +215,8 @@ pub trait WorkerClient: Sync + Send { /// Record a worker heartbeat async fn record_worker_heartbeat( &self, - heartbeat: WorkerHeartbeat, + namespace: String, + worker_heartbeat: Vec, ) -> Result; /// Replace the underlying client @@ -242,13 +224,15 @@ pub trait WorkerClient: Sync + Send { /// Return server capabilities fn capabilities(&self) -> Option; /// Return workers using this client - fn workers(&self) -> Arc; + fn workers(&self) -> Arc; /// Indicates if this is a mock client fn is_mock(&self) -> bool; /// Return name and version of the SDK fn sdk_name_and_version(&self) -> (String, String); /// Get worker identity - fn get_identity(&self) -> String; + fn identity(&self) -> String; + /// Get worker grouping key + fn worker_grouping_key(&self) -> Uuid; } /// Configuration options shared by workflow, activity, and Nexus polling calls @@ -360,6 +344,7 @@ impl WorkerClient for WorkerClientBag { async fn poll_nexus_task( &self, poll_options: PollOptions, + _send_heartbeat: bool, ) -> Result { #[allow(deprecated)] // want to list all fields explicitly let mut request = PollNexusTaskQueueRequest { @@ -372,7 +357,7 @@ impl WorkerClient for WorkerClientBag { identity: self.identity.clone(), worker_version_capabilities: self.worker_version_capabilities(), deployment_options: self.deployment_options(), - worker_heartbeat: self.capture_heartbeat().into_iter().collect(), + worker_heartbeat: Vec::new(), } .into_request(); request.extensions_mut().insert(IsWorkerTaskLongPoll); @@ -661,7 +646,7 @@ impl WorkerClient for WorkerClientBag { identity: self.identity.clone(), sticky_task_queue, reason: "graceful shutdown".to_string(), - worker_heartbeat: self.capture_heartbeat(), + worker_heartbeat: None, }; Ok( @@ -671,32 +656,34 @@ impl WorkerClient for WorkerClientBag { ) } - fn replace_client(&self, new_client: RetryClient) { - let mut replaceable_client = self.replaceable_client.write(); - *replaceable_client = new_client; - } - async fn record_worker_heartbeat( &self, - heartbeat: WorkerHeartbeat, + namespace: String, + worker_heartbeat: Vec, ) -> Result { + let request = RecordWorkerHeartbeatRequest { + namespace, + identity: self.identity.clone(), + worker_heartbeat, + }; Ok(self .cloned_client() - .record_worker_heartbeat(RecordWorkerHeartbeatRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - worker_heartbeat: vec![heartbeat], - }) + .record_worker_heartbeat(request) .await? .into_inner()) } + fn replace_client(&self, new_client: RetryClient) { + let mut replaceable_client = self.replaceable_client.write(); + *replaceable_client = new_client; + } + fn capabilities(&self) -> Option { let client = self.replaceable_client.read(); client.get_client().inner().capabilities().cloned() } - fn workers(&self) -> Arc { + fn workers(&self) -> Arc { let client = self.replaceable_client.read(); client.get_client().inner().workers() } @@ -711,9 +698,16 @@ impl WorkerClient for WorkerClientBag { (opts.client_name.clone(), opts.client_version.clone()) } - fn get_identity(&self) -> String { + fn identity(&self) -> String { self.identity.clone() } + + fn worker_grouping_key(&self) -> Uuid { + self.replaceable_client + .read() + .get_client() + .worker_grouping_key() + } } impl NamespacedClient for WorkerClientBag { diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index f6407f2a4..93984c364 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -1,10 +1,10 @@ use super::*; use futures_util::Future; use std::sync::{Arc, LazyLock}; -use temporal_client::SlotManager; +use temporal_client::ClientWorkerSet; -pub(crate) static DEFAULT_WORKERS_REGISTRY: LazyLock> = - LazyLock::new(|| Arc::new(SlotManager::new())); +pub(crate) static DEFAULT_WORKERS_REGISTRY: LazyLock> = + LazyLock::new(|| Arc::new(ClientWorkerSet::new())); pub(crate) static DEFAULT_TEST_CAPABILITIES: &Capabilities = &Capabilities { signal_and_query_header: true, @@ -33,8 +33,9 @@ pub fn mock_worker_client() -> MockWorkerClient { .returning(|_| Ok(ShutdownWorkerResponse {})); r.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - r.expect_get_identity() + r.expect_identity() .returning(|| "test-identity".to_string()); + r.expect_worker_grouping_key().returning(Uuid::new_v4); r } @@ -48,7 +49,7 @@ pub(crate) fn mock_manual_worker_client() -> MockManualWorkerClient { r.expect_is_mock().returning(|| true); r.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); - r.expect_get_identity() + r.expect_identity() .returning(|| "test-identity".to_string()); r } @@ -68,7 +69,7 @@ mockall::mock! { -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn poll_nexus_task<'a, 'b>(&self, poll_options: PollOptions) + fn poll_nexus_task<'a, 'b>(&self, poll_options: PollOptions, send_heartbeat: bool) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -139,7 +140,7 @@ mockall::mock! { fn respond_legacy_query<'a, 'b>( &self, task_token: TaskToken, - query_result: LegacyQueryResult, + query_result: LegacyQueryResult, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -150,13 +151,18 @@ mockall::mock! { fn shutdown_worker<'a, 'b>(&self, sticky_task_queue: String) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn record_worker_heartbeat<'a, 'b>(&self, heartbeat: WorkerHeartbeat) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; + fn record_worker_heartbeat<'a, 'b>( + &self, + namespace: String, + heartbeat: Vec + ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; fn replace_client(&self, new_client: RetryClient); fn capabilities(&self) -> Option; - fn workers(&self) -> Arc; + fn workers(&self) -> Arc; fn is_mock(&self) -> bool; fn sdk_name_and_version(&self) -> (String, String); - fn get_identity(&self) -> String; + fn identity(&self) -> String; + fn worker_grouping_key(&self) -> Uuid; } } diff --git a/core/src/worker/heartbeat.rs b/core/src/worker/heartbeat.rs index f7c8d5694..88774647f 100644 --- a/core/src/worker/heartbeat.rs +++ b/core/src/worker/heartbeat.rs @@ -1,55 +1,113 @@ -use crate::{WorkerClient, abstractions::dbg_panic}; -use gethostname::gethostname; +use crate::WorkerClient; +use crate::worker::{TaskPollers, WorkerTelemetry}; use parking_lot::Mutex; use prost_types::Duration as PbDuration; +use std::collections::HashMap; use std::{ - sync::{Arc, OnceLock}, + sync::Arc, time::{Duration, SystemTime}, }; -use temporal_sdk_core_api::worker::WorkerConfig; -use temporal_sdk_core_protos::temporal::api::worker::v1::{WorkerHeartbeat, WorkerHostInfo}; -use tokio::{sync::Notify, task::JoinHandle, time::MissedTickBehavior}; +use temporal_client::SharedNamespaceWorkerTrait; +use temporal_sdk_core_api::worker::{ + PollerBehavior, WorkerConfigBuilder, WorkerVersioningStrategy, +}; +use temporal_sdk_core_protos::temporal::api::worker::v1::WorkerHeartbeat; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; use uuid::Uuid; -pub(crate) type HeartbeatFn = Box Option + Send + Sync>; +/// Callback used to collect heartbeat data from each worker at the time of heartbeat +pub(crate) type HeartbeatFn = Box WorkerHeartbeat + Send + Sync>; -pub(crate) struct WorkerHeartbeatManager { - heartbeat_handle: JoinHandle<()>, +/// SharedNamespaceWorker is responsible for polling nexus-delivered worker commands and sending +/// worker heartbeats to the server. This invokes callbacks on all workers in the same process that +/// share the same namespace. +pub(crate) struct SharedNamespaceWorker { + heartbeat_map: Arc>>, + namespace: String, + cancel: CancellationToken, } -impl WorkerHeartbeatManager { +impl SharedNamespaceWorker { pub(crate) fn new( - config: WorkerConfig, - identity: String, - heartbeat_fn: Arc>, client: Arc, - ) -> Self { - let sdk_name_and_ver = client.sdk_name_and_version(); - let reset_notify = Arc::new(Notify::new()); - let data = Arc::new(Mutex::new(WorkerHeartbeatData::new( + namespace: String, + heartbeat_interval: Duration, + telemetry: Option, + ) -> Result { + let config = WorkerConfigBuilder::default() + .namespace(namespace.clone()) + .task_queue(format!( + "temporal-sys/worker-commands/{namespace}/{}", + client.worker_grouping_key(), + )) + .no_remote_activities(true) + .max_outstanding_nexus_tasks(5_usize) + .versioning_strategy(WorkerVersioningStrategy::None { + build_id: "1.0".to_owned(), + }) + .nexus_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) + .build() + .expect("all required fields should be implemented"); + let worker = crate::worker::Worker::new_with_pollers_inner( config, - identity, - sdk_name_and_ver, - reset_notify.clone(), - ))); - let data_clone = data.clone(); - - let heartbeat_handle = tokio::spawn(async move { - let mut ticker = tokio::time::interval(data_clone.lock().heartbeat_interval); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + None, + client.clone(), + TaskPollers::Real, + telemetry, + None, + true, + )?; + + let last_heartbeat_time_map = Mutex::new(HashMap::new()); + + let reset_notify = Arc::new(Notify::new()); + let cancel = CancellationToken::new(); + let cancel_clone = cancel.clone(); + + let client_clone = client; + let namespace_clone = namespace.clone(); + + let heartbeat_map = Arc::new(Mutex::new(HashMap::::new())); + let heartbeat_map_clone = heartbeat_map.clone(); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(heartbeat_interval); loop { tokio::select! { _ = ticker.tick() => { - let heartbeat = if let Some(heartbeat) = data_clone.lock().capture_heartbeat_if_needed() { - heartbeat - } else { - continue - }; - if let Err(e) = client.clone().record_worker_heartbeat(heartbeat).await { - if matches!( - e.code(), - tonic::Code::Unimplemented - ) { + let mut hb_to_send = Vec::new(); + for (instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() { + let mut heartbeat = heartbeat_callback(); + let mut last_heartbeat_time_map = last_heartbeat_time_map.lock(); + let now = SystemTime::now(); + let elapsed_since_last_heartbeat = last_heartbeat_time_map.get(instance_key).cloned().map( + |hb_time| { + let dur = now.duration_since(hb_time).unwrap_or(Duration::ZERO); + PbDuration { + seconds: dur.as_secs() as i64, + nanos: dur.subsec_nanos() as i32, + } + } + ); + + heartbeat.elapsed_since_last_heartbeat = elapsed_since_last_heartbeat; + heartbeat.heartbeat_time = Some(now.into()); + + // All of these heartbeat details rely on a client. To avoid circular + // dependencies, this must be populated from within SharedNamespaceWorker + // to get info from the current client + heartbeat.worker_identity = client_clone.identity(); + let sdk_name_and_ver = client_clone.sdk_name_and_version(); + heartbeat.sdk_name = sdk_name_and_ver.0; + heartbeat.sdk_version = sdk_name_and_ver.1; + + hb_to_send.push(heartbeat); + + last_heartbeat_time_map.insert(*instance_key, now); + } + if let Err(e) = client_clone.record_worker_heartbeat(namespace_clone.clone(), hb_to_send).await { + if matches!(e.code(), tonic::Code::Unimplemented) { return; } warn!(error=?e, "Network error while sending worker heartbeat"); @@ -58,131 +116,83 @@ impl WorkerHeartbeatManager { _ = reset_notify.notified() => { ticker.reset(); } + _ = cancel_clone.cancelled() => { + worker.shutdown().await; + return; + } } } }); - let data_clone = data.clone(); - if heartbeat_fn - .set(Box::new(move || { - data_clone.lock().capture_heartbeat_if_needed() - })) - .is_err() - { - dbg_panic!( - "Failed to set heartbeat_fn, heartbeat_fn should only be set once, when a singular WorkerHeartbeatInfo is created" - ); - } - - Self { heartbeat_handle } - } - - pub(crate) fn shutdown(&self) { - self.heartbeat_handle.abort() + Ok(Self { + heartbeat_map, + namespace, + cancel, + }) } } -#[derive(Debug, Clone)] -struct WorkerHeartbeatData { - worker_instance_key: String, - worker_identity: String, - host_info: WorkerHostInfo, - // Time of the last heartbeat. This is used to both for heartbeat_time and last_heartbeat_time - heartbeat_time: Option, - task_queue: String, - /// SDK name - sdk_name: String, - /// SDK version - sdk_version: String, - /// Worker start time - start_time: SystemTime, - heartbeat_interval: Duration, - reset_notify: Arc, -} +impl SharedNamespaceWorkerTrait for SharedNamespaceWorker { + fn namespace(&self) -> String { + self.namespace.clone() + } -impl WorkerHeartbeatData { - fn new( - worker_config: WorkerConfig, - worker_identity: String, - sdk_name_and_ver: (String, String), - reset_notify: Arc, - ) -> Self { - Self { - worker_identity, - host_info: WorkerHostInfo { - host_name: gethostname().to_string_lossy().to_string(), - process_id: std::process::id().to_string(), - ..Default::default() - }, - sdk_name: sdk_name_and_ver.0, - sdk_version: sdk_name_and_ver.1, - task_queue: worker_config.task_queue.clone(), - start_time: SystemTime::now(), - heartbeat_time: None, - worker_instance_key: Uuid::new_v4().to_string(), - heartbeat_interval: worker_config - .heartbeat_interval - .expect("WorkerHeartbeatData is only called when heartbeat_interval is Some"), - reset_notify, + fn register_callback( + &self, + worker_instance_key: Uuid, + heartbeat_callback: Box WorkerHeartbeat + Send + Sync>, + ) { + self.heartbeat_map + .lock() + .insert(worker_instance_key, heartbeat_callback); + } + fn unregister_callback( + &self, + worker_instance_key: Uuid, + ) -> (Option WorkerHeartbeat + Send + Sync>>, bool) { + let mut heartbeat_map = self.heartbeat_map.lock(); + let heartbeat_callback = heartbeat_map.remove(&worker_instance_key); + if heartbeat_map.is_empty() { + self.cancel.cancel(); } + (heartbeat_callback, heartbeat_map.is_empty()) } - fn capture_heartbeat_if_needed(&mut self) -> Option { - let now = SystemTime::now(); - let elapsed_since_last_heartbeat = if let Some(heartbeat_time) = self.heartbeat_time { - let dur = now.duration_since(heartbeat_time).unwrap_or(Duration::ZERO); - - // Only send poll data if it's nearly been a full interval since this data has been sent - // In this case, "nearly" is 90% of the interval - if dur.as_secs_f64() < 0.9 * self.heartbeat_interval.as_secs_f64() { - return None; - } - Some(PbDuration { - seconds: dur.as_secs() as i64, - nanos: dur.subsec_nanos() as i32, - }) - } else { - None - }; - - self.heartbeat_time = Some(now); - - self.reset_notify.notify_one(); - - Some(WorkerHeartbeat { - worker_instance_key: self.worker_instance_key.clone(), - worker_identity: self.worker_identity.clone(), - host_info: Some(self.host_info.clone()), - task_queue: self.task_queue.clone(), - sdk_name: self.sdk_name.clone(), - sdk_version: self.sdk_version.clone(), - status: 0, - start_time: Some(self.start_time.into()), - heartbeat_time: Some(SystemTime::now().into()), - elapsed_since_last_heartbeat, - ..Default::default() - }) + fn num_workers(&self) -> usize { + self.heartbeat_map.lock().len() } } #[cfg(test)] mod tests { - use super::*; use crate::{ test_help::{WorkerExt, test_worker_cfg}, worker, worker::client::mocks::mock_worker_client, }; - use std::{sync::Arc, time::Duration}; + use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, + }; use temporal_sdk_core_api::worker::PollerBehavior; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::RecordWorkerHeartbeatResponse; #[tokio::test] - async fn worker_heartbeat() { + async fn worker_heartbeat_basic() { let mut mock = mock_worker_client(); - mock.expect_record_worker_heartbeat() - .times(2) - .returning(move |heartbeat| { + let heartbeat_count = Arc::new(AtomicUsize::new(0)); + let heartbeat_count_clone = heartbeat_count.clone(); + mock.expect_poll_workflow_task() + .returning(move |_namespace, _task_queue| Ok(Default::default())); + mock.expect_poll_nexus_task() + .returning(move |_poll_options, _send_heartbeat| Ok(Default::default())); + mock.expect_record_worker_heartbeat().times(3).returning( + move |_namespace, worker_heartbeat| { + assert_eq!(1, worker_heartbeat.len()); + let heartbeat = worker_heartbeat[0].clone(); let host_info = heartbeat.host_info.clone().unwrap(); assert_eq!("test-identity", heartbeat.worker_identity); assert!(!heartbeat.worker_instance_key.is_empty()); @@ -193,38 +203,35 @@ mod tests { assert_eq!(host_info.process_id, std::process::id().to_string()); assert_eq!(heartbeat.sdk_name, "test-core"); assert_eq!(heartbeat.sdk_version, "0.0.0"); - assert_eq!(heartbeat.task_queue, "q"); assert!(heartbeat.heartbeat_time.is_some()); assert!(heartbeat.start_time.is_some()); + heartbeat_count_clone.fetch_add(1, Ordering::Relaxed); + Ok(RecordWorkerHeartbeatResponse {}) - }); + }, + ); let config = test_worker_cfg() .activity_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) .max_outstanding_activities(1_usize) - .heartbeat_interval(Duration::from_millis(200)) .build() .unwrap(); - let heartbeat_fn = Arc::new(OnceLock::new()); let client = Arc::new(mock); - let worker = worker::Worker::new(config, None, client, None, Some(heartbeat_fn.clone())); - heartbeat_fn.get().unwrap()(); - - // heartbeat timer fires once - advance_time(Duration::from_millis(300)).await; - // it hasn't been >90% of the interval since the last heartbeat, so no data should be returned here - assert_eq!(None, heartbeat_fn.get().unwrap()()); - // heartbeat timer fires once - advance_time(Duration::from_millis(300)).await; - + let worker = worker::Worker::new( + config, + None, + client.clone(), + None, + Some(Duration::from_millis(100)), + false, + ) + .unwrap(); + + tokio::time::sleep(Duration::from_millis(250)).await; worker.drain_activity_poller_and_shutdown().await; - } - async fn advance_time(dur: Duration) { - tokio::time::pause(); - tokio::time::advance(dur).await; - tokio::time::resume(); + assert_eq!(3, heartbeat_count.load(Ordering::Relaxed)); } } diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 5e57a26e8..5428c067a 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -20,11 +20,12 @@ pub(crate) use activities::{ pub(crate) use wft_poller::WFTPollerShared; pub use workflow::LEGACY_QUERY_ID; +use crate::worker::heartbeat::{HeartbeatFn, SharedNamespaceWorker}; use crate::{ ActivityHeartbeat, CompleteActivityError, PollError, WorkerTrait, abstractions::{MeteredPermitDealer, PermitDealerContextData, dbg_panic}, errors::CompleteWfError, - pollers::{ActivityTaskOptions, BoxedActPoller, BoxedNexusPoller, LongPollBuffer}, + pollers::{BoxedActPoller, BoxedNexusPoller}, protosext::validate_activity_completion, telemetry::{ TelemetryInstance, @@ -36,32 +37,41 @@ use crate::{ worker::{ activities::{LACompleteAction, LocalActivityManager, NextPendingLAAction}, client::WorkerClient, - heartbeat::{HeartbeatFn, WorkerHeartbeatManager}, nexus::NexusManager, workflow::{ - LAReqSink, LocalResolution, WorkflowBasics, Workflows, wft_poller, - wft_poller::make_wft_poller, + LAReqSink, LocalResolution, WorkflowBasics, Workflows, wft_poller::make_wft_poller, }, }, }; +use crate::{ + pollers::{ActivityTaskOptions, LongPollBuffer}, + worker::workflow::wft_poller, +}; use activities::WorkerActivityTasks; +use anyhow::bail; use futures_util::{StreamExt, stream}; -use parking_lot::Mutex; +use gethostname::gethostname; +use parking_lot::{Mutex, RwLock}; use slot_provider::SlotProvider; use std::{ convert::TryInto, future, sync::{ - Arc, OnceLock, + Arc, atomic::{AtomicBool, Ordering}, }, time::Duration, }; -use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics, WorkerKey}; +use temporal_client::{ClientWorker, HeartbeatCallback, Slot as SlotTrait}; +use temporal_client::{ + ConfiguredClient, SharedNamespaceWorkerTrait, TemporalServiceClientWithMetrics, +}; +use temporal_sdk_core_api::telemetry::metrics::TemporalMeter; use temporal_sdk_core_api::{ errors::{CompleteNexusError, WorkerValidationError}, worker::PollerBehavior, }; +use temporal_sdk_core_protos::temporal::api::worker::v1::{WorkerHeartbeat, WorkerHostInfo}; use temporal_sdk_core_protos::{ TaskToken, coresdk::{ @@ -80,7 +90,8 @@ use temporal_sdk_core_protos::{ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; - +use tracing::Subscriber; +use uuid::Uuid; #[cfg(any(feature = "test-utilities", test))] use { crate::{ @@ -97,8 +108,8 @@ use { pub struct Worker { config: WorkerConfig, client: Arc, - /// Registration key to enable eager workflow start for this worker - worker_key: Mutex>, + /// Worker instance key, unique identifier for this worker + worker_instance_key: Uuid, /// Manages all workflows and WFT processing workflows: Workflows, /// Manages activity tasks for this worker/task queue @@ -118,8 +129,8 @@ pub struct Worker { local_activities_complete: Arc, /// Used to track all permits have been released all_permits_tracker: tokio::sync::Mutex, - /// Used to shutdown the worker heartbeat task - worker_heartbeat: Option, + /// Used to track worker client + client_worker_registrator: Arc, } struct AllPermitsTracker { @@ -136,6 +147,13 @@ impl AllPermitsTracker { } } +#[derive(Clone)] +pub(crate) struct WorkerTelemetry { + metric_meter: Option, + temporal_metric_meter: Option, + trace_subscriber: Option>, +} + #[async_trait::async_trait] impl WorkerTrait for Worker { async fn validate(&self) -> Result<(), WorkerValidationError> { @@ -231,10 +249,20 @@ impl WorkerTrait for Worker { ); } self.shutdown_token.cancel(); - // First, disable Eager Workflow Start - if let Some(key) = *self.worker_key.lock() { - self.client.workers().unregister(key); + // First, unregister worker from the client + if let Err(e) = self + .client + .workers() + .unregister_worker(self.worker_instance_key) + { + error!( + task_queue=%self.config.task_queue, + namespace=%self.config.namespace, + error=%e, + "Failed to unregister worker on shutdown", + ); } + // Second, we want to stop polling of both activity and workflow tasks if let Some(atm) = self.at_task_mgr.as_ref() { atm.initiate_shutdown(); @@ -272,8 +300,9 @@ impl Worker { sticky_queue_name: Option, client: Arc, telem_instance: Option<&TelemetryInstance>, - heartbeat_fn: Option>>, - ) -> Self { + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { info!(task_queue=%config.task_queue, namespace=%config.namespace, "Initializing worker"); Self::new_with_pollers( @@ -282,25 +311,43 @@ impl Worker { client, TaskPollers::Real, telem_instance, - heartbeat_fn, + worker_heartbeat_interval, + shared_namespace_worker, ) } - /// Replace client and return a new client. For eager workflow purposes, this new client will - /// now apply to future eager start requests and the older client will not. - pub fn replace_client(&self, new_client: ConfiguredClient) { + /// Replace client and return a new client. + /// + /// For eager workflow purposes, this new client will now apply to future eager start requests + /// and the older client will not. Note, if this registration fails, the worker heartbeat will + /// also not be registered. + /// + /// For worker heartbeat, this will remove an existing shared worker if it is the last worker of + /// the old client and create a new nexus worker if it's the first client of the namespace on + /// the new client. + pub fn replace_client( + &self, + new_client: ConfiguredClient, + ) -> Result<(), anyhow::Error> { // Unregister worker from current client, register in new client at the end - let mut worker_key = self.worker_key.lock(); - let slot_provider = (*worker_key).and_then(|k| self.client.workers().unregister(k)); - self.client - .replace_client(super::init_worker_client(&self.config, new_client)); - *worker_key = - slot_provider.and_then(|slot_provider| self.client.workers().register(slot_provider)); + let client_worker = self + .client + .workers() + .unregister_worker(self.worker_instance_key)?; + let new_worker_client = super::init_worker_client( + self.config.namespace.clone(), + self.config.client_identity_override.clone(), + new_client, + ); + + self.client.replace_client(new_worker_client); + *self.client_worker_registrator.client.write() = self.client.clone(); + self.client.workers().register_worker(client_worker) } #[cfg(test)] pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self { - Self::new(config, None, Arc::new(client), None, None) + Self::new(config, None, Arc::new(client), None, None, false).unwrap() } pub(crate) fn new_with_pollers( @@ -309,16 +356,48 @@ impl Worker { client: Arc, task_pollers: TaskPollers, telem_instance: Option<&TelemetryInstance>, - heartbeat_fn: Option>>, - ) -> Self { - let (metrics, meter) = if let Some(ti) = telem_instance { + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { + let worker_telemetry = telem_instance.map(|telem| WorkerTelemetry { + metric_meter: telem.get_metric_meter(), + temporal_metric_meter: telem.get_temporal_metric_meter(), + trace_subscriber: telem.trace_subscriber(), + }); + + Worker::new_with_pollers_inner( + config, + sticky_queue_name, + client, + task_pollers, + worker_telemetry, + worker_heartbeat_interval, + shared_namespace_worker, + ) + } + + pub(crate) fn new_with_pollers_inner( + config: WorkerConfig, + sticky_queue_name: Option, + client: Arc, + task_pollers: TaskPollers, + worker_telemetry: Option, + worker_heartbeat_interval: Option, + shared_namespace_worker: bool, + ) -> Result { + let (metrics, meter) = if let Some(wt) = worker_telemetry.as_ref() { ( - MetricsContext::top_level(config.namespace.clone(), config.task_queue.clone(), ti), - ti.get_metric_meter(), + MetricsContext::top_level_with_meter( + config.namespace.clone(), + config.task_queue.clone(), + wt.temporal_metric_meter.clone(), + ), + wt.metric_meter.clone(), ) } else { (MetricsContext::no_op(), None) }; + let tuner = config .tuner .as_ref() @@ -329,7 +408,7 @@ impl Worker { let shutdown_token = CancellationToken::new(); let slot_context_data = Arc::new(PermitDealerContextData { task_queue: config.task_queue.clone(), - worker_identity: client.get_identity(), + worker_identity: client.identity(), worker_deployment_version: config.computed_deployment_version(), }); let wft_slots = MeteredPermitDealer::new( @@ -408,6 +487,7 @@ impl Worker { nexus_slots.clone(), shutdown_token.child_token(), Some(move |np| np_metrics.record_num_pollers(np)), + shared_namespace_worker, )) as BoxedNexusPoller; #[cfg(any(feature = "test-utilities", test))] @@ -486,20 +566,33 @@ impl Worker { wft_slots.clone(), external_wft_tx, ); - let worker_key = Mutex::new(client.workers().register(Box::new(provider))); - let sdk_name_and_ver = client.sdk_name_and_version(); + let worker_instance_key = Uuid::new_v4(); - let worker_heartbeat = heartbeat_fn.map(|heartbeat_fn| { + let sdk_name_and_ver = client.sdk_name_and_version(); + let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| { WorkerHeartbeatManager::new( config.clone(), - client.get_identity(), - heartbeat_fn, - client.clone(), + worker_instance_key, + hb_interval, + worker_telemetry.clone(), ) }); - Self { - worker_key, + let client_worker_registrator = Arc::new(ClientWorkerRegistrator { + worker_instance_key, + slot_provider: provider, + heartbeat_manager: worker_heartbeat, + client: RwLock::new(client.clone()), + }); + + if !shared_namespace_worker { + client + .workers() + .register_worker(client_worker_registrator.clone())?; + } + + Ok(Self { + worker_instance_key, client: client.clone(), workflows: Workflows::new( WorkflowBasics { @@ -538,7 +631,9 @@ impl Worker { _ => Some(mgr.get_handle_for_workflows()), } }), - telem_instance, + worker_telemetry + .as_ref() + .and_then(|telem| telem.trace_subscriber.clone()), ), at_task_mgr, local_act_mgr, @@ -554,8 +649,8 @@ impl Worker { la_permits, }), nexus_mgr, - worker_heartbeat, - } + client_worker_registrator, + }) } /// Will shutdown the worker. Does not resolve until all outstanding workflow tasks have been @@ -599,9 +694,6 @@ impl Worker { dbg_panic!("Waiting for all slot permits to release took too long!"); } } - if let Some(heartbeat) = self.worker_heartbeat.as_ref() { - heartbeat.shutdown(); - } } /// Finish shutting down by consuming the background pollers and freeing all resources @@ -858,6 +950,124 @@ impl Worker { } } +struct ClientWorkerRegistrator { + worker_instance_key: Uuid, + slot_provider: SlotProvider, + heartbeat_manager: Option, + client: RwLock>, +} + +impl ClientWorker for ClientWorkerRegistrator { + fn namespace(&self) -> &str { + self.slot_provider.namespace() + } + fn task_queue(&self) -> &str { + self.slot_provider.task_queue() + } + + fn try_reserve_wft_slot(&self) -> Option> { + self.slot_provider.try_reserve_wft_slot() + } + + fn worker_instance_key(&self) -> Uuid { + self.worker_instance_key + } + + fn heartbeat_enabled(&self) -> bool { + self.heartbeat_manager.is_some() + } + + fn heartbeat_callback(&self) -> Option { + if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { + let mut heartbeat_manager = hb_mgr.heartbeat_callback.lock(); + heartbeat_manager.take() + } else { + None + } + } + fn new_shared_namespace_worker( + &self, + ) -> Result, anyhow::Error> { + if let Some(ref hb_mgr) = self.heartbeat_manager { + Ok(Box::new(SharedNamespaceWorker::new( + self.client.read().clone(), + self.namespace().to_string(), + hb_mgr.heartbeat_interval, + hb_mgr.telemetry.clone(), + )?)) + } else { + bail!("Shared namespace worker creation never be called without a heartbeat manager"); + } + } + + fn register_callback(&self, callback: HeartbeatCallback) { + if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { + hb_mgr.heartbeat_callback.lock().replace(callback); + } + } +} + +struct WorkerHeartbeatManager { + /// Heartbeat interval, defaults to 60s + heartbeat_interval: Duration, + /// Telemetry instance, needed to initialize [SharedNamespaceWorker] when replacing client + telemetry: Option, + /// Heartbeat callback + heartbeat_callback: Mutex WorkerHeartbeat + Send + Sync>>>, +} + +impl WorkerHeartbeatManager { + fn new( + config: WorkerConfig, + worker_instance_key: Uuid, + heartbeat_interval: Duration, + telemetry_instance: Option, + ) -> Self { + let worker_instance_key_clone = worker_instance_key.to_string(); + let task_queue = config.task_queue.clone(); + + // TODO: requires the metrics changes to get the rest of these fields + let worker_heartbeat_callback: HeartbeatFn = Box::new(move || { + WorkerHeartbeat { + worker_instance_key: worker_instance_key_clone.clone(), + host_info: Some(WorkerHostInfo { + host_name: gethostname().to_string_lossy().to_string(), + process_id: std::process::id().to_string(), + ..Default::default() + }), + task_queue: task_queue.clone(), + deployment_version: None, + + status: 0, + start_time: Some(std::time::SystemTime::now().into()), + workflow_task_slots_info: None, + activity_task_slots_info: None, + nexus_task_slots_info: None, + local_activity_slots_info: None, + workflow_poller_info: None, + workflow_sticky_poller_info: None, + activity_poller_info: None, + nexus_poller_info: None, + total_sticky_cache_hit: 0, + total_sticky_cache_miss: 0, + current_sticky_cache_size: 0, + plugins: vec![], + + // sdk_name, sdk_version, and worker_identity must be set by + // SharedNamespaceWorker because they rely on the client, and + // need to be pulled from the current client used by SharedNamespaceWorker + ..Default::default() + } + }); + + WorkerHeartbeatManager { + heartbeat_interval, + telemetry: telemetry_instance, + heartbeat_callback: Mutex::new(Some(worker_heartbeat_callback)), + } + } +} + pub(crate) struct PostActivateHookData<'a> { pub(crate) run_id: &'a str, pub(crate) replaying: bool, diff --git a/core/src/worker/slot_provider.rs b/core/src/worker/slot_provider.rs index 1b5fcba95..40e9fb00e 100644 --- a/core/src/worker/slot_provider.rs +++ b/core/src/worker/slot_provider.rs @@ -7,7 +7,7 @@ use crate::{ protosext::ValidPollWFTQResponse, worker::workflow::wft_poller::validate_wft, }; -use temporal_client::{Slot as SlotTrait, SlotProvider as SlotProviderTrait}; +use temporal_client::Slot as SlotTrait; use temporal_sdk_core_api::worker::WorkflowSlotKind; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; use tokio::sync::mpsc::UnboundedSender; @@ -74,16 +74,13 @@ impl SlotProvider { external_wft_tx, } } -} - -impl SlotProviderTrait for SlotProvider { - fn namespace(&self) -> &str { + pub(super) fn namespace(&self) -> &str { &self.namespace } - fn task_queue(&self) -> &str { + pub(super) fn task_queue(&self) -> &str { &self.task_queue } - fn try_reserve_wft_slot(&self) -> Option> { + pub(super) fn try_reserve_wft_slot(&self) -> Option> { match self.wft_semaphore.try_acquire_owned().ok() { Some(permit) => Some(Box::new(Slot::new(permit, self.external_wft_tx.clone()))), None => None, diff --git a/core/src/worker/workflow/mod.rs b/core/src/worker/workflow/mod.rs index 33b9c6dac..24aa59f11 100644 --- a/core/src/worker/workflow/mod.rs +++ b/core/src/worker/workflow/mod.rs @@ -23,7 +23,7 @@ use crate::{ internal_flags::InternalFlags, pollers::TrackedPermittedTqResp, protosext::{ValidPollWFTQResponse, protocol_messages::IncomingProtocolMessage}, - telemetry::{TelemetryInstance, VecDisplayer, set_trace_subscriber_for_current_thread}, + telemetry::{VecDisplayer, set_trace_subscriber_for_current_thread}, worker::{ LocalActRequest, LocalActivityExecutionResult, LocalActivityResolution, PostActivateHookData, @@ -94,7 +94,7 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -use tracing::Span; +use tracing::{Span, Subscriber}; /// Id used by server for "legacy" queries. IE: Queries that come in the `query` rather than /// `queries` field of a WFT, and are responded to on the separate `respond_query_task_completed` @@ -166,7 +166,7 @@ impl Workflows { local_act_mgr: Arc, heartbeat_timeout_rx: UnboundedReceiver, activity_tasks_handle: Option, - telem_instance: Option<&TelemetryInstance>, + tracing_sub: Option>, ) -> Self { let (local_tx, local_rx) = unbounded_channel(); let (fetch_tx, fetch_rx) = unbounded_channel(); @@ -187,7 +187,6 @@ impl Workflows { let (start_polling_tx, start_polling_rx) = oneshot::channel(); // We must spawn a task to constantly poll the activation stream, because otherwise // activation completions would not cause anything to happen until the next poll. - let tracing_sub = telem_instance.and_then(|ti| ti.trace_subscriber()); let processing_task = thread::Builder::new() .name("workflow-processing".to_string()) .spawn(move || { diff --git a/sdk-core-protos/src/history_builder.rs b/sdk-core-protos/src/history_builder.rs index 3dcc306c5..9aef52737 100644 --- a/sdk-core-protos/src/history_builder.rs +++ b/sdk-core-protos/src/history_builder.rs @@ -621,7 +621,7 @@ fn default_attribs(et: EventType) -> Result { EventType::WorkflowExecutionStarted => default_wes_attribs().into(), EventType::WorkflowTaskScheduled => WorkflowTaskScheduledEventAttributes::default().into(), EventType::TimerStarted => TimerStartedEventAttributes::default().into(), - _ => bail!("Don't know how to construct default attrs for {:?}", et), + _ => bail!("Don't know how to construct default attrs for {et:?}"), }) } diff --git a/sdk/src/interceptors.rs b/sdk/src/interceptors.rs index 84a7b56fd..5d014fb17 100644 --- a/sdk/src/interceptors.rs +++ b/sdk/src/interceptors.rs @@ -88,10 +88,7 @@ impl WorkerInterceptor for FailOnNondeterminismInterceptor { activation.eviction_reason(), Some(EvictionReason::Nondeterminism) ) { - bail!( - "Workflow is being evicted because of nondeterminism! {}", - activation - ); + bail!("Workflow is being evicted because of nondeterminism! {activation}"); } Ok(()) } diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index 90115715e..e3970b4e2 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -10,7 +10,7 @@ //! ```no_run //! use std::{str::FromStr, sync::Arc}; //! use temporal_sdk::{sdk_client_options, ActContext, Worker}; -//! use temporal_sdk_core::{init_worker, Url, CoreRuntime}; +//! use temporal_sdk_core::{init_worker, Url, CoreRuntime, RuntimeOptionsBuilder}; //! use temporal_sdk_core_api::{ //! worker::{WorkerConfigBuilder, WorkerVersioningStrategy}, //! telemetry::TelemetryOptionsBuilder @@ -20,10 +20,11 @@ //! async fn main() -> Result<(), Box> { //! let server_options = sdk_client_options(Url::from_str("http://localhost:7233")?).build()?; //! -//! let client = server_options.connect("default", None).await?; -//! //! let telemetry_options = TelemetryOptionsBuilder::default().build()?; -//! let runtime = CoreRuntime::new_assume_tokio(telemetry_options)?; +//! let runtime_options = RuntimeOptionsBuilder::default().telemetry_options(telemetry_options).build().unwrap(); +//! let runtime = CoreRuntime::new_assume_tokio(runtime_options)?; +//! +//! let client = server_options.connect("default", None).await?; //! //! let worker_config = WorkerConfigBuilder::default() //! .namespace("default") @@ -497,11 +498,7 @@ impl WorkflowHalf { // In all other cases, we want to error as the runtime could be in an inconsistent state // at this point. - bail!( - "Got activation {:?} for unknown workflow {}", - activation, - run_id - ); + bail!("Got activation {activation:?} for unknown workflow {run_id}"); }; Ok(res) diff --git a/sdk/src/workflow_future.rs b/sdk/src/workflow_future.rs index 9e1b6036e..648d58f82 100644 --- a/sdk/src/workflow_future.rs +++ b/sdk/src/workflow_future.rs @@ -135,7 +135,7 @@ impl WorkflowFuture { }; let unblocker = self.command_status.remove(&cmd_id); let _ = unblocker - .ok_or_else(|| anyhow!("Command {:?} not found to unblock!", cmd_id))? + .ok_or_else(|| anyhow!("Command {cmd_id:?} not found to unblock!"))? .unblocker .send(event); Ok(()) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b17d07d31..35b2825c3 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -39,8 +39,8 @@ use temporal_sdk::{ }, }; use temporal_sdk_core::{ - ClientOptions, ClientOptionsBuilder, CoreRuntime, WorkerConfigBuilder, init_replay_worker, - init_worker, + ClientOptions, ClientOptionsBuilder, CoreRuntime, RuntimeOptions, RuntimeOptionsBuilder, + WorkerConfigBuilder, init_replay_worker, init_worker, replay::{HistoryForReplay, ReplayWorkerInput}, telemetry::{build_otlp_metric_exporter, start_prometheus_metric_exporter}, }; @@ -164,8 +164,12 @@ pub(crate) fn init_integ_telem() -> Option<&'static CoreRuntime> { } Some(INTEG_TESTS_RT.get_or_init(|| { let telemetry_options = get_integ_telem_options(); + let runtime_options = RuntimeOptionsBuilder::default() + .telemetry_options(telemetry_options) + .build() + .expect("Runtime options build cleanly"); let rt = - CoreRuntime::new_assume_tokio(telemetry_options).expect("Core runtime inits cleanly"); + CoreRuntime::new_assume_tokio(runtime_options).expect("Core runtime inits cleanly"); if let Some(sub) = rt.telemetry().trace_subscriber() { let _ = tracing::subscriber::set_global_default(sub); } @@ -314,8 +318,7 @@ impl CoreWfStarter { pub(crate) async fn worker(&mut self) -> TestWorker { let w = self.get_worker().await; - let tq = w.get_config().task_queue.clone(); - let mut w = TestWorker::new(w, tq); + let mut w = TestWorker::new(w); w.client = Some(self.get_client().await); w @@ -477,8 +480,11 @@ pub(crate) struct TestWorker { } impl TestWorker { /// Create a new test worker - pub(crate) fn new(core_worker: Arc, task_queue: impl Into) -> Self { - let inner = Worker::new_from_core(core_worker.clone(), task_queue); + pub(crate) fn new(core_worker: Arc) -> Self { + let inner = Worker::new_from_core( + core_worker.clone(), + core_worker.get_config().task_queue.clone(), + ); Self { inner, core_worker, @@ -807,6 +813,13 @@ pub(crate) fn get_integ_telem_options() -> TelemetryOptions { .unwrap() } +pub(crate) fn get_integ_runtime_options(telemopts: TelemetryOptions) -> RuntimeOptions { + RuntimeOptionsBuilder::default() + .telemetry_options(telemopts) + .build() + .unwrap() +} + #[async_trait::async_trait(?Send)] pub(crate) trait WorkflowHandleExt { async fn fetch_history_and_replay( @@ -932,10 +945,7 @@ pub(crate) fn mock_sdk_cfg( let mut mock = build_mock_pollers(poll_cfg); mock.worker_cfg(mutator); let core = mock_worker(mock); - TestWorker::new( - Arc::new(core), - temporal_sdk_core::test_help::TEST_Q.to_string(), - ) + TestWorker::new(Arc::new(core)) } #[derive(Default)] diff --git a/tests/global_metric_tests.rs b/tests/global_metric_tests.rs index 14e799195..822bf238c 100644 --- a/tests/global_metric_tests.rs +++ b/tests/global_metric_tests.rs @@ -2,6 +2,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::CoreWfStarter; use parking_lot::Mutex; use std::{sync::Arc, time::Duration}; @@ -71,18 +72,16 @@ async fn otel_errors_logged_as_errors() { .unwrap(), ) .unwrap(); + let telemopts = TelemetryOptionsBuilder::default() + .metrics(Arc::new(exporter) as Arc) + // Importantly, _not_ using subscriber override, is using console. + .logging(Logger::Console { + filter: construct_filter_string(Level::INFO, Level::WARN), + }) + .build() + .unwrap(); - let rt = CoreRuntime::new_assume_tokio( - TelemetryOptionsBuilder::default() - .metrics(Arc::new(exporter) as Arc) - // Importantly, _not_ using subscriber override, is using console. - .logging(Logger::Console { - filter: construct_filter_string(Level::INFO, Level::WARN), - }) - .build() - .unwrap(), - ) - .unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("otel_errors_logged_as_errors", rt); let _worker = starter.get_worker().await; diff --git a/tests/heavy_tests.rs b/tests/heavy_tests.rs index f5dc9018d..64bb298fb 100644 --- a/tests/heavy_tests.rs +++ b/tests/heavy_tests.rs @@ -2,6 +2,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::{ CoreWfStarter, init_integ_telem, prom_metrics, rand_6_chars, workflows::la_problem_workflow, }; @@ -194,7 +195,7 @@ async fn workflow_load() { // cause us to encounter the tracing span drop bug telemopts.logging = None; init_integ_telem(); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("workflow_load", rt); starter .worker_config diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index dd4068ac4..dc7caa812 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -1,3 +1,4 @@ +use crate::common::get_integ_runtime_options; use crate::{ common::{ ANY_PORT, CoreWfStarter, NAMESPACE, OTEL_URL_ENV_VAR, PROMETHEUS_QUERY_API, @@ -97,7 +98,7 @@ async fn prometheus_metrics_exported( }); } let (telemopts, addr, _aborter) = prom_metrics(Some(opts_builder.build().unwrap())); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let opts = get_integ_server_options(); let mut raw_client = opts .connect_no_namespace(rt.telemetry().get_temporal_metric_meter()) @@ -148,7 +149,7 @@ async fn prometheus_metrics_exported( async fn one_slot_worker_reports_available_slot() { let (telemopts, addr, _aborter) = prom_metrics(None); let tq = "one_slot_worker_tq"; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let worker_cfg = WorkerConfigBuilder::default() .namespace(NAMESPACE) @@ -401,7 +402,7 @@ async fn query_of_closed_workflow_doesnt_tick_terminal_metric( completion: workflow_command::Variant, ) { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("query_of_closed_workflow_doesnt_tick_terminal_metric", rt); // Disable cache to ensure replay happens completely @@ -523,8 +524,11 @@ async fn query_of_closed_workflow_doesnt_tick_terminal_metric( #[test] fn runtime_new() { - let mut rt = - CoreRuntime::new(get_integ_telem_options(), TokioRuntimeBuilder::default()).unwrap(); + let mut rt = CoreRuntime::new( + get_integ_runtime_options(get_integ_telem_options()), + TokioRuntimeBuilder::default(), + ) + .unwrap(); let handle = rt.tokio_handle(); let _rt = handle.enter(); let (telemopts, addr, _aborter) = prom_metrics(None); @@ -570,7 +574,7 @@ async fn latency_metrics( .build() .unwrap(), )); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("latency_metrics", rt); let worker = starter.get_worker().await; starter.start_wf().await; @@ -624,7 +628,7 @@ async fn latency_metrics( #[tokio::test] async fn request_fail_codes() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let opts = get_integ_server_options(); let mut client = opts .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) @@ -667,8 +671,8 @@ async fn request_fail_codes_otel() { let mut telemopts = TelemetryOptionsBuilder::default(); let exporter = Arc::new(exporter); telemopts.metrics(exporter as Arc); - - let rt = CoreRuntime::new_assume_tokio(telemopts.build().unwrap()).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts.build().unwrap())) + .unwrap(); let opts = get_integ_server_options(); let mut client = opts .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) @@ -718,7 +722,7 @@ async fn docker_metrics_with_prometheus( .metric_prefix(test_uid.clone()) .build() .unwrap(); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let test_name = "docker_metrics_with_prometheus"; let mut starter = CoreWfStarter::new_with_runtime(test_name, rt); let worker = starter.get_worker().await; @@ -758,8 +762,15 @@ async fn docker_metrics_with_prometheus( assert!(!data.is_empty(), "No metrics found for query: {test_uid}"); assert_eq!(data[0]["metric"]["exported_job"], "temporal-core-sdk"); assert_eq!(data[0]["metric"]["job"], "otel-collector"); + // Worker heartbeating nexus worker assert!( data[0]["metric"]["task_queue"] + .as_str() + .unwrap() + .starts_with("temporal-sys/worker-commands/default/") + ); + assert!( + data[1]["metric"]["task_queue"] .as_str() .unwrap() .starts_with(test_name) @@ -772,7 +783,7 @@ async fn docker_metrics_with_prometheus( #[tokio::test] async fn activity_metrics() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "activity_metrics"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter @@ -906,7 +917,7 @@ async fn activity_metrics() { #[tokio::test] async fn nexus_metrics() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "nexus_metrics"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); @@ -1083,7 +1094,7 @@ async fn nexus_metrics() { #[tokio::test] async fn evict_on_complete_does_not_count_as_forced_eviction() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "evict_on_complete_does_not_count_as_forced_eviction"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); @@ -1166,7 +1177,7 @@ where #[tokio::test] async fn metrics_available_from_custom_slot_supplier() { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("metrics_available_from_custom_slot_supplier", rt); starter.worker_config.no_remote_activities(true); diff --git a/tests/integ_tests/polling_tests.rs b/tests/integ_tests/polling_tests.rs index a69006e39..aaa69ec16 100644 --- a/tests/integ_tests/polling_tests.rs +++ b/tests/integ_tests/polling_tests.rs @@ -198,7 +198,9 @@ async fn switching_worker_client_changes_poll() { // Swap client, poll for next task, confirm it's second wf, and respond w/ empty info!("Replacing client and polling again"); - worker.replace_client(client2.get_client().inner().clone()); + worker + .replace_client(client2.get_client().inner().clone()) + .unwrap(); let act2 = worker.poll_workflow_activation().await.unwrap(); assert_eq!(wf2.run_id, act2.run_id); worker.complete_execution(&act2.run_id).await; diff --git a/tests/integ_tests/worker_tests.rs b/tests/integ_tests/worker_tests.rs index 728ba3222..920e2606d 100644 --- a/tests/integ_tests/worker_tests.rs +++ b/tests/integ_tests/worker_tests.rs @@ -1,3 +1,4 @@ +use crate::common::get_integ_runtime_options; use crate::{ common::{CoreWfStarter, get_integ_server_options, get_integ_telem_options, mock_sdk_cfg}, shared_tests, @@ -17,8 +18,8 @@ use temporal_sdk::{ActivityOptions, WfContext, interceptors::WorkerInterceptor}; use temporal_sdk_core::{ CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, init_worker, test_help::{ - FakeWfResponses, MockPollCfg, ResponseType, TEST_Q, build_mock_pollers, - drain_pollers_and_shutdown, hist_to_poll_resp, mock_worker, mock_worker_client, + FakeWfResponses, MockPollCfg, ResponseType, build_mock_pollers, drain_pollers_and_shutdown, + hist_to_poll_resp, mock_worker, mock_worker_client, }, }; use temporal_sdk_core_api::{ @@ -61,7 +62,9 @@ use uuid::Uuid; #[tokio::test] async fn worker_validation_fails_on_nonexistent_namespace() { let opts = get_integ_server_options(); - let runtime = CoreRuntime::new_assume_tokio(get_integ_telem_options()).unwrap(); + let runtime = + CoreRuntime::new_assume_tokio(get_integ_runtime_options(get_integ_telem_options())) + .unwrap(); let retrying_client = opts .connect_no_namespace(runtime.telemetry().get_temporal_metric_meter()) .await @@ -318,7 +321,7 @@ async fn activity_tasks_from_completion_reserve_slots() { cfg.max_outstanding_activities = Some(2); }); let core = Arc::new(mock_worker(mock)); - let mut worker = crate::common::TestWorker::new(core.clone(), TEST_Q.to_string()); + let mut worker = crate::common::TestWorker::new(core.clone()); // First poll for activities twice, occupying both slots let at1 = core.poll_activity_task().await.unwrap(); diff --git a/tests/integ_tests/workflow_tests.rs b/tests/integ_tests/workflow_tests.rs index 6d10fbb71..263c942b0 100644 --- a/tests/integ_tests/workflow_tests.rs +++ b/tests/integ_tests/workflow_tests.rs @@ -18,6 +18,7 @@ mod stickyness; mod timers; mod upsert_search_attrs; +use crate::common::get_integ_runtime_options; use crate::{ common::{ CoreWfStarter, history_from_proto_binary, init_core_and_create_wf, @@ -67,7 +68,6 @@ use temporal_sdk_core_protos::{ test_utils::schedule_activity_cmd, }; use tokio::{join, sync::Notify, time::sleep}; - // TODO: We should get expected histories for these tests and confirm that the history at the end // matches. @@ -764,7 +764,7 @@ async fn nondeterminism_errors_fail_workflow_when_configured_to( #[values(true, false)] whole_worker: bool, ) { let (telemopts, addr, _aborter) = prom_metrics(None); - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let wf_name = "nondeterminism_errors_fail_workflow_when_configured_to"; let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); starter.worker_config.no_remote_activities(true); diff --git a/tests/main.rs b/tests/main.rs index bf05e2a1f..8a71f03ca 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -27,7 +27,8 @@ mod integ_tests { mod workflow_tests; use crate::common::{ - CoreWfStarter, get_integ_server_options, get_integ_telem_options, rand_6_chars, + CoreWfStarter, get_integ_runtime_options, get_integ_server_options, + get_integ_telem_options, rand_6_chars, }; use std::time::Duration; use temporal_client::{NamespacedClient, WorkflowService}; @@ -44,7 +45,9 @@ mod integ_tests { #[ignore] // Really a compile time check more than anything async fn lang_bridge_example() { let opts = get_integ_server_options(); - let runtime = CoreRuntime::new_assume_tokio(get_integ_telem_options()).unwrap(); + let runtime = + CoreRuntime::new_assume_tokio(get_integ_runtime_options(get_integ_telem_options())) + .unwrap(); let mut retrying_client = opts .connect_no_namespace(runtime.telemetry().get_temporal_metric_meter()) .await diff --git a/tests/manual_tests.rs b/tests/manual_tests.rs index 8f5ef4c5b..9588be3f3 100644 --- a/tests/manual_tests.rs +++ b/tests/manual_tests.rs @@ -5,6 +5,7 @@ #[allow(dead_code)] mod common; +use crate::common::get_integ_runtime_options; use common::{CoreWfStarter, prom_metrics, rand_6_chars}; use futures_util::{ StreamExt, @@ -41,7 +42,7 @@ async fn poller_load_spiky() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config @@ -200,7 +201,7 @@ async fn poller_load_sustained() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config @@ -291,7 +292,7 @@ async fn poller_load_spike_then_sustained() { } else { prom_metrics(None) }; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let mut starter = CoreWfStarter::new_with_runtime("poller_load", rt); starter .worker_config diff --git a/tests/workflow_replay_bench.rs b/tests/workflow_replay_bench.rs index 4200ebd3a..d80796b0a 100644 --- a/tests/workflow_replay_bench.rs +++ b/tests/workflow_replay_bench.rs @@ -5,7 +5,9 @@ #[allow(dead_code)] mod common; -use crate::common::{DONT_AUTO_INIT_INTEG_TELEM, prom_metrics, replay_sdk_worker}; +use crate::common::{ + DONT_AUTO_INIT_INTEG_TELEM, get_integ_runtime_options, prom_metrics, replay_sdk_worker, +}; use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use futures_util::StreamExt; use std::{ @@ -80,7 +82,7 @@ pub fn bench_metrics(c: &mut Criterion) { let _tokio = tokio_runtime.enter(); let (mut telemopts, _addr, _aborter) = prom_metrics(None); telemopts.logging = None; - let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let rt = CoreRuntime::new_assume_tokio(get_integ_runtime_options(telemopts)).unwrap(); let meter = rt.telemetry().get_metric_meter().unwrap(); c.bench_function("Record with new attributes on each call", move |b| { From 4bc7383fa8dc8801c5acdfe3c57531c0414de5c6 Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Fri, 17 Oct 2025 09:39:35 -0700 Subject: [PATCH 2/5] Worker heartbeat: New in-memory metrics mechism, plumb rest of heartbeat data (#1023) * plumb in memory metrics * simplify worker::new(), fix some heartbeat metrics, new test file * CounterImpl, final_heartbeat, more specific metric label dbg_panic msg, counter_with_in_mem and and_then() * Support in-mem metrics when metrics aren't configured * Move sys_info refresh to dedicated thread, use tuner's existing sys info * Format, AtomicCell * Fix unit test * Set dynamic config for WorkerHeartbeatsEnabled and ListWorkersEnabled, remove stale metric previously added * Should not expect heartbeat nexus worker in metrics for non-heartbeating integ test * recv_timeout instead of thread::sleep, use WorkflowService::list_workers directly, WithLabel API improvement * MetricAttributes::NoOp, add mechanism to ignore dupe workers for testing, more tests * More tests, sticky cache miss, plugins * Formatting, fix skip_client_worker_set_check * Cursor found a bug * Lower sleep time, add print for debugging * more prints * use semaphores for worker_heartbeat_failure_metrics * skip_client_worker_set_check for all integ workers * Can't use tokio semaphore in workflow code * use signal to test workflow_slots.last_interval_failure_tasks * Use Notify instead of semaphores, fix test flake * Use eventually() instead of a manual sleep * max_outstanding_workflow_tasks 2 --- .cargo/config.toml | 2 +- client/src/raw.rs | 18 + client/src/worker_registry/mod.rs | 63 +- core-api/Cargo.toml | 1 + core-api/src/lib.rs | 9 + core-api/src/telemetry/metrics.rs | 493 +++++++- core-api/src/worker.rs | 21 +- core-c-bridge/src/client.rs | 4 + core/src/abstractions.rs | 43 + core/src/core_tests/workers.rs | 4 +- core/src/core_tests/workflow_tasks.rs | 1 - core/src/lib.rs | 8 +- core/src/pollers/poll_buffer.rs | 35 +- core/src/replay/mod.rs | 2 +- core/src/telemetry/metrics.rs | 77 +- core/src/telemetry/mod.rs | 1 + core/src/worker/activities.rs | 4 + core/src/worker/client.rs | 103 +- core/src/worker/client/mocks.rs | 11 +- core/src/worker/heartbeat.rs | 50 +- core/src/worker/mod.rs | 344 ++++-- core/src/worker/tuner.rs | 19 +- core/src/worker/tuner/fixed_size.rs | 4 + core/src/worker/tuner/resource_based.rs | 119 +- core/src/worker/workflow/wft_poller.rs | 7 + .../api_upstream/openapi/openapiv2.json | 259 +++- .../api_upstream/openapi/openapiv3.yaml | 239 +++- .../temporal/api/common/v1/message.proto | 2 +- .../temporal/api/deployment/v1/message.proto | 6 + .../temporal/api/namespace/v1/message.proto | 8 +- .../workflowservice/v1/request_response.proto | 62 +- .../api/workflowservice/v1/service.proto | 36 +- tests/common/mod.rs | 11 +- tests/integ_tests/metrics_tests.rs | 7 - tests/integ_tests/worker_heartbeat_tests.rs | 1039 +++++++++++++++++ tests/main.rs | 1 + tests/runner.rs | 4 + 37 files changed, 2783 insertions(+), 334 deletions(-) create mode 100644 tests/integ_tests/worker_heartbeat_tests.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index 4d1a8e5e1..17cfb5135 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,6 +1,6 @@ [env] # This temporarily overrides the version of the CLI used for integration tests, locally and in CI -CLI_VERSION_OVERRIDE = "v1.4.1-cloud-v1-29-0-139-2.0" +#CLI_VERSION_OVERRIDE = "v1.4.1-cloud-v1-29-0-139-2.0" [alias] # Not sure why --all-features doesn't work diff --git a/client/src/raw.rs b/client/src/raw.rs index 92e6a7956..26d08203d 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -1345,6 +1345,15 @@ proxier! { r.extensions_mut().insert(labels); } ); + ( + describe_worker, + DescribeWorkerRequest, + DescribeWorkerResponse, + |r| { + let labels = namespaced_request!(r); + r.extensions_mut().insert(labels); + } + ); ( record_worker_heartbeat, RecordWorkerHeartbeatRequest, @@ -1382,6 +1391,15 @@ proxier! { r.extensions_mut().insert(labels); } ); + ( + set_worker_deployment_manager, + SetWorkerDeploymentManagerRequest, + SetWorkerDeploymentManagerResponse, + |r| { + let labels = namespaced_request!(r); + r.extensions_mut().insert(labels); + } + ); } proxier! { diff --git a/client/src/worker_registry/mod.rs b/client/src/worker_registry/mod.rs index f10b128ce..6112eb11a 100644 --- a/client/src/worker_registry/mod.rs +++ b/client/src/worker_registry/mod.rs @@ -76,12 +76,13 @@ impl ClientWorkerSetImpl { fn register( &mut self, worker: Arc, + skip_client_worker_set_check: bool, ) -> Result<(), anyhow::Error> { let slot_key = SlotKey::new( worker.namespace().to_string(), worker.task_queue().to_string(), ); - if self.slot_providers.contains_key(&slot_key) { + if self.slot_providers.contains_key(&slot_key) && !skip_client_worker_set_check { bail!( "Registration of multiple workers on the same namespace and task queue for the same client not allowed: {slot_key:?}, worker_instance_key: {:?}.", worker.worker_instance_key() @@ -133,14 +134,8 @@ impl ClientWorkerSetImpl { if let Some(w) = self.shared_worker.get_mut(worker.namespace()) { let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key()); - if let Some(cb) = callback { - if is_empty { - self.shared_worker.remove(worker.namespace()); - } - - // To maintain single ownership of the callback, we must re-register the callback - // back to the ClientWorker - worker.register_callback(cb); + if callback.is_some() && is_empty { + self.shared_worker.remove(worker.namespace()); } } @@ -212,6 +207,17 @@ impl ClientWorkerSet { .try_reserve_wft_slot(namespace, task_queue) } + /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating. + pub fn register_worker( + &self, + worker: Arc, + skip_client_worker_set_check: bool, + ) -> Result<(), anyhow::Error> { + self.worker_manager + .write() + .register(worker, skip_client_worker_set_check) + } + /// Unregisters a local worker, typically when that worker starts shutdown. pub fn unregister_worker( &self, @@ -220,14 +226,6 @@ impl ClientWorkerSet { self.worker_manager.write().unregister(worker_instance_key) } - /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating. - pub fn register_worker( - &self, - worker: Arc, - ) -> Result<(), anyhow::Error> { - self.worker_manager.write().register(worker) - } - /// Returns the worker grouping key, which is unique for each worker. pub fn worker_grouping_key(&self) -> Uuid { self.worker_grouping_key @@ -256,7 +254,7 @@ impl std::fmt::Debug for ClientWorkerSet { } /// Contains a worker heartbeat callback, wrapped for mocking -pub type HeartbeatCallback = Box WorkerHeartbeat + Send + Sync>; +pub type HeartbeatCallback = Arc WorkerHeartbeat + Send + Sync>; /// Represents a complete worker that can handle both slot management /// and worker heartbeat functionality. @@ -276,7 +274,7 @@ pub trait ClientWorker: Send + Sync { fn try_reserve_wft_slot(&self) -> Option>; /// Unique identifier for this worker instance. - /// This must be stable across the worker's lifetime but unique per instance. + /// This must be stable across the worker's lifetime and unique per instance. fn worker_instance_key(&self) -> Uuid; /// Indicates if worker heartbeating is enabled for this client worker. @@ -289,9 +287,6 @@ pub trait ClientWorker: Send + Sync { fn new_shared_namespace_worker( &self, ) -> Result, anyhow::Error>; - - /// Registers a worker heartbeat callback, typically when a worker is unregistered from a client - fn register_callback(&self, callback: HeartbeatCallback); } #[cfg(test)] @@ -350,7 +345,7 @@ mod tests { new_mock_provider(namespace, "bar_q".to_string(), false, false, false); let worker_instance_key = mock_provider.worker_instance_key(); - let result = manager.register_worker(Arc::new(mock_provider)); + let result = manager.register_worker(Arc::new(mock_provider), false); if result.is_ok() { successful_registrations += 1; worker_keys.push(worker_instance_key); @@ -453,7 +448,7 @@ mod tests { if heartbeat_enabled { mock_provider .expect_heartbeat_callback() - .returning(|| Some(Box::new(WorkerHeartbeat::default))); + .returning(|| Some(Arc::new(WorkerHeartbeat::default))); let namespace_clone = namespace.clone(); mock_provider @@ -463,8 +458,6 @@ mod tests { namespace_clone.clone(), ))) }); - - mock_provider.expect_register_callback().returning(|_| {}); } mock_provider @@ -489,10 +482,10 @@ mod tests { Uuid::new_v4(), ); - manager.register_worker(Arc::new(worker1)).unwrap(); + manager.register_worker(Arc::new(worker1), false).unwrap(); // second worker register should fail due to duplicate namespace+task_queue - let result = manager.register_worker(Arc::new(worker2)); + let result = manager.register_worker(Arc::new(worker2), false); assert!(result.is_err()); assert!( result @@ -528,8 +521,8 @@ mod tests { Uuid::new_v4(), ); - manager.register_worker(Arc::new(worker1)).unwrap(); - manager.register_worker(Arc::new(worker2)).unwrap(); + manager.register_worker(Arc::new(worker1), false).unwrap(); + manager.register_worker(Arc::new(worker2), false).unwrap(); assert_eq!(2, manager.num_providers()); assert_eq!(manager.num_heartbeat_workers(), 2); @@ -558,8 +551,8 @@ mod tests { Uuid::new_v4(), ); - manager.register_worker(Arc::new(worker1)).unwrap(); - manager.register_worker(Arc::new(worker2)).unwrap(); + manager.register_worker(Arc::new(worker1), false).unwrap(); + manager.register_worker(Arc::new(worker2), false).unwrap(); assert_eq!(2, manager.num_providers()); assert_eq!(manager.num_heartbeat_workers(), 2); @@ -590,8 +583,10 @@ mod tests { let worker_instance_key1 = worker1.worker_instance_key(); let worker_instance_key2 = worker2.worker_instance_key(); - manager.register_worker(Arc::new(worker1)).unwrap(); - manager.register_worker(Arc::new(worker2)).unwrap(); + assert_ne!(worker_instance_key1, worker_instance_key2); + + manager.register_worker(Arc::new(worker1), false).unwrap(); + manager.register_worker(Arc::new(worker2), false).unwrap(); // Verify initial state: 2 slot providers, 2 heartbeat workers, 1 shared worker assert_eq!(2, manager.num_providers()); diff --git a/core-api/Cargo.toml b/core-api/Cargo.toml index ab2640dca..83042665a 100644 --- a/core-api/Cargo.toml +++ b/core-api/Cargo.toml @@ -31,6 +31,7 @@ tonic = { workspace = true } tracing = "0.1" tracing-core = "0.1" url = "2.5" +uuid = { version = "1.18.1", features = ["v4"] } [dependencies.temporal-sdk-core-protos] path = "../sdk-core-protos" diff --git a/core-api/src/lib.rs b/core-api/src/lib.rs index ca65ccae9..511c9383f 100644 --- a/core-api/src/lib.rs +++ b/core-api/src/lib.rs @@ -19,6 +19,7 @@ use temporal_sdk_core_protos::coresdk::{ workflow_activation::WorkflowActivation, workflow_completion::WorkflowActivationCompletion, }; +use uuid::Uuid; /// This trait is the primary way by which language specific SDKs interact with the core SDK. /// It represents one worker, which has a (potentially shared) client for connecting to the service @@ -138,6 +139,10 @@ pub trait Worker: Send + Sync { /// This should be called only after [Worker::shutdown] has resolved and/or both polling /// functions have returned `ShutDown` errors. async fn finalize_shutdown(self); + + /// Unique identifier for this worker instance. + /// This must be stable across the worker's lifetime and unique per instance. + fn worker_instance_key(&self) -> Uuid; } #[async_trait::async_trait] @@ -205,6 +210,10 @@ where async fn finalize_shutdown(self) { panic!("Can't finalize shutdown on Arc'd worker") } + + fn worker_instance_key(&self) -> Uuid { + (**self).worker_instance_key() + } } macro_rules! dbg_panic { diff --git a/core-api/src/telemetry/metrics.rs b/core-api/src/telemetry/metrics.rs index 407603f8d..bb276f9bd 100644 --- a/core-api/src/telemetry/metrics.rs +++ b/core-api/src/telemetry/metrics.rs @@ -1,4 +1,5 @@ use crate::dbg_panic; +use std::sync::atomic::{AtomicU64, Ordering}; use std::{ any::Any, borrow::Cow, @@ -26,6 +27,18 @@ pub trait CoreMeter: Send + Sync + Debug { attribs: NewAttributes, ) -> MetricAttributes; fn counter(&self, params: MetricParameters) -> Counter; + + /// Create a counter with in-memory tracking for worker heartbeating reporting + fn counter_with_in_memory( + &self, + params: MetricParameters, + in_memory_counter: HeartbeatMetricType, + ) -> Counter { + let primary_counter = self.counter(params); + + Counter::new_with_in_memory(primary_counter.primary.metric.clone(), in_memory_counter) + } + fn histogram(&self, params: MetricParameters) -> Histogram; fn histogram_f64(&self, params: MetricParameters) -> HistogramF64; /// Create a histogram which records Durations. Implementations should choose to emit in @@ -33,10 +46,217 @@ pub trait CoreMeter: Send + Sync + Debug { /// [MetricParameters::unit] should be overwritten by implementations to be `ms` or `s` /// accordingly. fn histogram_duration(&self, params: MetricParameters) -> HistogramDuration; + + /// Create a histogram duration with in-memory tracking for worker heartbeating reporting + fn histogram_duration_with_in_memory( + &self, + params: MetricParameters, + in_memory_hist: HeartbeatMetricType, + ) -> HistogramDuration { + let primary_hist = self.histogram_duration(params); + + HistogramDuration::new_with_in_memory(primary_hist.primary.metric.clone(), in_memory_hist) + } fn gauge(&self, params: MetricParameters) -> Gauge; + + /// Create a gauge with in-memory tracking for worker heartbeating reporting + fn gauge_with_in_memory( + &self, + params: MetricParameters, + in_memory_metrics: HeartbeatMetricType, + ) -> Gauge { + let primary_gauge = self.gauge(params.clone()); + Gauge::new_with_in_memory(primary_gauge.primary.metric.clone(), in_memory_metrics) + } + fn gauge_f64(&self, params: MetricParameters) -> GaugeF64; } +/// Provides a generic way to record metrics in memory. +/// This can be done either with individual metrics or more fine-grained metrics +/// that vary by a set of labels for the same metric. +#[derive(Clone, Debug)] +pub enum HeartbeatMetricType { + Individual(Arc), + WithLabel { + label_key: String, + metrics: HashMap>, + }, +} + +impl HeartbeatMetricType { + fn record_counter(&self, delta: u64) { + match self { + HeartbeatMetricType::Individual(metric) => { + metric.fetch_add(delta, Ordering::Relaxed); + } + HeartbeatMetricType::WithLabel { .. } => { + dbg_panic!("Counter does not support in-memory metric with labels"); + } + } + } + + fn record_histogram_observation(&self) { + match self { + HeartbeatMetricType::Individual(metric) => { + metric.fetch_add(1, Ordering::Relaxed); + } + HeartbeatMetricType::WithLabel { .. } => { + dbg_panic!("Histogram does not support in-memory metric with labels"); + } + } + } + + fn record_gauge(&self, value: u64, attributes: &MetricAttributes) { + match self { + HeartbeatMetricType::Individual(metric) => { + metric.store(value, Ordering::Relaxed); + } + HeartbeatMetricType::WithLabel { label_key, metrics } => { + if let Some(metric) = label_value_from_attributes(attributes, label_key.as_str()) + .and_then(|label_value| metrics.get(label_value.as_str())) + { + metric.store(value, Ordering::Relaxed) + } + } + } + } +} + +fn label_value_from_attributes(attributes: &MetricAttributes, key: &str) -> Option { + match attributes { + MetricAttributes::Prometheus { labels } => labels.as_prom_labels().get(key).cloned(), + #[cfg(feature = "otel_impls")] + MetricAttributes::OTel { kvs } => kvs + .iter() + .find(|kv| kv.key.as_str() == key) + .map(|kv| kv.value.to_string()), + MetricAttributes::NoOp(labels) => labels.get(key).cloned(), + _ => None, + } +} + +#[derive(Default, Debug)] +pub struct NumPollersMetric { + pub wft_current_pollers: Arc, + pub sticky_wft_current_pollers: Arc, + pub activity_current_pollers: Arc, + pub nexus_current_pollers: Arc, +} + +impl NumPollersMetric { + pub fn as_map(&self) -> HashMap> { + HashMap::from([ + ( + "workflow_task".to_string(), + self.wft_current_pollers.clone(), + ), + ( + "sticky_workflow_task".to_string(), + self.sticky_wft_current_pollers.clone(), + ), + ( + "activity_task".to_string(), + self.activity_current_pollers.clone(), + ), + ("nexus_task".to_string(), self.nexus_current_pollers.clone()), + ]) + } +} + +#[derive(Default, Debug)] +pub struct SlotMetrics { + pub workflow_worker: Arc, + pub activity_worker: Arc, + pub nexus_worker: Arc, + pub local_activity_worker: Arc, +} + +impl SlotMetrics { + pub fn as_map(&self) -> HashMap> { + HashMap::from([ + ("WorkflowWorker".to_string(), self.workflow_worker.clone()), + ("ActivityWorker".to_string(), self.activity_worker.clone()), + ("NexusWorker".to_string(), self.nexus_worker.clone()), + ( + "LocalActivityWorker".to_string(), + self.local_activity_worker.clone(), + ), + ]) + } +} + +#[derive(Default, Debug)] +pub struct WorkerHeartbeatMetrics { + pub sticky_cache_size: Arc, + pub total_sticky_cache_hit: Arc, + pub total_sticky_cache_miss: Arc, + pub num_pollers: NumPollersMetric, + pub worker_task_slots_used: SlotMetrics, + pub worker_task_slots_available: SlotMetrics, + pub workflow_task_execution_failed: Arc, + pub activity_execution_failed: Arc, + pub nexus_task_execution_failed: Arc, + pub local_activity_execution_failed: Arc, + pub activity_execution_latency: Arc, + pub local_activity_execution_latency: Arc, + pub workflow_task_execution_latency: Arc, + pub nexus_task_execution_latency: Arc, +} + +impl WorkerHeartbeatMetrics { + pub fn get_metric(&self, name: &str) -> Option { + match name { + "sticky_cache_size" => Some(HeartbeatMetricType::Individual( + self.sticky_cache_size.clone(), + )), + "sticky_cache_hit" => Some(HeartbeatMetricType::Individual( + self.total_sticky_cache_hit.clone(), + )), + "sticky_cache_miss" => Some(HeartbeatMetricType::Individual( + self.total_sticky_cache_miss.clone(), + )), + "num_pollers" => Some(HeartbeatMetricType::WithLabel { + label_key: "poller_type".to_string(), + metrics: self.num_pollers.as_map(), + }), + "worker_task_slots_used" => Some(HeartbeatMetricType::WithLabel { + label_key: "worker_type".to_string(), + metrics: self.worker_task_slots_used.as_map(), + }), + "worker_task_slots_available" => Some(HeartbeatMetricType::WithLabel { + label_key: "worker_type".to_string(), + metrics: self.worker_task_slots_available.as_map(), + }), + "workflow_task_execution_failed" => Some(HeartbeatMetricType::Individual( + self.workflow_task_execution_failed.clone(), + )), + "activity_execution_failed" => Some(HeartbeatMetricType::Individual( + self.activity_execution_failed.clone(), + )), + "nexus_task_execution_failed" => Some(HeartbeatMetricType::Individual( + self.nexus_task_execution_failed.clone(), + )), + "local_activity_execution_failed" => Some(HeartbeatMetricType::Individual( + self.local_activity_execution_failed.clone(), + )), + "activity_execution_latency" => Some(HeartbeatMetricType::Individual( + self.activity_execution_latency.clone(), + )), + "local_activity_execution_latency" => Some(HeartbeatMetricType::Individual( + self.local_activity_execution_latency.clone(), + )), + "workflow_task_execution_latency" => Some(HeartbeatMetricType::Individual( + self.workflow_task_execution_latency.clone(), + )), + "nexus_task_execution_latency" => Some(HeartbeatMetricType::Individual( + self.nexus_task_execution_latency.clone(), + )), + _ => None, + } + } +} + #[derive(Debug, Clone, derive_builder::Builder)] pub struct MetricParameters { /// The name for the new metric/instrument @@ -124,6 +344,7 @@ pub enum MetricAttributes { }, Buffer(BufferAttributes), Dynamic(Arc), + NoOp(Arc>), Empty, } @@ -155,6 +376,16 @@ where } } +impl From for HashMap { + fn from(value: NewAttributes) -> Self { + value + .attributes + .into_iter() + .map(|kv| (kv.key, kv.value.to_string())) + .collect() + } +} + /// A K/V pair that can be used to label a specific recording of a metric #[derive(Clone, Debug, PartialEq)] pub struct MetricKeyValue { @@ -227,43 +458,79 @@ impl LazyBoundMetric { pub trait CounterBase: Send + Sync { fn adds(&self, value: u64); } -pub type Counter = LazyBoundMetric< + +pub type CounterImpl = LazyBoundMetric< Arc> + Send + Sync>, Arc, >; + +#[derive(Clone)] +pub struct Counter { + primary: CounterImpl, + in_memory: Option, +} impl Counter { pub fn new(inner: Arc> + Send + Sync>) -> Self { Self { - metric: inner, - attributes: MetricAttributes::Empty, - bound_cache: OnceLock::new(), + primary: LazyBoundMetric { + metric: inner, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: None, + } + } + + pub fn new_with_in_memory( + primary: Arc> + Send + Sync>, + in_memory: HeartbeatMetricType, + ) -> Self { + Self { + primary: LazyBoundMetric { + metric: primary, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: Some(in_memory), } } + pub fn add(&self, value: u64, attributes: &MetricAttributes) { - match self.metric.with_attributes(attributes) { - Ok(base) => { - base.adds(value); - } + match self.primary.metric.with_attributes(attributes) { + Ok(base) => base.adds(value), Err(e) => { - dbg_panic!("Failed to initialize metric, will drop values: {e:?}",); + dbg_panic!("Failed to initialize primary metric, will drop values: {e:?}"); } } + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_counter(value); + } + } + + pub fn update_attributes(&mut self, new_attributes: MetricAttributes) { + self.primary.update_attributes(new_attributes.clone()); } } impl CounterBase for Counter { fn adds(&self, value: u64) { // TODO: Replace all of these with below when stable // https://doc.rust-lang.org/std/sync/struct.OnceLock.html#method.get_or_try_init - let bound = self.bound_cache.get_or_init(|| { - self.metric - .with_attributes(&self.attributes) + let bound = self.primary.bound_cache.get_or_init(|| { + self.primary + .metric + .with_attributes(&self.primary.attributes) .map(Into::into) .unwrap_or_else(|e| { - dbg_panic!("Failed to initialize metric, will drop values: {e:?}"); + dbg_panic!("Failed to initialize primary metric, will drop values: {e:?}"); Arc::new(NoOpInstrument) as Arc }) }); bound.adds(value); + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_counter(value); + } } } impl MetricAttributable for Counter { @@ -271,10 +538,15 @@ impl MetricAttributable for Counter { &self, attributes: &MetricAttributes, ) -> Result> { - Ok(Self { - metric: self.metric.clone(), + let primary = LazyBoundMetric { + metric: self.primary.metric.clone(), attributes: attributes.clone(), bound_cache: OnceLock::new(), + }; + + Ok(Counter { + primary, + in_memory: self.in_memory.clone(), }) } } @@ -390,22 +662,45 @@ impl MetricAttributable for HistogramF64 { pub trait HistogramDurationBase: Send + Sync { fn records(&self, value: Duration); } -pub type HistogramDuration = LazyBoundMetric< + +pub type HistogramDurationImpl = LazyBoundMetric< Arc> + Send + Sync>, Arc, >; + +#[derive(Clone)] +pub struct HistogramDuration { + primary: HistogramDurationImpl, + in_memory: Option, +} impl HistogramDuration { pub fn new( inner: Arc> + Send + Sync>, ) -> Self { Self { - metric: inner, - attributes: MetricAttributes::Empty, - bound_cache: OnceLock::new(), + primary: LazyBoundMetric { + metric: inner, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: None, + } + } + pub fn new_with_in_memory( + primary: Arc> + Send + Sync>, + in_memory: HeartbeatMetricType, + ) -> Self { + Self { + primary: LazyBoundMetric { + metric: primary, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: Some(in_memory), } } pub fn record(&self, value: Duration, attributes: &MetricAttributes) { - match self.metric.with_attributes(attributes) { + match self.primary.metric.with_attributes(attributes) { Ok(base) => { base.records(value); } @@ -413,13 +708,22 @@ impl HistogramDuration { dbg_panic!("Failed to initialize metric, will drop values: {e:?}",); } } + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_histogram_observation(); + } + } + + pub fn update_attributes(&mut self, new_attributes: MetricAttributes) { + self.primary.update_attributes(new_attributes.clone()); } } impl HistogramDurationBase for HistogramDuration { fn records(&self, value: Duration) { - let bound = self.bound_cache.get_or_init(|| { - self.metric - .with_attributes(&self.attributes) + let bound = self.primary.bound_cache.get_or_init(|| { + self.primary + .metric + .with_attributes(&self.primary.attributes) .map(Into::into) .unwrap_or_else(|e| { dbg_panic!("Failed to initialize metric, will drop values: {e:?}"); @@ -427,6 +731,10 @@ impl HistogramDurationBase for HistogramDuration { }) }); bound.records(value); + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_histogram_observation(); + } } } impl MetricAttributable for HistogramDuration { @@ -434,10 +742,15 @@ impl MetricAttributable for HistogramDuration { &self, attributes: &MetricAttributes, ) -> Result> { - Ok(Self { - metric: self.metric.clone(), + let primary = LazyBoundMetric { + metric: self.primary.metric.clone(), attributes: attributes.clone(), bound_cache: OnceLock::new(), + }; + + Ok(HistogramDuration { + primary, + in_memory: self.in_memory.clone(), }) } } @@ -445,41 +758,77 @@ impl MetricAttributable for HistogramDuration { pub trait GaugeBase: Send + Sync { fn records(&self, value: u64); } -pub type Gauge = LazyBoundMetric< + +pub type GaugeImpl = LazyBoundMetric< Arc> + Send + Sync>, Arc, >; + +#[derive(Clone)] +pub struct Gauge { + primary: GaugeImpl, + in_memory: Option, +} impl Gauge { pub fn new(inner: Arc> + Send + Sync>) -> Self { Self { - metric: inner, - attributes: MetricAttributes::Empty, - bound_cache: OnceLock::new(), + primary: LazyBoundMetric { + metric: inner, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: None, + } + } + + pub fn new_with_in_memory( + primary: Arc> + Send + Sync>, + in_memory: HeartbeatMetricType, + ) -> Self { + Self { + primary: LazyBoundMetric { + metric: primary, + attributes: MetricAttributes::Empty, + bound_cache: OnceLock::new(), + }, + in_memory: Some(in_memory), } } + pub fn record(&self, value: u64, attributes: &MetricAttributes) { - match self.metric.with_attributes(attributes) { - Ok(base) => { - base.records(value); - } + match self.primary.metric.with_attributes(attributes) { + Ok(base) => base.records(value), Err(e) => { - dbg_panic!("Failed to initialize metric, will drop values: {e:?}",); + dbg_panic!("Failed to initialize primary metric, will drop values: {e:?}"); } } + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_gauge(value, attributes); + } + } + + pub fn update_attributes(&mut self, new_attributes: MetricAttributes) { + self.primary.update_attributes(new_attributes.clone()); } } impl GaugeBase for Gauge { fn records(&self, value: u64) { - let bound = self.bound_cache.get_or_init(|| { - self.metric - .with_attributes(&self.attributes) + let bound = self.primary.bound_cache.get_or_init(|| { + self.primary + .metric + .with_attributes(&self.primary.attributes) .map(Into::into) .unwrap_or_else(|e| { - dbg_panic!("Failed to initialize metric, will drop values: {e:?}"); + dbg_panic!("Failed to initialize primary metric, will drop values: {e:?}"); Arc::new(NoOpInstrument) as Arc }) }); bound.records(value); + + if let Some(ref in_mem) = self.in_memory { + in_mem.record_gauge(value, &self.primary.attributes); + } } } impl MetricAttributable for Gauge { @@ -487,10 +836,15 @@ impl MetricAttributable for Gauge { &self, attributes: &MetricAttributes, ) -> Result> { - Ok(Self { - metric: self.metric.clone(), + let primary = LazyBoundMetric { + metric: self.primary.metric.clone(), attributes: attributes.clone(), bound_cache: OnceLock::new(), + }; + + Ok(Gauge { + primary, + in_memory: self.in_memory.clone(), }) } } @@ -633,12 +987,23 @@ impl LazyRef { #[derive(Debug)] pub struct NoOpCoreMeter; impl CoreMeter for NoOpCoreMeter { - fn new_attributes(&self, _: NewAttributes) -> MetricAttributes { - MetricAttributes::Dynamic(Arc::new(NoOpAttributes)) + fn new_attributes(&self, attribs: NewAttributes) -> MetricAttributes { + MetricAttributes::NoOp(Arc::new(attribs.into())) } - fn extend_attributes(&self, existing: MetricAttributes, _: NewAttributes) -> MetricAttributes { - existing + fn extend_attributes( + &self, + existing: MetricAttributes, + attribs: NewAttributes, + ) -> MetricAttributes { + if let MetricAttributes::NoOp(labels) = existing { + let mut labels = (*labels).clone(); + labels.extend::>(attribs.into()); + MetricAttributes::NoOp(Arc::new(labels)) + } else { + dbg_panic!("Must use NoOp attributes with a NoOp metric implementation"); + existing + } } fn counter(&self, _: MetricParameters) -> Counter { @@ -701,11 +1066,41 @@ impl_no_op!(HistogramDurationBase, Duration); impl_no_op!(GaugeBase, u64); impl_no_op!(GaugeF64Base, f64); -#[derive(Debug, Clone)] -pub struct NoOpAttributes; -impl CustomMetricAttributes for NoOpAttributes { - fn as_any(self: Arc) -> Arc { - self as Arc +#[cfg(test)] +mod tests { + use super::*; + use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + }; + + #[test] + fn in_memory_attributes_provide_label_values() { + let meter = NoOpCoreMeter; + let base_attrs = meter.new_attributes(NewAttributes::default()); + let attrs = meter.extend_attributes( + base_attrs, + NewAttributes::from(vec![MetricKeyValue::new("poller_type", "workflow_task")]), + ); + + let value = Arc::new(AtomicU64::new(0)); + let mut metrics = HashMap::new(); + metrics.insert("workflow_task".to_string(), value.clone()); + let heartbeat_metric = HeartbeatMetricType::WithLabel { + label_key: "poller_type".to_string(), + metrics, + }; + + heartbeat_metric.record_gauge(3, &attrs); + + assert_eq!(value.load(Ordering::Relaxed), 3); + assert_eq!( + label_value_from_attributes(&attrs, "poller_type").as_deref(), + Some("workflow_task") + ); } } diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index d92efeec0..1a96a6645 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -11,6 +11,7 @@ use temporal_sdk_core_protos::{ coresdk::{ActivitySlotInfo, LocalActivitySlotInfo, NexusSlotInfo, WorkflowSlotInfo}, temporal, temporal::api::enums::v1::VersioningBehavior, + temporal::api::worker::v1::PluginInfo, }; /// Defines per-worker configuration options @@ -141,19 +142,19 @@ pub struct WorkerConfig { /// Mutually exclusive with `tuner` #[builder(setter(into, strip_option), default)] pub max_outstanding_workflow_tasks: Option, - /// The maximum number of activity tasks that will ever be given to this worker concurrently + /// The maximum number of activity tasks that will ever be given to this worker concurrently. /// /// Mutually exclusive with `tuner` #[builder(setter(into, strip_option), default)] pub max_outstanding_activities: Option, /// The maximum number of local activity tasks that will ever be given to this worker - /// concurrently + /// concurrently. /// /// Mutually exclusive with `tuner` #[builder(setter(into, strip_option), default)] pub max_outstanding_local_activities: Option, /// The maximum number of nexus tasks that will ever be given to this worker - /// concurrently + /// concurrently. /// /// Mutually exclusive with `tuner` #[builder(setter(into, strip_option), default)] @@ -161,6 +162,14 @@ pub struct WorkerConfig { /// A versioning strategy for this worker. pub versioning_strategy: WorkerVersioningStrategy, + + /// List of plugins used by lang. + #[builder(default)] + pub plugins: Vec, + + /// Skips the single worker+client+namespace+task_queue check + #[builder(default = "false")] + pub skip_client_worker_set_check: bool, } impl WorkerConfig { @@ -357,6 +366,12 @@ pub trait SlotSupplier { fn available_slots(&self) -> Option { None } + + /// Returns a human-friendly identifier describing this supplier implementation for + /// diagnostics and telemetry. + fn slot_supplier_kind(&self) -> String { + "Custom".to_string() + } } pub trait SlotReservationContext: Send + Sync { diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index ccdd660fd..28ba7799d 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -571,6 +571,7 @@ async fn call_workflow_service( "DescribeNamespace" => rpc_call!(client, call, describe_namespace), "DescribeSchedule" => rpc_call!(client, call, describe_schedule), "DescribeTaskQueue" => rpc_call!(client, call, describe_task_queue), + "DescribeWorker" => rpc_call!(client, call, describe_worker), "DescribeWorkerDeployment" => rpc_call!(client, call, describe_worker_deployment), "DescribeWorkerDeploymentVersion" => { rpc_call!(client, call, describe_worker_deployment_version) @@ -651,6 +652,9 @@ async fn call_workflow_service( "SetWorkerDeploymentCurrentVersion" => { rpc_call!(client, call, set_worker_deployment_current_version) } + "SetWorkerDeploymentManager" => { + rpc_call!(client, call, set_worker_deployment_manager) + } "SetWorkerDeploymentRampingVersion" => { rpc_call!(client, call, set_worker_deployment_ramping_version) } diff --git a/core/src/abstractions.rs b/core/src/abstractions.rs index d4b86cb35..0d5a53206 100644 --- a/core/src/abstractions.rs +++ b/core/src/abstractions.rs @@ -25,6 +25,7 @@ use tokio_util::sync::CancellationToken; #[derive(Clone)] pub(crate) struct MeteredPermitDealer { supplier: Arc + Send + Sync>, + slot_supplier_kind: SlotSupplierKind, /// The number of permit owners who have acquired a permit, but are not yet meaningfully using /// that permit. This is useful for giving a more semantically accurate count of used task /// slots, since we typically wait for a permit first before polling, but that slot isn't used @@ -54,6 +55,35 @@ pub(crate) struct PermitDealerContextData { pub(crate) worker_deployment_version: Option, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum SlotSupplierKind { + Fixed, + ResourceBased, + Custom(String), +} + +impl SlotSupplierKind { + fn from_label(label: &str) -> Self { + if label == "Fixed" { + SlotSupplierKind::Fixed + } else if label == "ResourceBased" { + SlotSupplierKind::ResourceBased + } else { + SlotSupplierKind::Custom(label.to_string()) + } + } +} + +impl std::fmt::Display for SlotSupplierKind { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SlotSupplierKind::Fixed => f.write_str("Fixed"), + SlotSupplierKind::ResourceBased => f.write_str("ResourceBased"), + SlotSupplierKind::Custom(name) => f.write_str(name.as_str()), + } + } +} + impl MeteredPermitDealer where SK: SlotKind + 'static, @@ -65,8 +95,11 @@ where context_data: Arc, meter: Option, ) -> Self { + let supplier_kind_label = supplier.slot_supplier_kind(); + let slot_supplier_kind = SlotSupplierKind::from_label(supplier_kind_label.as_ref()); Self { supplier, + slot_supplier_kind, unused_claimants: Arc::new(AtomicUsize::new(0)), extant_permits: watch::channel(0), metrics_ctx, @@ -81,6 +114,10 @@ where self.supplier.available_slots() } + pub(crate) fn slot_supplier_kind(&self) -> &SlotSupplierKind { + &self.slot_supplier_kind + } + #[cfg(test)] pub(crate) fn unused_permits(&self) -> Option { self.available_permits() @@ -492,4 +529,10 @@ pub(crate) mod tests { // Now it'll proceed acquire_fut.await; } + + #[test] + fn captures_slot_supplier_kind() { + let dealer = fixed_size_permit_dealer::(1); + assert_eq!(*dealer.slot_supplier_kind(), SlotSupplierKind::Fixed); + } } diff --git a/core/src/core_tests/workers.rs b/core/src/core_tests/workers.rs index f5288a442..68314c082 100644 --- a/core/src/core_tests/workers.rs +++ b/core/src/core_tests/workers.rs @@ -321,12 +321,12 @@ async fn worker_shutdown_api(#[case] use_cache: bool, #[case] api_success: bool) if api_success { mock.expect_shutdown_worker() .times(1) - .returning(|_| Ok(ShutdownWorkerResponse {})); + .returning(|_, _| Ok(ShutdownWorkerResponse {})); } else { // worker.shutdown() should succeed even if shutdown_worker fails mock.expect_shutdown_worker() .times(1) - .returning(|_| Err(tonic::Status::unavailable("fake shutdown error"))); + .returning(|_, _| Err(tonic::Status::unavailable("fake shutdown error"))); } } else { mock.expect_shutdown_worker().times(0); diff --git a/core/src/core_tests/workflow_tasks.rs b/core/src/core_tests/workflow_tasks.rs index e5550938b..8d5df2b7a 100644 --- a/core/src/core_tests/workflow_tasks.rs +++ b/core/src/core_tests/workflow_tasks.rs @@ -2996,7 +2996,6 @@ async fn both_normal_and_sticky_pollers_poll_concurrently() { Arc::new(mock_client), None, None, - false, ) .unwrap(); diff --git a/core/src/lib.rs b/core/src/lib.rs index 3995b7562..a73eb43ad 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -41,10 +41,9 @@ pub use temporal_sdk_core_protos as protos; pub use temporal_sdk_core_protos::TaskToken; pub use url::Url; pub use worker::{ - FixedSizeSlotSupplier, RealSysInfo, ResourceBasedSlotsOptions, - ResourceBasedSlotsOptionsBuilder, ResourceBasedTuner, ResourceSlotOptions, SlotSupplierOptions, - TunerBuilder, TunerHolder, TunerHolderOptions, TunerHolderOptionsBuilder, Worker, WorkerConfig, - WorkerConfigBuilder, + FixedSizeSlotSupplier, ResourceBasedSlotsOptions, ResourceBasedSlotsOptionsBuilder, + ResourceBasedTuner, ResourceSlotOptions, SlotSupplierOptions, TunerBuilder, TunerHolder, + TunerHolderOptions, TunerHolderOptionsBuilder, Worker, WorkerConfig, WorkerConfigBuilder, }; /// Expose [WorkerClient] symbols @@ -123,7 +122,6 @@ where client_bag.clone(), Some(&runtime.telemetry), runtime.heartbeat_interval, - false, ) } diff --git a/core/src/pollers/poll_buffer.rs b/core/src/pollers/poll_buffer.rs index 7bc4311fb..786b195e4 100644 --- a/core/src/pollers/poll_buffer.rs +++ b/core/src/pollers/poll_buffer.rs @@ -6,8 +6,10 @@ use crate::{ client::{PollActivityOptions, PollOptions, PollWorkflowOptions, WorkerClient}, }, }; +use crossbeam_utils::atomic::AtomicCell; use futures_util::{FutureExt, StreamExt, future::BoxFuture}; use governor::{Quota, RateLimiter}; +use std::time::SystemTime; use std::{ cmp, fmt::Debug, @@ -74,9 +76,15 @@ impl LongPollBuffer { shutdown: CancellationToken, num_pollers_handler: Option, options: WorkflowTaskOptions, + last_successful_poll_time: Arc>>, ) -> Self { let is_sticky = sticky_queue.is_some(); - let poll_scaler = PollScaler::new(poller_behavior, num_pollers_handler, shutdown.clone()); + let poll_scaler = PollScaler::new( + poller_behavior, + num_pollers_handler, + shutdown.clone(), + last_successful_poll_time, + ); if let Some(wftps) = options.wft_poller_shared.as_ref() { if is_sticky { wftps.set_sticky_active(poll_scaler.active_rx.clone()); @@ -136,6 +144,7 @@ impl LongPollBuffer { } impl LongPollBuffer { + #[allow(clippy::too_many_arguments)] pub(crate) fn new_activity_task( client: Arc, task_queue: String, @@ -144,6 +153,7 @@ impl LongPollBuffer { shutdown: CancellationToken, num_pollers_handler: Option, options: ActivityTaskOptions, + last_successful_poll_time: Arc>>, ) -> Self { let pre_permit_delay = options .max_worker_acts_per_second @@ -183,7 +193,12 @@ impl LongPollBuffer { } }; - let poll_scaler = PollScaler::new(poller_behavior, num_pollers_handler, shutdown.clone()); + let poll_scaler = PollScaler::new( + poller_behavior, + num_pollers_handler, + shutdown.clone(), + last_successful_poll_time, + ); Self::new( poll_fn, permit_dealer, @@ -196,6 +211,7 @@ impl LongPollBuffer { } impl LongPollBuffer { + #[allow(clippy::too_many_arguments)] pub(crate) fn new_nexus_task( client: Arc, task_queue: String, @@ -203,6 +219,7 @@ impl LongPollBuffer { permit_dealer: MeteredPermitDealer, shutdown: CancellationToken, num_pollers_handler: Option, + last_successful_poll_time: Arc>>, send_heartbeat: bool, ) -> Self { let no_retry = if matches!(poller_behavior, PollerBehavior::Autoscaling { .. }) { @@ -232,7 +249,12 @@ impl LongPollBuffer { poll_fn, permit_dealer, shutdown.clone(), - PollScaler::new(poller_behavior, num_pollers_handler, shutdown), + PollScaler::new( + poller_behavior, + num_pollers_handler, + shutdown, + last_successful_poll_time, + ), None:: BoxFuture<'static, ()>>, None::, ) @@ -417,6 +439,7 @@ where behavior: PollerBehavior, num_pollers_handler: Option, shutdown: CancellationToken, + last_successful_poll_time: Arc>>, ) -> Self { let (active_tx, active_rx) = watch::channel(0); let num_pollers_handler = num_pollers_handler.map(Arc::new); @@ -437,6 +460,7 @@ where ingested_this_period: Default::default(), ingested_last_period: Default::default(), scale_up_allowed: AtomicBool::new(true), + last_successful_poll_time, }); let rhc = report_handle.clone(); let ingestor_task = if behavior.is_autoscaling() { @@ -499,6 +523,7 @@ struct PollScalerReportHandle { ingested_this_period: AtomicUsize, ingested_last_period: AtomicUsize, scale_up_allowed: AtomicBool, + last_successful_poll_time: Arc>>, } impl PollScalerReportHandle { @@ -506,6 +531,8 @@ impl PollScalerReportHandle { fn poll_result(&self, res: &Result) -> bool { match res { Ok(res) => { + self.last_successful_poll_time + .store(Some(SystemTime::now())); if let PollerBehavior::SimpleMaximum(_) = self.behavior { // We don't do auto-scaling with the simple max return true; @@ -739,6 +766,7 @@ mod tests { WorkflowTaskOptions { wft_poller_shared: Some(Arc::new(WFTPollerShared::new(Some(10)))), }, + Arc::new(AtomicCell::new(None)), ); // Poll a bunch of times, "interrupting" it each time, we should only actually have polled @@ -794,6 +822,7 @@ mod tests { WorkflowTaskOptions { wft_poller_shared: Some(Arc::new(WFTPollerShared::new(Some(1)))), }, + Arc::new(AtomicCell::new(None)), ); // Should not see error, unwraps should get empty response diff --git a/core/src/replay/mod.rs b/core/src/replay/mod.rs index 1e4990000..03f0003be 100644 --- a/core/src/replay/mod.rs +++ b/core/src/replay/mod.rs @@ -114,7 +114,7 @@ where hist_allow_tx.send("Failed".to_string()).unwrap(); async move { Ok(RespondWorkflowTaskFailedResponse::default()) }.boxed() }); - let mut worker = Worker::new(self.config, None, Arc::new(client), None, None, false)?; + let mut worker = Worker::new(self.config, None, Arc::new(client), None, None)?; worker.set_post_activate_hook(post_activate); shutdown_tok(worker.shutdown_token()); Ok(worker) diff --git a/core/src/telemetry/metrics.rs b/core/src/telemetry/metrics.rs index fe6aec8e1..d39bf02b0 100644 --- a/core/src/telemetry/metrics.rs +++ b/core/src/telemetry/metrics.rs @@ -13,7 +13,7 @@ use temporal_sdk_core_api::telemetry::metrics::{ GaugeF64, GaugeF64Base, Histogram, HistogramBase, HistogramDuration, HistogramDurationBase, HistogramF64, HistogramF64Base, LazyBufferInstrument, MetricAttributable, MetricAttributes, MetricCallBufferer, MetricEvent, MetricKeyValue, MetricKind, MetricParameters, MetricUpdateVal, - NewAttributes, NoOpCoreMeter, TemporalMeter, + NewAttributes, NoOpCoreMeter, TemporalMeter, WorkerHeartbeatMetrics, }; use temporal_sdk_core_protos::temporal::api::{ enums::v1::WorkflowTaskFailedCause, failure::v1::Failure, @@ -25,6 +25,7 @@ pub(crate) struct MetricsContext { meter: Arc, kvs: MetricAttributes, instruments: Arc, + in_memory_metrics: Option>, } #[derive(Clone)] @@ -70,11 +71,13 @@ impl MetricsContext { pub(crate) fn no_op() -> Self { let meter = Arc::new(NoOpCoreMeter); let kvs = meter.new_attributes(Default::default()); - let instruments = Arc::new(Instruments::new(meter.as_ref())); + let in_memory_metrics = Some(Arc::new(WorkerHeartbeatMetrics::default())); + let instruments = Arc::new(Instruments::new(meter.as_ref(), in_memory_metrics.clone())); Self { kvs, instruments, meter, + in_memory_metrics, } } @@ -95,12 +98,14 @@ impl MetricsContext { .push(MetricKeyValue::new(KEY_NAMESPACE, namespace)); meter.default_attribs.attributes.push(task_queue(tq)); let kvs = meter.inner.new_attributes(meter.default_attribs); - let mut instruments = Instruments::new(meter.inner.as_ref()); + let in_memory_metrics = Some(Arc::new(WorkerHeartbeatMetrics::default())); + let mut instruments = Instruments::new(meter.inner.as_ref(), in_memory_metrics.clone()); instruments.update_attributes(&kvs); Self { kvs, instruments: Arc::new(instruments), meter: meter.inner, + in_memory_metrics, } } else { Self::no_op() @@ -121,9 +126,14 @@ impl MetricsContext { instruments: Arc::new(instruments), kvs, meter: self.meter.clone(), + in_memory_metrics: self.in_memory_metrics.clone(), } } + pub(crate) fn in_memory_meter(&self) -> Option> { + self.in_memory_metrics.clone() + } + /// A workflow task queue poll succeeded pub(crate) fn wf_tq_poll_ok(&self) { self.instruments.wf_task_queue_poll_succeed_counter.adds(1); @@ -299,7 +309,31 @@ impl MetricsContext { } impl Instruments { - fn new(meter: &dyn CoreMeter) -> Self { + fn new(meter: &dyn CoreMeter, in_memory: Option>) -> Self { + let counter_with_in_mem = |params: MetricParameters| -> Counter { + in_memory + .clone() + .and_then(|in_mem| in_mem.get_metric(¶ms.name)) + .map(|metric| meter.counter_with_in_memory(params.clone(), metric)) + .unwrap_or_else(|| meter.counter(params)) + }; + + let gauge_with_in_mem = |params: MetricParameters| -> Gauge { + in_memory + .clone() + .and_then(|in_mem| in_mem.get_metric(¶ms.name)) + .map(|metric| meter.gauge_with_in_memory(params.clone(), metric)) + .unwrap_or_else(|| meter.gauge(params)) + }; + + let histogram_with_in_mem = |params: MetricParameters| -> HistogramDuration { + in_memory + .clone() + .and_then(|in_mem| in_mem.get_metric(¶ms.name)) + .map(|metric| meter.histogram_duration_with_in_memory(params.clone(), metric)) + .unwrap_or_else(|| meter.histogram_duration(params)) + }; + Self { wf_completed_counter: meter.counter(MetricParameters { name: "workflow_completed".into(), @@ -331,12 +365,12 @@ impl Instruments { description: "Count of workflow task queue poll timeouts (no new task)".into(), unit: "".into(), }), - wf_task_queue_poll_succeed_counter: meter.counter(MetricParameters { + wf_task_queue_poll_succeed_counter: counter_with_in_mem(MetricParameters { name: "workflow_task_queue_poll_succeed".into(), description: "Count of workflow task queue poll successes".into(), unit: "".into(), }), - wf_task_execution_failure_counter: meter.counter(MetricParameters { + wf_task_execution_failure_counter: counter_with_in_mem(MetricParameters { name: "workflow_task_execution_failed".into(), description: "Count of workflow task execution failures".into(), unit: "".into(), @@ -351,7 +385,7 @@ impl Instruments { unit: "duration".into(), description: "Histogram of workflow task replay latencies".into(), }), - wf_task_execution_latency: meter.histogram_duration(MetricParameters { + wf_task_execution_latency: histogram_with_in_mem(MetricParameters { name: WORKFLOW_TASK_EXECUTION_LATENCY_HISTOGRAM_NAME.into(), unit: "duration".into(), description: "Histogram of workflow task execution (not replay) latencies".into(), @@ -361,12 +395,12 @@ impl Instruments { description: "Count of activity task queue poll timeouts (no new task)".into(), unit: "".into(), }), - act_task_received_counter: meter.counter(MetricParameters { + act_task_received_counter: counter_with_in_mem(MetricParameters { name: "activity_task_received".into(), description: "Count of activity task queue poll successes".into(), unit: "".into(), }), - act_execution_failed: meter.counter(MetricParameters { + act_execution_failed: counter_with_in_mem(MetricParameters { name: "activity_execution_failed".into(), description: "Count of activity task execution failures".into(), unit: "".into(), @@ -376,7 +410,7 @@ impl Instruments { unit: "duration".into(), description: "Histogram of activity schedule-to-start latencies".into(), }), - act_exec_latency: meter.histogram_duration(MetricParameters { + act_exec_latency: histogram_with_in_mem(MetricParameters { name: ACTIVITY_EXEC_LATENCY_HISTOGRAM_NAME.into(), unit: "duration".into(), description: "Histogram of activity execution latencies".into(), @@ -397,7 +431,7 @@ impl Instruments { description: "Count of local activity executions that failed".into(), unit: "".into(), }), - la_exec_latency: meter.histogram_duration(MetricParameters { + la_exec_latency: histogram_with_in_mem(MetricParameters { name: "local_activity_execution_latency".into(), unit: "duration".into(), description: "Histogram of local activity execution latencies".into(), @@ -409,7 +443,7 @@ impl Instruments { "Histogram of local activity execution latencies for successful local activities" .into(), }), - la_total: meter.counter(MetricParameters { + la_total: counter_with_in_mem(MetricParameters { name: "local_activity_total".into(), description: "Count of local activities executed".into(), unit: "".into(), @@ -429,12 +463,12 @@ impl Instruments { unit: "duration".into(), description: "Histogram of nexus task end-to-end latencies".into(), }), - nexus_task_execution_latency: meter.histogram_duration(MetricParameters { + nexus_task_execution_latency: histogram_with_in_mem(MetricParameters { name: "nexus_task_execution_latency".into(), unit: "duration".into(), description: "Histogram of nexus task execution latencies".into(), }), - nexus_task_execution_failed: meter.counter(MetricParameters { + nexus_task_execution_failed: counter_with_in_mem(MetricParameters { name: "nexus_task_execution_failed".into(), description: "Count of nexus task execution failures".into(), unit: "".into(), @@ -445,35 +479,34 @@ impl Instruments { description: "Count of the number of initialized workers".into(), unit: "".into(), }), - num_pollers: meter.gauge(MetricParameters { + num_pollers: gauge_with_in_mem(MetricParameters { name: NUM_POLLERS_NAME.into(), description: "Current number of active pollers per queue type".into(), unit: "".into(), }), - task_slots_available: meter.gauge(MetricParameters { + task_slots_available: gauge_with_in_mem(MetricParameters { name: TASK_SLOTS_AVAILABLE_NAME.into(), description: "Current number of available slots per task type".into(), unit: "".into(), }), - task_slots_used: meter.gauge(MetricParameters { + task_slots_used: gauge_with_in_mem(MetricParameters { name: TASK_SLOTS_USED_NAME.into(), description: "Current number of used slots per task type".into(), unit: "".into(), }), - sticky_cache_hit: meter.counter(MetricParameters { + sticky_cache_hit: counter_with_in_mem(MetricParameters { name: "sticky_cache_hit".into(), description: "Count of times the workflow cache was used for a new workflow task" .into(), unit: "".into(), }), - sticky_cache_miss: meter.counter(MetricParameters { + sticky_cache_miss: counter_with_in_mem(MetricParameters { name: "sticky_cache_miss".into(), description: - "Count of times the workflow cache was missing a workflow for a sticky task" - .into(), + "Count of times the workflow cache was missing a workflow for a sticky task".into(), unit: "".into(), }), - sticky_cache_size: meter.gauge(MetricParameters { + sticky_cache_size: gauge_with_in_mem(MetricParameters { name: STICKY_CACHE_SIZE_NAME.into(), description: "Current number of cached workflows".into(), unit: "".into(), diff --git a/core/src/telemetry/mod.rs b/core/src/telemetry/mod.rs index 4f4536938..94457f697 100644 --- a/core/src/telemetry/mod.rs +++ b/core/src/telemetry/mod.rs @@ -39,6 +39,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +pub(crate) use temporal_sdk_core_api::telemetry::metrics::WorkerHeartbeatMetrics; use temporal_sdk_core_api::telemetry::{ CoreLog, CoreTelemetry, Logger, TelemetryOptions, TelemetryOptionsBuilder, metrics::{CoreMeter, MetricKeyValue, NewAttributes, TemporalMeter}, diff --git a/core/src/worker/activities.rs b/core/src/worker/activities.rs index 505d5c840..a4edb4f2c 100644 --- a/core/src/worker/activities.rs +++ b/core/src/worker/activities.rs @@ -728,6 +728,7 @@ mod tests { prost_dur, worker::client::mocks::mock_worker_client, }; + use crossbeam_utils::atomic::AtomicCell; use temporal_sdk_core_api::worker::PollerBehavior; use temporal_sdk_core_protos::coresdk::activity_result::ActivityExecutionResult; @@ -773,6 +774,7 @@ mod tests { max_worker_acts_per_second: Some(2.0), max_tps: None, }, + Arc::new(AtomicCell::new(None)), ); let atm = WorkerActivityTasks::new( sem.clone(), @@ -864,6 +866,7 @@ mod tests { max_worker_acts_per_second: None, max_tps: None, }, + Arc::new(AtomicCell::new(None)), ); let atm = WorkerActivityTasks::new( sem.clone(), @@ -937,6 +940,7 @@ mod tests { max_worker_acts_per_second: None, max_tps: None, }, + Arc::new(AtomicCell::new(None)), ); let atm = WorkerActivityTasks::new( sem.clone(), diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 5d773330e..578976d0e 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -2,13 +2,18 @@ pub(crate) mod mocks; use crate::protosext::legacy_query_failure; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; +use prost_types::Duration as PbDuration; +use std::collections::HashMap; +use std::time::SystemTime; use std::{sync::Arc, time::Duration}; use temporal_client::{ Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, RetryClient, WorkflowService, }; use temporal_sdk_core_api::worker::WorkerVersioningStrategy; +use temporal_sdk_core_protos::temporal::api::enums::v1::WorkerStatus; +use temporal_sdk_core_protos::temporal::api::worker::v1::WorkerSlotsInfo; use temporal_sdk_core_protos::{ TaskToken, coresdk::{workflow_commands::QueryResult, workflow_completion}, @@ -48,6 +53,7 @@ pub(crate) struct WorkerClientBag { namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, + worker_heartbeat_map: Arc>>, } impl WorkerClientBag { @@ -62,6 +68,7 @@ impl WorkerClientBag { namespace, identity, worker_versioning_strategy, + worker_heartbeat_map: Arc::new(Mutex::new(HashMap::new())), } } @@ -211,7 +218,11 @@ pub trait WorkerClient: Sync + Send { /// Describe the namespace async fn describe_namespace(&self) -> Result; /// Shutdown the worker - async fn shutdown_worker(&self, sticky_task_queue: String) -> Result; + async fn shutdown_worker( + &self, + sticky_task_queue: String, + final_heartbeat: Option, + ) -> Result; /// Record a worker heartbeat async fn record_worker_heartbeat( &self, @@ -233,6 +244,9 @@ pub trait WorkerClient: Sync + Send { fn identity(&self) -> String; /// Get worker grouping key fn worker_grouping_key(&self) -> Uuid; + /// Sets the client-reliant fields for WorkerHeartbeat. This also updates client-level tracking + /// of heartbeat fields, like last heartbeat timestamp. + fn set_heartbeat_client_fields(&self, heartbeat: &mut WorkerHeartbeat); } /// Configuration options shared by workflow, activity, and Nexus polling calls @@ -640,13 +654,22 @@ impl WorkerClient for WorkerClientBag { .into_inner()) } - async fn shutdown_worker(&self, sticky_task_queue: String) -> Result { + async fn shutdown_worker( + &self, + sticky_task_queue: String, + final_heartbeat: Option, + ) -> Result { + let mut final_heartbeat = final_heartbeat; + if let Some(w) = final_heartbeat.as_mut() { + w.status = WorkerStatus::Shutdown.into(); + self.set_heartbeat_client_fields(w); + } let request = ShutdownWorkerRequest { namespace: self.namespace.clone(), identity: self.identity.clone(), sticky_task_queue, reason: "graceful shutdown".to_string(), - worker_heartbeat: None, + worker_heartbeat: final_heartbeat, }; Ok( @@ -708,6 +731,50 @@ impl WorkerClient for WorkerClientBag { .get_client() .worker_grouping_key() } + + fn set_heartbeat_client_fields(&self, heartbeat: &mut WorkerHeartbeat) { + if let Some(host_info) = heartbeat.host_info.as_mut() { + host_info.process_key = self.worker_grouping_key().to_string(); + } + heartbeat.worker_identity = self.identity(); + let sdk_name_and_ver = self.sdk_name_and_version(); + heartbeat.sdk_name = sdk_name_and_ver.0; + heartbeat.sdk_version = sdk_name_and_ver.1; + + let now = SystemTime::now(); + heartbeat.heartbeat_time = Some(now.into()); + let mut heartbeat_map = self.worker_heartbeat_map.lock(); + let client_heartbeat_data = heartbeat_map + .entry(heartbeat.worker_instance_key.clone()) + .or_default(); + let elapsed_since_last_heartbeat = + client_heartbeat_data.last_heartbeat_time.map(|hb_time| { + let dur = now.duration_since(hb_time).unwrap_or(Duration::ZERO); + PbDuration { + seconds: dur.as_secs() as i64, + nanos: dur.subsec_nanos() as i32, + } + }); + heartbeat.elapsed_since_last_heartbeat = elapsed_since_last_heartbeat; + client_heartbeat_data.last_heartbeat_time = Some(now); + + update_slots( + &mut heartbeat.workflow_task_slots_info, + &mut client_heartbeat_data.workflow_task_slots_info, + ); + update_slots( + &mut heartbeat.activity_task_slots_info, + &mut client_heartbeat_data.activity_task_slots_info, + ); + update_slots( + &mut heartbeat.nexus_task_slots_info, + &mut client_heartbeat_data.nexus_task_slots_info, + ); + update_slots( + &mut heartbeat.local_activity_slots_info, + &mut client_heartbeat_data.local_activity_slots_info, + ); + } } impl NamespacedClient for WorkerClientBag { @@ -745,3 +812,31 @@ pub struct WorkflowTaskCompletion { /// Versioning behavior of the workflow, if any. pub versioning_behavior: VersioningBehavior, } + +#[derive(Clone, Default)] +struct SlotsInfo { + total_processed_tasks: i32, + total_failed_tasks: i32, +} + +#[derive(Clone, Default)] +struct ClientHeartbeatData { + last_heartbeat_time: Option, + + workflow_task_slots_info: SlotsInfo, + activity_task_slots_info: SlotsInfo, + nexus_task_slots_info: SlotsInfo, + local_activity_slots_info: SlotsInfo, +} + +fn update_slots(slots_info: &mut Option, client_heartbeat_data: &mut SlotsInfo) { + if let Some(wft_slot_info) = slots_info.as_mut() { + wft_slot_info.last_interval_processed_tasks = + wft_slot_info.total_processed_tasks - client_heartbeat_data.total_processed_tasks; + wft_slot_info.last_interval_failure_tasks = + wft_slot_info.total_failed_tasks - client_heartbeat_data.total_failed_tasks; + + client_heartbeat_data.total_processed_tasks = wft_slot_info.total_processed_tasks; + client_heartbeat_data.total_failed_tasks = wft_slot_info.total_failed_tasks; + } +} diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index 93984c364..b86334e4b 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -30,12 +30,18 @@ pub fn mock_worker_client() -> MockWorkerClient { .returning(|| DEFAULT_WORKERS_REGISTRY.clone()); r.expect_is_mock().returning(|| true); r.expect_shutdown_worker() - .returning(|_| Ok(ShutdownWorkerResponse {})); + .returning(|_, _| Ok(ShutdownWorkerResponse {})); r.expect_sdk_name_and_version() .returning(|| ("test-core".to_string(), "0.0.0".to_string())); r.expect_identity() .returning(|| "test-identity".to_string()); r.expect_worker_grouping_key().returning(Uuid::new_v4); + r.expect_set_heartbeat_client_fields().returning(|hb| { + hb.sdk_name = "test-core".to_string(); + hb.sdk_version = "0.0.0".to_string(); + hb.worker_identity = "test-identity".to_string(); + hb.heartbeat_time = Some(SystemTime::now().into()); + }); r } @@ -148,7 +154,7 @@ mockall::mock! { impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn shutdown_worker<'a, 'b>(&self, sticky_task_queue: String) -> impl Future> + Send + 'b + fn shutdown_worker<'a, 'b>(&self, sticky_task_queue: String, worker_heartbeat: Option) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; fn record_worker_heartbeat<'a, 'b>( @@ -164,5 +170,6 @@ mockall::mock! { fn sdk_name_and_version(&self) -> (String, String); fn identity(&self) -> String; fn worker_grouping_key(&self) -> Uuid; + fn set_heartbeat_client_fields(&self, heartbeat: &mut WorkerHeartbeat); } } diff --git a/core/src/worker/heartbeat.rs b/core/src/worker/heartbeat.rs index 88774647f..7ec2f7aa5 100644 --- a/core/src/worker/heartbeat.rs +++ b/core/src/worker/heartbeat.rs @@ -1,12 +1,9 @@ use crate::WorkerClient; use crate::worker::{TaskPollers, WorkerTelemetry}; use parking_lot::Mutex; -use prost_types::Duration as PbDuration; use std::collections::HashMap; -use std::{ - sync::Arc, - time::{Duration, SystemTime}, -}; +use std::sync::Arc; +use std::time::Duration; use temporal_client::SharedNamespaceWorkerTrait; use temporal_sdk_core_api::worker::{ PollerBehavior, WorkerConfigBuilder, WorkerVersioningStrategy, @@ -17,7 +14,7 @@ use tokio_util::sync::CancellationToken; use uuid::Uuid; /// Callback used to collect heartbeat data from each worker at the time of heartbeat -pub(crate) type HeartbeatFn = Box WorkerHeartbeat + Send + Sync>; +pub(crate) type HeartbeatFn = Arc WorkerHeartbeat + Send + Sync>; /// SharedNamespaceWorker is responsible for polling nexus-delivered worker commands and sending /// worker heartbeats to the server. This invokes callbacks on all workers in the same process that @@ -49,7 +46,7 @@ impl SharedNamespaceWorker { .nexus_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) .build() .expect("all required fields should be implemented"); - let worker = crate::worker::Worker::new_with_pollers_inner( + let worker = crate::worker::Worker::new_with_pollers( config, None, client.clone(), @@ -59,8 +56,6 @@ impl SharedNamespaceWorker { true, )?; - let last_heartbeat_time_map = Mutex::new(HashMap::new()); - let reset_notify = Arc::new(Notify::new()); let cancel = CancellationToken::new(); let cancel_clone = cancel.clone(); @@ -77,34 +72,13 @@ impl SharedNamespaceWorker { tokio::select! { _ = ticker.tick() => { let mut hb_to_send = Vec::new(); - for (instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() { + for (_instance_key, heartbeat_callback) in heartbeat_map_clone.lock().iter() { let mut heartbeat = heartbeat_callback(); - let mut last_heartbeat_time_map = last_heartbeat_time_map.lock(); - let now = SystemTime::now(); - let elapsed_since_last_heartbeat = last_heartbeat_time_map.get(instance_key).cloned().map( - |hb_time| { - let dur = now.duration_since(hb_time).unwrap_or(Duration::ZERO); - PbDuration { - seconds: dur.as_secs() as i64, - nanos: dur.subsec_nanos() as i32, - } - } - ); - - heartbeat.elapsed_since_last_heartbeat = elapsed_since_last_heartbeat; - heartbeat.heartbeat_time = Some(now.into()); - // All of these heartbeat details rely on a client. To avoid circular // dependencies, this must be populated from within SharedNamespaceWorker // to get info from the current client - heartbeat.worker_identity = client_clone.identity(); - let sdk_name_and_ver = client_clone.sdk_name_and_version(); - heartbeat.sdk_name = sdk_name_and_ver.0; - heartbeat.sdk_version = sdk_name_and_ver.1; - + client_clone.set_heartbeat_client_fields(&mut heartbeat); hb_to_send.push(heartbeat); - - last_heartbeat_time_map.insert(*instance_key, now); } if let Err(e) = client_clone.record_worker_heartbeat(namespace_clone.clone(), hb_to_send).await { if matches!(e.code(), tonic::Code::Unimplemented) { @@ -137,19 +111,12 @@ impl SharedNamespaceWorkerTrait for SharedNamespaceWorker { self.namespace.clone() } - fn register_callback( - &self, - worker_instance_key: Uuid, - heartbeat_callback: Box WorkerHeartbeat + Send + Sync>, - ) { + fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatFn) { self.heartbeat_map .lock() .insert(worker_instance_key, heartbeat_callback); } - fn unregister_callback( - &self, - worker_instance_key: Uuid, - ) -> (Option WorkerHeartbeat + Send + Sync>>, bool) { + fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option, bool) { let mut heartbeat_map = self.heartbeat_map.lock(); let heartbeat_callback = heartbeat_map.remove(&worker_instance_key); if heartbeat_map.is_empty() { @@ -225,7 +192,6 @@ mod tests { client.clone(), None, Some(Duration::from_millis(100)), - false, ) .unwrap(); diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 5428c067a..f3dbf276a 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -1,6 +1,6 @@ mod activities; pub(crate) mod client; -mod heartbeat; +pub(crate) mod heartbeat; mod nexus; mod slot_provider; pub(crate) mod tuner; @@ -8,10 +8,11 @@ mod workflow; pub use temporal_sdk_core_api::worker::{WorkerConfig, WorkerConfigBuilder}; pub use tuner::{ - FixedSizeSlotSupplier, RealSysInfo, ResourceBasedSlotsOptions, - ResourceBasedSlotsOptionsBuilder, ResourceBasedTuner, ResourceSlotOptions, SlotSupplierOptions, - TunerBuilder, TunerHolder, TunerHolderOptions, TunerHolderOptionsBuilder, + FixedSizeSlotSupplier, ResourceBasedSlotsOptions, ResourceBasedSlotsOptionsBuilder, + ResourceBasedTuner, ResourceSlotOptions, SlotSupplierOptions, TunerBuilder, TunerHolder, + TunerHolderOptions, TunerHolderOptionsBuilder, }; +pub(crate) use tuner::{RealSysInfo, SystemResourceInfo}; pub(crate) use activities::{ ExecutingLAId, LocalActRequest, LocalActivityExecutionResult, LocalActivityResolution, @@ -20,6 +21,7 @@ pub(crate) use activities::{ pub(crate) use wft_poller::WFTPollerShared; pub use workflow::LEGACY_QUERY_ID; +use crate::telemetry::WorkerHeartbeatMetrics; use crate::worker::heartbeat::{HeartbeatFn, SharedNamespaceWorker}; use crate::{ ActivityHeartbeat, CompleteActivityError, PollError, WorkerTrait, @@ -49,10 +51,13 @@ use crate::{ }; use activities::WorkerActivityTasks; use anyhow::bail; +use crossbeam_utils::atomic::AtomicCell; use futures_util::{StreamExt, stream}; use gethostname::gethostname; use parking_lot::{Mutex, RwLock}; use slot_provider::SlotProvider; +use std::sync::atomic::AtomicU64; +use std::time::SystemTime; use std::{ convert::TryInto, future, @@ -67,11 +72,18 @@ use temporal_client::{ ConfiguredClient, SharedNamespaceWorkerTrait, TemporalServiceClientWithMetrics, }; use temporal_sdk_core_api::telemetry::metrics::TemporalMeter; +use temporal_sdk_core_api::worker::{ + ActivitySlotKind, LocalActivitySlotKind, NexusSlotKind, SlotKind, WorkflowSlotKind, +}; use temporal_sdk_core_api::{ errors::{CompleteNexusError, WorkerValidationError}, worker::PollerBehavior, }; -use temporal_sdk_core_protos::temporal::api::worker::v1::{WorkerHeartbeat, WorkerHostInfo}; +use temporal_sdk_core_protos::temporal::api::deployment; +use temporal_sdk_core_protos::temporal::api::enums::v1::WorkerStatus; +use temporal_sdk_core_protos::temporal::api::worker::v1::{ + WorkerHeartbeat, WorkerHostInfo, WorkerPollerInfo, WorkerSlotsInfo, +}; use temporal_sdk_core_protos::{ TaskToken, coresdk::{ @@ -131,6 +143,8 @@ pub struct Worker { all_permits_tracker: tokio::sync::Mutex, /// Used to track worker client client_worker_registrator: Arc, + /// Status of the worker + status: Arc>, } struct AllPermitsTracker { @@ -249,18 +263,15 @@ impl WorkerTrait for Worker { ); } self.shutdown_token.cancel(); - // First, unregister worker from the client - if let Err(e) = self - .client - .workers() - .unregister_worker(self.worker_instance_key) { - error!( - task_queue=%self.config.task_queue, - namespace=%self.config.namespace, - error=%e, - "Failed to unregister worker on shutdown", - ); + *self.status.lock() = WorkerStatus::ShuttingDown; + } + // First, unregister worker from the client + if !self.client_worker_registrator.shared_namespace_worker { + let _res = self + .client + .workers() + .unregister_worker(self.worker_instance_key); } // Second, we want to stop polling of both activity and workflow tasks @@ -288,6 +299,10 @@ impl WorkerTrait for Worker { async fn finalize_shutdown(self) { self.finalize_shutdown().await } + + fn worker_instance_key(&self) -> Uuid { + self.worker_instance_key + } } impl Worker { @@ -301,18 +316,23 @@ impl Worker { client: Arc, telem_instance: Option<&TelemetryInstance>, worker_heartbeat_interval: Option, - shared_namespace_worker: bool, ) -> Result { info!(task_queue=%config.task_queue, namespace=%config.namespace, "Initializing worker"); + let worker_telemetry = telem_instance.map(|telem| WorkerTelemetry { + metric_meter: telem.get_metric_meter(), + temporal_metric_meter: telem.get_temporal_metric_meter(), + trace_subscriber: telem.trace_subscriber(), + }); + Self::new_with_pollers( config, sticky_queue_name, client, TaskPollers::Real, - telem_instance, + worker_telemetry, worker_heartbeat_interval, - shared_namespace_worker, + false, ) } @@ -334,6 +354,7 @@ impl Worker { .client .workers() .unregister_worker(self.worker_instance_key)?; + let new_worker_client = super::init_worker_client( self.config.namespace.clone(), self.config.client_identity_override.clone(), @@ -342,41 +363,17 @@ impl Worker { self.client.replace_client(new_worker_client); *self.client_worker_registrator.client.write() = self.client.clone(); - self.client.workers().register_worker(client_worker) + self.client + .workers() + .register_worker(client_worker, self.config.skip_client_worker_set_check) } #[cfg(test)] pub(crate) fn new_test(config: WorkerConfig, client: impl WorkerClient + 'static) -> Self { - Self::new(config, None, Arc::new(client), None, None, false).unwrap() + Self::new(config, None, Arc::new(client), None, None).unwrap() } pub(crate) fn new_with_pollers( - config: WorkerConfig, - sticky_queue_name: Option, - client: Arc, - task_pollers: TaskPollers, - telem_instance: Option<&TelemetryInstance>, - worker_heartbeat_interval: Option, - shared_namespace_worker: bool, - ) -> Result { - let worker_telemetry = telem_instance.map(|telem| WorkerTelemetry { - metric_meter: telem.get_metric_meter(), - temporal_metric_meter: telem.get_temporal_metric_meter(), - trace_subscriber: telem.trace_subscriber(), - }); - - Worker::new_with_pollers_inner( - config, - sticky_queue_name, - client, - task_pollers, - worker_telemetry, - worker_heartbeat_interval, - shared_namespace_worker, - ) - } - - pub(crate) fn new_with_pollers_inner( config: WorkerConfig, sticky_queue_name: Option, client: Arc, @@ -398,11 +395,13 @@ impl Worker { (MetricsContext::no_op(), None) }; - let tuner = config - .tuner - .as_ref() - .cloned() - .unwrap_or_else(|| Arc::new(TunerBuilder::from_config(&config).build())); + let mut sys_info = None; + let tuner = config.tuner.as_ref().cloned().unwrap_or_else(|| { + let mut tuner_builder = TunerBuilder::from_config(&config); + sys_info = tuner_builder.get_sys_info(); + Arc::new(tuner_builder.build()) + }); + let sys_info = sys_info.unwrap_or_else(|| Arc::new(RealSysInfo::new())); metrics.worker_registered(); let shutdown_token = CancellationToken::new(); @@ -434,6 +433,12 @@ impl Worker { ); let act_permits = act_slots.get_extant_count_rcv(); let (external_wft_tx, external_wft_rx) = unbounded_channel(); + + let wf_last_suc_poll_time = Arc::new(AtomicCell::new(None)); + let wf_sticky_last_suc_poll_time = Arc::new(AtomicCell::new(None)); + let act_last_suc_poll_time = Arc::new(AtomicCell::new(None)); + let nexus_last_suc_poll_time = Arc::new(AtomicCell::new(None)); + let nexus_slots = MeteredPermitDealer::new( tuner.nexus_task_slot_supplier(), metrics.with_new_attrs([nexus_worker_type()]), @@ -450,6 +455,8 @@ impl Worker { &metrics, &shutdown_token, &wft_slots, + wf_last_suc_poll_time.clone(), + wf_sticky_last_suc_poll_time.clone(), ); let wft_stream = if !client.is_mock() { // Some replay tests combine a mock client with real pollers, @@ -475,11 +482,13 @@ impl Worker { max_worker_acts_per_second: config.max_worker_activities_per_second, max_tps: config.max_task_queue_activities_per_second, }, + act_last_suc_poll_time.clone(), ); Some(Box::from(ap) as BoxedActPoller) }; let np_metrics = metrics.with_new_attrs([nexus_poller()]); + let nexus_poll_buffer = Box::new(LongPollBuffer::new_nexus_task( client.clone(), config.task_queue.clone(), @@ -487,6 +496,7 @@ impl Worker { nexus_slots.clone(), shutdown_token.child_token(), Some(move |np| np_metrics.record_num_pollers(np)), + nexus_last_suc_poll_time.clone(), shared_namespace_worker, )) as BoxedNexusPoller; @@ -531,13 +541,13 @@ impl Worker { let la_permits = la_permit_dealer.get_extant_count_rcv(); let local_act_mgr = Arc::new(LocalActivityManager::new( config.namespace.clone(), - la_permit_dealer, + la_permit_dealer.clone(), hb_tx, metrics.clone(), )); let at_task_mgr = act_poller.map(|ap| { WorkerActivityTasks::new( - act_slots, + act_slots.clone(), ap, client.clone(), metrics.clone(), @@ -548,7 +558,7 @@ impl Worker { ) }); let poll_on_non_local_activities = at_task_mgr.is_some(); - if !poll_on_non_local_activities { + if !poll_on_non_local_activities && !shared_namespace_worker { info!("Activity polling is disabled for this worker"); }; let la_sink = LAReqSink::new(local_act_mgr.clone()); @@ -567,14 +577,29 @@ impl Worker { external_wft_tx, ); let worker_instance_key = Uuid::new_v4(); + let worker_status = Arc::new(Mutex::new(WorkerStatus::Running)); let sdk_name_and_ver = client.sdk_name_and_version(); let worker_heartbeat = worker_heartbeat_interval.map(|hb_interval| { + let hb_metrics = HeartbeatMetrics { + in_mem_metrics: metrics.in_memory_meter(), + wft_slots: wft_slots.clone(), + act_slots, + nexus_slots, + la_slots: la_permit_dealer, + wf_last_suc_poll_time, + wf_sticky_last_suc_poll_time, + act_last_suc_poll_time, + nexus_last_suc_poll_time, + status: worker_status.clone(), + sys_info, + }; WorkerHeartbeatManager::new( config.clone(), worker_instance_key, hb_interval, worker_telemetry.clone(), + hb_metrics, ) }); @@ -583,12 +608,14 @@ impl Worker { slot_provider: provider, heartbeat_manager: worker_heartbeat, client: RwLock::new(client.clone()), + shared_namespace_worker, }); if !shared_namespace_worker { - client - .workers() - .register_worker(client_worker_registrator.clone())?; + client.workers().register_worker( + client_worker_registrator.clone(), + config.skip_client_worker_set_check, + )?; } Ok(Self { @@ -650,6 +677,7 @@ impl Worker { }), nexus_mgr, client_worker_registrator, + status: worker_status, }) } @@ -658,8 +686,14 @@ impl Worker { async fn shutdown(&self) { self.initiate_shutdown(); if let Some(name) = self.workflows.get_sticky_queue_name() { + let heartbeat = self + .client_worker_registrator + .heartbeat_manager + .as_ref() + .map(|hm| hm.heartbeat_callback.clone()()); + // This is a best effort call and we can still shutdown the worker if it fails - match self.client.shutdown_worker(name).await { + match self.client.shutdown_worker(name, heartbeat).await { Err(err) if !matches!( err.code(), @@ -955,6 +989,7 @@ struct ClientWorkerRegistrator { slot_provider: SlotProvider, heartbeat_manager: Option, client: RwLock>, + shared_namespace_worker: bool, } impl ClientWorker for ClientWorkerRegistrator { @@ -979,12 +1014,12 @@ impl ClientWorker for ClientWorkerRegistrator { fn heartbeat_callback(&self) -> Option { if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { - let mut heartbeat_manager = hb_mgr.heartbeat_callback.lock(); - heartbeat_manager.take() + Some(hb_mgr.heartbeat_callback.clone()) } else { None } } + fn new_shared_namespace_worker( &self, ) -> Result, anyhow::Error> { @@ -999,12 +1034,20 @@ impl ClientWorker for ClientWorkerRegistrator { bail!("Shared namespace worker creation never be called without a heartbeat manager"); } } +} - fn register_callback(&self, callback: HeartbeatCallback) { - if let Some(hb_mgr) = self.heartbeat_manager.as_ref() { - hb_mgr.heartbeat_callback.lock().replace(callback); - } - } +struct HeartbeatMetrics { + in_mem_metrics: Option>, + wft_slots: MeteredPermitDealer, + act_slots: MeteredPermitDealer, + nexus_slots: MeteredPermitDealer, + la_slots: MeteredPermitDealer, + wf_last_suc_poll_time: Arc>>, + wf_sticky_last_suc_poll_time: Arc>>, + act_last_suc_poll_time: Arc>>, + nexus_last_suc_poll_time: Arc>>, + status: Arc>, + sys_info: Arc, } struct WorkerHeartbeatManager { @@ -1013,7 +1056,7 @@ struct WorkerHeartbeatManager { /// Telemetry instance, needed to initialize [SharedNamespaceWorker] when replacing client telemetry: Option, /// Heartbeat callback - heartbeat_callback: Mutex WorkerHeartbeat + Send + Sync>>>, + heartbeat_callback: Arc WorkerHeartbeat + Send + Sync>, } impl WorkerHeartbeatManager { @@ -1022,48 +1065,136 @@ impl WorkerHeartbeatManager { worker_instance_key: Uuid, heartbeat_interval: Duration, telemetry_instance: Option, + heartbeat_manager_metrics: HeartbeatMetrics, ) -> Self { - let worker_instance_key_clone = worker_instance_key.to_string(); - let task_queue = config.task_queue.clone(); + let start_time = Some(SystemTime::now().into()); + let worker_heartbeat_callback: HeartbeatFn = Arc::new(move || { + let deployment_version = config.computed_deployment_version().map(|dv| { + deployment::v1::WorkerDeploymentVersion { + deployment_name: dv.deployment_name, + build_id: dv.build_id, + } + }); - // TODO: requires the metrics changes to get the rest of these fields - let worker_heartbeat_callback: HeartbeatFn = Box::new(move || { - WorkerHeartbeat { - worker_instance_key: worker_instance_key_clone.clone(), + let mut worker_heartbeat = WorkerHeartbeat { + worker_instance_key: worker_instance_key.to_string(), host_info: Some(WorkerHostInfo { host_name: gethostname().to_string_lossy().to_string(), process_id: std::process::id().to_string(), - ..Default::default() + current_host_cpu_usage: heartbeat_manager_metrics.sys_info.used_cpu_percent() + as f32, + current_host_mem_usage: heartbeat_manager_metrics.sys_info.used_mem_percent() + as f32, + + // Set by SharedNamespaceWorker because it relies on the client + process_key: String::new(), }), - task_queue: task_queue.clone(), - deployment_version: None, - - status: 0, - start_time: Some(std::time::SystemTime::now().into()), - workflow_task_slots_info: None, - activity_task_slots_info: None, - nexus_task_slots_info: None, - local_activity_slots_info: None, - workflow_poller_info: None, - workflow_sticky_poller_info: None, - activity_poller_info: None, - nexus_poller_info: None, - total_sticky_cache_hit: 0, - total_sticky_cache_miss: 0, - current_sticky_cache_size: 0, - plugins: vec![], - - // sdk_name, sdk_version, and worker_identity must be set by + task_queue: config.task_queue.clone(), + deployment_version, + + status: (*heartbeat_manager_metrics.status.lock()) as i32, + start_time, + plugins: config.plugins.clone(), + + // Some Metrics dependent fields are set below, and + // some fields like sdk_name, sdk_version, and worker_identity, must be set by // SharedNamespaceWorker because they rely on the client, and // need to be pulled from the current client used by SharedNamespaceWorker ..Default::default() + }; + + if let Some(in_mem) = heartbeat_manager_metrics.in_mem_metrics.as_ref() { + worker_heartbeat.total_sticky_cache_hit = + in_mem.total_sticky_cache_hit.load(Ordering::Relaxed) as i32; + worker_heartbeat.total_sticky_cache_miss = + in_mem.total_sticky_cache_miss.load(Ordering::Relaxed) as i32; + worker_heartbeat.current_sticky_cache_size = + in_mem.sticky_cache_size.load(Ordering::Relaxed) as i32; + + worker_heartbeat.workflow_poller_info = Some(WorkerPollerInfo { + current_pollers: in_mem + .num_pollers + .wft_current_pollers + .load(Ordering::Relaxed) as i32, + last_successful_poll_time: heartbeat_manager_metrics + .wf_last_suc_poll_time + .load() + .map(|time| time.into()), + is_autoscaling: config.workflow_task_poller_behavior.is_autoscaling(), + }); + worker_heartbeat.workflow_sticky_poller_info = Some(WorkerPollerInfo { + current_pollers: in_mem + .num_pollers + .sticky_wft_current_pollers + .load(Ordering::Relaxed) as i32, + last_successful_poll_time: heartbeat_manager_metrics + .wf_sticky_last_suc_poll_time + .load() + .map(|time| time.into()), + is_autoscaling: config.workflow_task_poller_behavior.is_autoscaling(), + }); + worker_heartbeat.activity_poller_info = Some(WorkerPollerInfo { + current_pollers: in_mem + .num_pollers + .activity_current_pollers + .load(Ordering::Relaxed) as i32, + last_successful_poll_time: heartbeat_manager_metrics + .act_last_suc_poll_time + .load() + .map(|time| time.into()), + is_autoscaling: config.activity_task_poller_behavior.is_autoscaling(), + }); + worker_heartbeat.nexus_poller_info = Some(WorkerPollerInfo { + current_pollers: in_mem + .num_pollers + .nexus_current_pollers + .load(Ordering::Relaxed) as i32, + last_successful_poll_time: heartbeat_manager_metrics + .nexus_last_suc_poll_time + .load() + .map(|time| time.into()), + is_autoscaling: config.nexus_task_poller_behavior.is_autoscaling(), + }); + + worker_heartbeat.workflow_task_slots_info = make_slots_info( + &heartbeat_manager_metrics.wft_slots, + in_mem.worker_task_slots_available.workflow_worker.clone(), + in_mem.worker_task_slots_used.workflow_worker.clone(), + in_mem.workflow_task_execution_latency.clone(), + in_mem.workflow_task_execution_failed.clone(), + ); + worker_heartbeat.activity_task_slots_info = make_slots_info( + &heartbeat_manager_metrics.act_slots, + in_mem.worker_task_slots_available.activity_worker.clone(), + in_mem.worker_task_slots_used.activity_worker.clone(), + in_mem.activity_execution_latency.clone(), + in_mem.activity_execution_failed.clone(), + ); + worker_heartbeat.nexus_task_slots_info = make_slots_info( + &heartbeat_manager_metrics.nexus_slots, + in_mem.worker_task_slots_available.nexus_worker.clone(), + in_mem.worker_task_slots_used.nexus_worker.clone(), + in_mem.nexus_task_execution_latency.clone(), + in_mem.nexus_task_execution_failed.clone(), + ); + worker_heartbeat.local_activity_slots_info = make_slots_info( + &heartbeat_manager_metrics.la_slots, + in_mem + .worker_task_slots_available + .local_activity_worker + .clone(), + in_mem.worker_task_slots_used.local_activity_worker.clone(), + in_mem.local_activity_execution_latency.clone(), + in_mem.local_activity_execution_failed.clone(), + ); } + worker_heartbeat }); WorkerHeartbeatManager { heartbeat_interval, telemetry: telemetry_instance, - heartbeat_callback: Mutex::new(Some(worker_heartbeat_callback)), + heartbeat_callback: worker_heartbeat_callback, } } } @@ -1105,6 +1236,31 @@ fn wft_poller_behavior(config: &WorkerConfig, is_sticky: bool) -> PollerBehavior } } +fn make_slots_info( + dealer: &MeteredPermitDealer, + slots_available: Arc, + slots_used: Arc, + total_processed: Arc, + total_failed: Arc, +) -> Option +where + SK: SlotKind + 'static, +{ + Some(WorkerSlotsInfo { + current_available_slots: i32::try_from(slots_available.load(Ordering::Relaxed)) + .unwrap_or(-1), + current_used_slots: i32::try_from(slots_used.load(Ordering::Relaxed)).unwrap_or(-1), + slot_supplier_kind: dealer.slot_supplier_kind().to_string(), + total_processed_tasks: i32::try_from(total_processed.load(Ordering::Relaxed)) + .unwrap_or(i32::MIN), + total_failed_tasks: i32::try_from(total_failed.load(Ordering::Relaxed)).unwrap_or(i32::MIN), + + // Filled in by heartbeat later + last_interval_processed_tasks: 0, + last_interval_failure_tasks: 0, + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/core/src/worker/tuner.rs b/core/src/worker/tuner.rs index ed592a6f7..0a7cadcc9 100644 --- a/core/src/worker/tuner.rs +++ b/core/src/worker/tuner.rs @@ -3,10 +3,12 @@ mod resource_based; pub use fixed_size::FixedSizeSlotSupplier; pub use resource_based::{ - RealSysInfo, ResourceBasedSlotsOptions, ResourceBasedSlotsOptionsBuilder, ResourceBasedTuner, + ResourceBasedSlotsOptions, ResourceBasedSlotsOptionsBuilder, ResourceBasedTuner, ResourceSlotOptions, }; +pub(crate) use resource_based::{RealSysInfo, SystemResourceInfo}; + use std::sync::Arc; use temporal_sdk_core_api::worker::{ ActivitySlotKind, LocalActivitySlotKind, NexusSlotKind, SlotKind, SlotSupplier, WorkerConfig, @@ -126,6 +128,9 @@ impl TunerHolderOptions { } None => {} } + if let Some(tuner) = rb_tuner { + builder.sys_info(tuner.sys_info()); + } Ok(builder.build()) } } @@ -187,6 +192,7 @@ pub struct TunerBuilder { local_activity_slot_supplier: Option + Send + Sync>>, nexus_slot_supplier: Option + Send + Sync>>, + sys_info: Option>, } impl TunerBuilder { @@ -243,6 +249,17 @@ impl TunerBuilder { self } + /// Sets a field that implements [SystemResourceInfo] + pub fn sys_info(&mut self, sys_info: Arc) -> &mut Self { + self.sys_info = Some(sys_info); + self + } + + /// Gets the field that implements [SystemResourceInfo] + pub fn get_sys_info(&self) -> Option> { + self.sys_info.clone() + } + /// Build a [WorkerTuner] from the configured slot suppliers pub fn build(&mut self) -> TunerHolder { TunerHolder { diff --git a/core/src/worker/tuner/fixed_size.rs b/core/src/worker/tuner/fixed_size.rs index aa737dc8b..e1bf53d6e 100644 --- a/core/src/worker/tuner/fixed_size.rs +++ b/core/src/worker/tuner/fixed_size.rs @@ -50,4 +50,8 @@ where fn available_slots(&self) -> Option { Some(self.sem.available_permits()) } + + fn slot_supplier_kind(&self) -> String { + "Fixed".to_string() + } } diff --git a/core/src/worker/tuner/resource_based.rs b/core/src/worker/tuner/resource_based.rs index 173418413..88606add3 100644 --- a/core/src/worker/tuner/resource_based.rs +++ b/core/src/worker/tuner/resource_based.rs @@ -1,11 +1,13 @@ use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; +use std::sync::mpsc; use std::{ marker::PhantomData, sync::{ Arc, OnceLock, atomic::{AtomicU64, AtomicUsize, Ordering}, }, + thread, time::{Duration, Instant}, }; use temporal_sdk_core_api::{ @@ -31,6 +33,8 @@ pub struct ResourceBasedTuner { act_opts: Option, la_opts: Option, nexus_opts: Option, + + sys_info: Arc, } impl ResourceBasedTuner { @@ -42,25 +46,28 @@ impl ResourceBasedTuner { .target_cpu_usage(target_cpu_usage) .build() .expect("default resource based slot options can't fail to build"); - let controller = ResourceController::new_with_sysinfo(opts, RealSysInfo::new()); + let controller = ResourceController::new_with_sysinfo(opts, Arc::new(RealSysInfo::new())); Self::new_from_controller(controller) } /// Create an instance using the fully configurable set of PID controller options pub fn new_from_options(options: ResourceBasedSlotsOptions) -> Self { - let controller = ResourceController::new_with_sysinfo(options, RealSysInfo::new()); + let controller = + ResourceController::new_with_sysinfo(options, Arc::new(RealSysInfo::new())); Self::new_from_controller(controller) } } impl ResourceBasedTuner { fn new_from_controller(controller: ResourceController) -> Self { + let sys_info = controller.sys_info_supplier.clone(); Self { slots: Arc::new(controller), wf_opts: None, act_opts: None, la_opts: None, nexus_opts: None, + sys_info, } } @@ -87,6 +94,11 @@ impl ResourceBasedTuner { self.nexus_opts = Some(opts); self } + + /// Get sys info + pub fn sys_info(&self) -> Arc { + self.sys_info.clone() + } } const DEFAULT_WF_SLOT_OPTS: ResourceSlotOptions = ResourceSlotOptions { @@ -121,7 +133,7 @@ pub struct ResourceSlotOptions { struct ResourceController { options: ResourceBasedSlotsOptions, - sys_info_supplier: MI, + sys_info_supplier: Arc, metrics: OnceLock>, pids: Mutex, last_metric_vals: Arc>, @@ -314,6 +326,10 @@ where } } } + + fn slot_supplier_kind(&self) -> String { + "ResourceBased".to_string() + } } impl ResourceBasedSlotsForType @@ -421,7 +437,7 @@ impl ResourceController { Arc::new(ResourceBasedSlotsForType::new(self.clone(), opts)) } - fn new_with_sysinfo(options: ResourceBasedSlotsOptions, sys_info: MI) -> Self { + fn new_with_sysinfo(options: ResourceBasedSlotsOptions, sys_info: Arc) -> Self { Self { pids: Mutex::new(PidControllers::new(&options)), options, @@ -474,37 +490,14 @@ impl ResourceController { /// Implements [SystemResourceInfo] using the [sysinfo] crate #[derive(Debug)] -pub struct RealSysInfo { +struct RealSysInfoInner { sys: Mutex, total_mem: AtomicU64, cur_mem_usage: AtomicU64, cur_cpu_usage: AtomicU64, - last_refresh: AtomicCell, } -impl RealSysInfo { - fn new() -> Self { - let mut sys = sysinfo::System::new(); - sys.refresh_memory(); - let total_mem = sys.total_memory(); - let s = Self { - sys: Mutex::new(sys), - last_refresh: AtomicCell::new(Instant::now()), - cur_mem_usage: AtomicU64::new(0), - cur_cpu_usage: AtomicU64::new(0), - total_mem: AtomicU64::new(total_mem), - }; - s.refresh(); - s - } - - fn refresh_if_needed(&self) { - // This is all quite expensive and meaningfully slows everything down if it's allowed to - // happen more often. A better approach than a lock would be needed to go faster. - if (Instant::now() - self.last_refresh.load()) > Duration::from_millis(100) { - self.refresh(); - } - } +impl RealSysInfoInner { fn refresh(&self) { let mut lock = self.sys.lock(); lock.refresh_memory(); @@ -522,25 +515,73 @@ impl RealSysInfo { self.cur_mem_usage.store(mem, Ordering::Release); } self.cur_cpu_usage.store(cpu.to_bits(), Ordering::Release); - self.last_refresh.store(Instant::now()); + } +} + +/// Tracks host resource usage by refreshing metrics on a background thread. +pub struct RealSysInfo { + inner: Arc, + shutdown_tx: mpsc::Sender<()>, + shutdown_handle: Mutex>>, +} + +impl RealSysInfo { + pub(crate) fn new() -> Self { + let mut sys = sysinfo::System::new(); + sys.refresh_memory(); + let total_mem = sys.total_memory(); + let inner = Arc::new(RealSysInfoInner { + sys: Mutex::new(sys), + cur_mem_usage: AtomicU64::new(0), + cur_cpu_usage: AtomicU64::new(0), + total_mem: AtomicU64::new(total_mem), + }); + inner.refresh(); + + let thread_clone = inner.clone(); + let (tx, rx) = mpsc::channel::<()>(); + let handle = thread::Builder::new() + .name("temporal-real-sysinfo".to_string()) + .spawn(move || { + const REFRESH_INTERVAL: Duration = Duration::from_millis(100); + loop { + thread_clone.refresh(); + let r = rx.recv_timeout(REFRESH_INTERVAL); + if matches!(r, Err(mpsc::RecvTimeoutError::Disconnected)) || r.is_ok() { + return; + } + } + }) + .expect("failed to spawn RealSysInfo refresh thread"); + + Self { + inner, + shutdown_tx: tx, + shutdown_handle: Mutex::new(Some(handle)), + } } } impl SystemResourceInfo for RealSysInfo { fn total_mem(&self) -> u64 { - self.total_mem.load(Ordering::Acquire) + self.inner.total_mem.load(Ordering::Acquire) } fn used_mem(&self) -> u64 { - // TODO: This should really happen on a background thread since it's getting called from - // the async reserve - self.refresh_if_needed(); - self.cur_mem_usage.load(Ordering::Acquire) + self.inner.cur_mem_usage.load(Ordering::Acquire) } fn used_cpu_percent(&self) -> f64 { - self.refresh_if_needed(); - f64::from_bits(self.cur_cpu_usage.load(Ordering::Acquire)) + f64::from_bits(self.inner.cur_cpu_usage.load(Ordering::Acquire)) + } +} + +impl Drop for RealSysInfo { + fn drop(&mut self) { + let _res = self.shutdown_tx.send(()); + if let Some(handle) = self.shutdown_handle.lock().take() { + let _ = handle.join(); + } } } @@ -558,9 +599,9 @@ mod tests { used: Arc, } impl FakeMIS { - fn new() -> (Self, Arc) { + fn new() -> (Arc, Arc) { let used = Arc::new(AtomicU64::new(0)); - (Self { used: used.clone() }, used) + (Arc::new(Self { used: used.clone() }), used) } } impl SystemResourceInfo for FakeMIS { diff --git a/core/src/worker/workflow/wft_poller.rs b/core/src/worker/workflow/wft_poller.rs index 3cc2da579..0a00ad179 100644 --- a/core/src/worker/workflow/wft_poller.rs +++ b/core/src/worker/workflow/wft_poller.rs @@ -6,13 +6,16 @@ use crate::{ telemetry::metrics::{workflow_poller, workflow_sticky_poller}, worker::{client::WorkerClient, wft_poller_behavior}, }; +use crossbeam_utils::atomic::AtomicCell; use futures_util::{Stream, stream}; use std::sync::{Arc, OnceLock}; +use std::time::SystemTime; use temporal_sdk_core_api::worker::{WorkerConfig, WorkflowSlotKind}; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; use tokio::sync::watch; use tokio_util::sync::CancellationToken; +#[allow(clippy::too_many_arguments)] pub(crate) fn make_wft_poller( config: &WorkerConfig, sticky_queue_name: &Option, @@ -20,6 +23,8 @@ pub(crate) fn make_wft_poller( metrics: &MetricsContext, shutdown_token: &CancellationToken, wft_slots: &MeteredPermitDealer, + last_successful_poll_time: Arc>>, + sticky_last_successful_poll_time: Arc>>, ) -> impl Stream< Item = Result< ( @@ -52,6 +57,7 @@ pub(crate) fn make_wft_poller( WorkflowTaskOptions { wft_poller_shared: wft_poller_shared.clone(), }, + last_successful_poll_time, ); let sticky_queue_poller = sticky_queue_name.as_ref().map(|sqn| { let sticky_metrics = metrics.with_new_attrs([workflow_sticky_poller()]); @@ -66,6 +72,7 @@ pub(crate) fn make_wft_poller( sticky_metrics.record_num_pollers(np); }), WorkflowTaskOptions { wft_poller_shared }, + sticky_last_successful_poll_time, ) }); let wf_task_poll_buffer = Box::new(WorkflowTaskPoller::new( diff --git a/sdk-core-protos/protos/api_upstream/openapi/openapiv2.json b/sdk-core-protos/protos/api_upstream/openapi/openapiv2.json index 8591cb0be..cfed16ffd 100644 --- a/sdk-core-protos/protos/api_upstream/openapi/openapiv2.json +++ b/sdk-core-protos/protos/api_upstream/openapi/openapiv2.json @@ -2130,6 +2130,51 @@ ] } }, + "/api/v1/namespaces/{namespace}/worker-deployments/{deploymentName}/set-manager": { + "post": { + "summary": "Set/unset the ManagerIdentity of a Worker Deployment.\nExperimental. This API might significantly change or be removed in a future release.", + "operationId": "SetWorkerDeploymentManager2", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1SetWorkerDeploymentManagerResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "namespace", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "deploymentName", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/WorkflowServiceSetWorkerDeploymentManagerBody" + } + } + ], + "tags": [ + "WorkflowService" + ] + } + }, "/api/v1/namespaces/{namespace}/worker-deployments/{deploymentName}/set-ramping-version": { "post": { "summary": "Set/unset the Ramping Version of a Worker Deployment and its ramp percentage. Can be used for\ngradual ramp to unversioned workers too.\nExperimental. This API might significantly change or be removed in a future release.", @@ -2296,6 +2341,45 @@ ] } }, + "/api/v1/namespaces/{namespace}/workers/describe/{workerInstanceKey}": { + "get": { + "summary": "DescribeWorker returns information about the specified worker.", + "operationId": "DescribeWorker2", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1DescribeWorkerResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "namespace", + "description": "Namespace this worker belongs to.", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "workerInstanceKey", + "description": "Worker instance key to describe.", + "in": "path", + "required": true, + "type": "string" + } + ], + "tags": [ + "WorkflowService" + ] + } + }, "/api/v1/namespaces/{namespace}/workers/fetch-config": { "post": { "summary": "FetchWorkerConfig returns the worker configuration for a specific worker.", @@ -2822,7 +2906,7 @@ }, "/api/v1/namespaces/{namespace}/workflows/{execution.workflowId}/history-reverse": { "get": { - "summary": "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse \norder (starting from last event). Fails with`NotFound` if the specified workflow execution is \nunknown to the service.", + "summary": "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse\norder (starting from last event). Fails with`NotFound` if the specified workflow execution is\nunknown to the service.", "operationId": "GetWorkflowExecutionHistoryReverse2", "responses": { "200": { @@ -5871,6 +5955,51 @@ ] } }, + "/namespaces/{namespace}/worker-deployments/{deploymentName}/set-manager": { + "post": { + "summary": "Set/unset the ManagerIdentity of a Worker Deployment.\nExperimental. This API might significantly change or be removed in a future release.", + "operationId": "SetWorkerDeploymentManager", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1SetWorkerDeploymentManagerResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "namespace", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "deploymentName", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/WorkflowServiceSetWorkerDeploymentManagerBody" + } + } + ], + "tags": [ + "WorkflowService" + ] + } + }, "/namespaces/{namespace}/worker-deployments/{deploymentName}/set-ramping-version": { "post": { "summary": "Set/unset the Ramping Version of a Worker Deployment and its ramp percentage. Can be used for\ngradual ramp to unversioned workers too.\nExperimental. This API might significantly change or be removed in a future release.", @@ -6037,6 +6166,45 @@ ] } }, + "/namespaces/{namespace}/workers/describe/{workerInstanceKey}": { + "get": { + "summary": "DescribeWorker returns information about the specified worker.", + "operationId": "DescribeWorker", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1DescribeWorkerResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "namespace", + "description": "Namespace this worker belongs to.", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "workerInstanceKey", + "description": "Worker instance key to describe.", + "in": "path", + "required": true, + "type": "string" + } + ], + "tags": [ + "WorkflowService" + ] + } + }, "/namespaces/{namespace}/workers/fetch-config": { "post": { "summary": "FetchWorkerConfig returns the worker configuration for a specific worker.", @@ -6563,7 +6731,7 @@ }, "/namespaces/{namespace}/workflows/{execution.workflowId}/history-reverse": { "get": { - "summary": "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse \norder (starting from last event). Fails with`NotFound` if the specified workflow execution is \nunknown to the service.", + "summary": "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse\norder (starting from last event). Fails with`NotFound` if the specified workflow execution is\nunknown to the service.", "operationId": "GetWorkflowExecutionHistoryReverse", "responses": { "200": { @@ -7861,7 +8029,7 @@ }, "type": { "type": "string", - "description": "Pause all running activities of this type." + "description": "Pause all running activities of this type.\nNote: Experimental - the behavior of pause by activity type might change in a future release." }, "reason": { "type": "string", @@ -8314,10 +8482,37 @@ "ignoreMissingTaskQueues": { "type": "boolean", "description": "Optional. By default this request would be rejected if not all the expected Task Queues are\nbeing polled by the new Version, to protect against accidental removal of Task Queues, or\nworker health issues. Pass `true` here to bypass this protection.\nThe set of expected Task Queues is the set of all the Task Queues that were ever poller by\nthe existing Current Version of the Deployment, with the following exclusions:\n - Task Queues that are not used anymore (inferred by having empty backlog and a task\n add_rate of 0.)\n - Task Queues that are moved to another Worker Deployment (inferred by the Task Queue\n having a different Current Version than the Current Version of this deployment.)\nWARNING: Do not set this flag unless you are sure that the missing task queue pollers are not\nneeded. If the request is unexpectedly rejected due to missing pollers, then that means the\npollers have not reached to the server yet. Only set this if you expect those pollers to\nnever arrive." + }, + "allowNoPollers": { + "type": "boolean", + "description": "Optional. By default this request will be rejected if no pollers have been seen for the proposed\nCurrent Version, in order to protect users from routing tasks to pollers that do not exist, leading\nto possible timeouts. Pass `true` here to bypass this protection." } }, "description": "Set/unset the Current Version of a Worker Deployment." }, + "WorkflowServiceSetWorkerDeploymentManagerBody": { + "type": "object", + "properties": { + "managerIdentity": { + "type": "string", + "description": "Arbitrary value for `manager_identity`.\nEmpty will unset the field." + }, + "self": { + "type": "boolean", + "description": "True will set `manager_identity` to `identity`." + }, + "conflictToken": { + "type": "string", + "format": "byte", + "description": "Optional. This can be the value of conflict_token from a Describe, or another Worker\nDeployment API. Passing a non-nil conflict token will cause this request to fail if the\nDeployment's configuration has been modified between the API call that generated the\ntoken and this one." + }, + "identity": { + "type": "string", + "description": "Required. The identity of the client who initiated this request." + } + }, + "description": "Update the ManagerIdentity of a Worker Deployment." + }, "WorkflowServiceSetWorkerDeploymentRampingVersionBody": { "type": "object", "properties": { @@ -8346,6 +8541,10 @@ "ignoreMissingTaskQueues": { "type": "boolean", "description": "Optional. By default this request would be rejected if not all the expected Task Queues are\nbeing polled by the new Version, to protect against accidental removal of Task Queues, or\nworker health issues. Pass `true` here to bypass this protection.\nThe set of expected Task Queues equals to all the Task Queues ever polled from the existing\nCurrent Version of the Deployment, with the following exclusions:\n - Task Queues that are not used anymore (inferred by having empty backlog and a task\n add_rate of 0.)\n - Task Queues that are moved to another Worker Deployment (inferred by the Task Queue\n having a different Current Version than the Current Version of this deployment.)\nWARNING: Do not set this flag unless you are sure that the missing task queue poller are not\nneeded. If the request is unexpectedly rejected due to missing pollers, then that means the\npollers have not reached to the server yet. Only set this if you expect those pollers to\nnever arrive.\nNote: this check only happens when the ramping version is about to change, not every time\nthat the percentage changes. Also note that the check is against the deployment's Current\nVersion, not the previous Ramping Version." + }, + "allowNoPollers": { + "type": "boolean", + "description": "Optional. By default this request will be rejected if no pollers have been seen for the proposed\nCurrent Version, in order to protect users from routing tasks to pollers that do not exist, leading\nto possible timeouts. Pass `true` here to bypass this protection." } }, "description": "Set/unset the Ramping Version of a Worker Deployment and its ramp percentage." @@ -8643,6 +8842,10 @@ "priority": { "$ref": "#/definitions/v1Priority", "title": "Priority metadata" + }, + "eagerWorkerDeploymentOptions": { + "$ref": "#/definitions/v1WorkerDeploymentOptions", + "description": "Deployment Options of the worker who will process the eager task. Passed when `request_eager_execution=true`." } } }, @@ -10875,6 +11078,14 @@ } } }, + "v1DescribeWorkerResponse": { + "type": "object", + "properties": { + "workerInfo": { + "$ref": "#/definitions/v1WorkerInfo" + } + } + }, "v1DescribeWorkflowExecutionResponse": { "type": "object", "properties": { @@ -11198,6 +11409,14 @@ }, "visibilityStore": { "type": "string" + }, + "initialFailoverVersion": { + "type": "string", + "format": "int64" + }, + "failoverVersionIncrement": { + "type": "string", + "format": "int64" } }, "description": "GetClusterInfoResponse contains information about Temporal cluster." @@ -12167,6 +12386,14 @@ "asyncUpdate": { "type": "boolean", "title": "True if the namespace supports async update" + }, + "workerHeartbeats": { + "type": "boolean", + "title": "True if the namespace supports worker heartbeats" + }, + "reportedProblemsSearchAttribute": { + "type": "boolean", + "title": "True if the namespace supports reported problems search attribute" } }, "description": "Namespace capability details. Should contain what features are enabled in a namespace." @@ -13154,7 +13381,7 @@ "priorityKey": { "type": "integer", "format": "int32", - "description": "Priority key is a positive integer from 1 to n, where smaller integers\ncorrespond to higher priorities (tasks run sooner). In general, tasks in\na queue should be processed in close to priority order, although small\ndeviations are possible.\n\nThe maximum priority value (minimum priority) is determined by server\nconfiguration, and defaults to 5.\n\nIf priority is not present (or zero), then the effective priority will be\nthe default priority, which is is calculated by (min+max)/2. With the\ndefault max of 5, and min of 1, that comes out to 3." + "description": "Priority key is a positive integer from 1 to n, where smaller integers\ncorrespond to higher priorities (tasks run sooner). In general, tasks in\na queue should be processed in close to priority order, although small\ndeviations are possible.\n\nThe maximum priority value (minimum priority) is determined by server\nconfiguration, and defaults to 5.\n\nIf priority is not present (or zero), then the effective priority will be\nthe default priority, which is calculated by (min+max)/2. With the\ndefault max of 5, and min of 1, that comes out to 3." }, "fairnessKey": { "type": "string", @@ -14343,6 +14570,20 @@ } } }, + "v1SetWorkerDeploymentManagerResponse": { + "type": "object", + "properties": { + "conflictToken": { + "type": "string", + "format": "byte", + "description": "This value is returned so that it can be optionally passed to APIs\nthat write to the Worker Deployment state to ensure that the state\ndid not change between this API call and a future write." + }, + "previousManagerIdentity": { + "type": "string", + "description": "What the `manager_identity` field was before this change." + } + } + }, "v1SetWorkerDeploymentRampingVersionResponse": { "type": "object", "properties": { @@ -14891,6 +15132,10 @@ "priority": { "$ref": "#/definitions/v1Priority", "title": "Priority metadata" + }, + "eagerWorkerDeploymentOptions": { + "$ref": "#/definitions/v1WorkerDeploymentOptions", + "description": "Deployment Options of the worker who will process the eager task. Passed when `request_eager_execution=true`." } } }, @@ -15453,7 +15698,7 @@ "additionalProperties": { "type": "string" }, - "description": "A key-value map for any customized purpose.\nIf data already exists on the namespace, \nthis will merge with the existing key values." + "description": "A key-value map for any customized purpose.\nIf data already exists on the namespace,\nthis will merge with the existing key values." }, "state": { "$ref": "#/definitions/v1NamespaceState", @@ -15812,6 +16057,10 @@ "lastModifierIdentity": { "type": "string", "description": "Identity of the last client who modified the configuration of this Deployment. Set to the\n`identity` value sent by APIs such as `SetWorkerDeploymentCurrentVersion` and\n`SetWorkerDeploymentRampingVersion`." + }, + "managerIdentity": { + "type": "string", + "description": "Identity of the client that has the exclusive right to make changes to this Worker Deployment.\nEmpty by default.\nIf this is set, clients whose identity does not match `manager_identity` will not be able to make changes\nto this Worker Deployment. They can either set their own identity as the manager or unset the field to proceed." } }, "description": "A Worker Deployment (Deployment, for short) represents all workers serving \na shared set of Task Queues. Typically, a Deployment represents one service or \napplication.\nA Deployment contains multiple Deployment Versions, each representing a different \nversion of workers. (see documentation of WorkerDeploymentVersionInfo)\nDeployment records are created in Temporal server automatically when their\nfirst poller arrives to the server.\nExperimental. Worker Deployments are experimental and might significantly change in the future." diff --git a/sdk-core-protos/protos/api_upstream/openapi/openapiv3.yaml b/sdk-core-protos/protos/api_upstream/openapi/openapiv3.yaml index 7d587b366..88f8737d0 100644 --- a/sdk-core-protos/protos/api_upstream/openapi/openapiv3.yaml +++ b/sdk-core-protos/protos/api_upstream/openapi/openapiv3.yaml @@ -1914,6 +1914,44 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' + /api/v1/namespaces/{namespace}/worker-deployments/{deploymentName}/set-manager: + post: + tags: + - WorkflowService + description: |- + Set/unset the ManagerIdentity of a Worker Deployment. + Experimental. This API might significantly change or be removed in a future release. + operationId: SetWorkerDeploymentManager + parameters: + - name: namespace + in: path + required: true + schema: + type: string + - name: deploymentName + in: path + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SetWorkerDeploymentManagerRequest' + required: true + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/SetWorkerDeploymentManagerResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /api/v1/namespaces/{namespace}/worker-deployments/{deploymentName}/set-ramping-version: post: tags: @@ -2087,6 +2125,38 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' + /api/v1/namespaces/{namespace}/workers/describe/{workerInstanceKey}: + get: + tags: + - WorkflowService + description: DescribeWorker returns information about the specified worker. + operationId: DescribeWorker + parameters: + - name: namespace + in: path + description: Namespace this worker belongs to. + required: true + schema: + type: string + - name: workerInstanceKey + in: path + description: Worker instance key to describe. + required: true + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/DescribeWorkerResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /api/v1/namespaces/{namespace}/workers/fetch-config: post: tags: @@ -2540,7 +2610,10 @@ paths: get: tags: - WorkflowService - description: "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse \n order (starting from last event). Fails with`NotFound` if the specified workflow execution is \n unknown to the service." + description: |- + GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse + order (starting from last event). Fails with`NotFound` if the specified workflow execution is + unknown to the service. operationId: GetWorkflowExecutionHistoryReverse parameters: - name: namespace @@ -5265,6 +5338,44 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' + /namespaces/{namespace}/worker-deployments/{deploymentName}/set-manager: + post: + tags: + - WorkflowService + description: |- + Set/unset the ManagerIdentity of a Worker Deployment. + Experimental. This API might significantly change or be removed in a future release. + operationId: SetWorkerDeploymentManager + parameters: + - name: namespace + in: path + required: true + schema: + type: string + - name: deploymentName + in: path + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SetWorkerDeploymentManagerRequest' + required: true + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/SetWorkerDeploymentManagerResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /namespaces/{namespace}/worker-deployments/{deploymentName}/set-ramping-version: post: tags: @@ -5438,6 +5549,38 @@ paths: application/json: schema: $ref: '#/components/schemas/Status' + /namespaces/{namespace}/workers/describe/{workerInstanceKey}: + get: + tags: + - WorkflowService + description: DescribeWorker returns information about the specified worker. + operationId: DescribeWorker + parameters: + - name: namespace + in: path + description: Namespace this worker belongs to. + required: true + schema: + type: string + - name: workerInstanceKey + in: path + description: Worker instance key to describe. + required: true + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/DescribeWorkerResponse' + default: + description: Default error response + content: + application/json: + schema: + $ref: '#/components/schemas/Status' /namespaces/{namespace}/workers/fetch-config: post: tags: @@ -5891,7 +6034,10 @@ paths: get: tags: - WorkflowService - description: "GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse \n order (starting from last event). Fails with`NotFound` if the specified workflow execution is \n unknown to the service." + description: |- + GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse + order (starting from last event). Fails with`NotFound` if the specified workflow execution is + unknown to the service. operationId: GetWorkflowExecutionHistoryReverse parameters: - name: namespace @@ -7921,6 +8067,11 @@ components: Only set if `report_task_queue_stats` is set to true in the request. (-- api-linter: core::0140::prepositions=disabled aip.dev/not-precedent: "by" is used to clarify the key. --) + DescribeWorkerResponse: + type: object + properties: + workerInfo: + $ref: '#/components/schemas/WorkerInfo' DescribeWorkflowExecutionResponse: type: object properties: @@ -8242,6 +8393,10 @@ components: type: string visibilityStore: type: string + initialFailoverVersion: + type: string + failoverVersionIncrement: + type: string description: GetClusterInfoResponse contains information about Temporal cluster. GetCurrentDeploymentResponse: type: object @@ -9055,6 +9210,12 @@ components: asyncUpdate: type: boolean description: True if the namespace supports async update + workerHeartbeats: + type: boolean + description: True if the namespace supports worker heartbeats + reportedProblemsSearchAttribute: + type: boolean + description: True if the namespace supports reported problems search attribute description: Namespace capability details. Should contain what features are enabled in a namespace. NamespaceReplicationConfig: type: object @@ -9461,7 +9622,9 @@ components: description: Only the activity with this ID will be paused. type: type: string - description: Pause all running activities of this type. + description: |- + Pause all running activities of this type. + Note: Experimental - the behavior of pause by activity type might change in a future release. reason: type: string description: Reason to pause the activity. @@ -9939,7 +10102,7 @@ components: configuration, and defaults to 5. If priority is not present (or zero), then the effective priority will be - the default priority, which is is calculated by (min+max)/2. With the + the default priority, which is calculated by (min+max)/2. With the default max of 5, and min of 1, that comes out to 3. format: int32 fairnessKey: @@ -11347,6 +11510,12 @@ components: needed. If the request is unexpectedly rejected due to missing pollers, then that means the pollers have not reached to the server yet. Only set this if you expect those pollers to never arrive. + allowNoPollers: + type: boolean + description: |- + Optional. By default this request will be rejected if no pollers have been seen for the proposed + Current Version, in order to protect users from routing tasks to pollers that do not exist, leading + to possible timeouts. Pass `true` here to bypass this protection. description: Set/unset the Current Version of a Worker Deployment. SetWorkerDeploymentCurrentVersionResponse: type: object @@ -11365,6 +11534,46 @@ components: allOf: - $ref: '#/components/schemas/WorkerDeploymentVersion' description: The version that was current before executing this operation. + SetWorkerDeploymentManagerRequest: + type: object + properties: + namespace: + type: string + deploymentName: + type: string + managerIdentity: + type: string + description: |- + Arbitrary value for `manager_identity`. + Empty will unset the field. + self: + type: boolean + description: True will set `manager_identity` to `identity`. + conflictToken: + type: string + description: |- + Optional. This can be the value of conflict_token from a Describe, or another Worker + Deployment API. Passing a non-nil conflict token will cause this request to fail if the + Deployment's configuration has been modified between the API call that generated the + token and this one. + format: bytes + identity: + type: string + description: Required. The identity of the client who initiated this request. + description: Update the ManagerIdentity of a Worker Deployment. + SetWorkerDeploymentManagerResponse: + type: object + properties: + conflictToken: + type: string + description: |- + This value is returned so that it can be optionally passed to APIs + that write to the Worker Deployment state to ensure that the state + did not change between this API call and a future write. + format: bytes + previousManagerIdentity: + type: string + description: What the `manager_identity` field was before this change. SetWorkerDeploymentRampingVersionRequest: type: object properties: @@ -11415,6 +11624,12 @@ components: Note: this check only happens when the ramping version is about to change, not every time that the percentage changes. Also note that the check is against the deployment's Current Version, not the previous Ramping Version. + allowNoPollers: + type: boolean + description: |- + Optional. By default this request will be rejected if no pollers have been seen for the proposed + Current Version, in order to protect users from routing tasks to pollers that do not exist, leading + to possible timeouts. Pass `true` here to bypass this protection. description: Set/unset the Ramping Version of a Worker Deployment and its ramp percentage. SetWorkerDeploymentRampingVersionResponse: type: object @@ -11960,6 +12175,10 @@ components: allOf: - $ref: '#/components/schemas/Priority' description: Priority metadata + eagerWorkerDeploymentOptions: + allOf: + - $ref: '#/components/schemas/WorkerDeploymentOptions' + description: Deployment Options of the worker who will process the eager task. Passed when `request_eager_execution=true`. StartWorkflowExecutionResponse: type: object properties: @@ -12551,7 +12770,10 @@ components: type: object additionalProperties: type: string - description: "A key-value map for any customized purpose.\n If data already exists on the namespace, \n this will merge with the existing key values." + description: |- + A key-value map for any customized purpose. + If data already exists on the namespace, + this will merge with the existing key values. state: enum: - NAMESPACE_STATE_UNSPECIFIED @@ -13093,6 +13315,13 @@ components: Identity of the last client who modified the configuration of this Deployment. Set to the `identity` value sent by APIs such as `SetWorkerDeploymentCurrentVersion` and `SetWorkerDeploymentRampingVersion`. + managerIdentity: + type: string + description: |- + Identity of the client that has the exclusive right to make changes to this Worker Deployment. + Empty by default. + If this is set, clients whose identity does not match `manager_identity` will not be able to make changes + to this Worker Deployment. They can either set their own identity as the manager or unset the field to proceed. description: "A Worker Deployment (Deployment, for short) represents all workers serving \n a shared set of Task Queues. Typically, a Deployment represents one service or \n application.\n A Deployment contains multiple Deployment Versions, each representing a different \n version of workers. (see documentation of WorkerDeploymentVersionInfo)\n Deployment records are created in Temporal server automatically when their\n first poller arrives to the server.\n Experimental. Worker Deployments are experimental and might significantly change in the future." WorkerDeploymentInfo_WorkerDeploymentVersionSummary: type: object diff --git a/sdk-core-protos/protos/api_upstream/temporal/api/common/v1/message.proto b/sdk-core-protos/protos/api_upstream/temporal/api/common/v1/message.proto index 51acfaa2e..838f5fefc 100644 --- a/sdk-core-protos/protos/api_upstream/temporal/api/common/v1/message.proto +++ b/sdk-core-protos/protos/api_upstream/temporal/api/common/v1/message.proto @@ -280,7 +280,7 @@ message Priority { // configuration, and defaults to 5. // // If priority is not present (or zero), then the effective priority will be - // the default priority, which is is calculated by (min+max)/2. With the + // the default priority, which is calculated by (min+max)/2. With the // default max of 5, and min of 1, that comes out to 3. int32 priority_key = 1; diff --git a/sdk-core-protos/protos/api_upstream/temporal/api/deployment/v1/message.proto b/sdk-core-protos/protos/api_upstream/temporal/api/deployment/v1/message.proto index 14b4205c5..8f6685a5d 100644 --- a/sdk-core-protos/protos/api_upstream/temporal/api/deployment/v1/message.proto +++ b/sdk-core-protos/protos/api_upstream/temporal/api/deployment/v1/message.proto @@ -195,6 +195,12 @@ message WorkerDeploymentInfo { // `SetWorkerDeploymentRampingVersion`. string last_modifier_identity = 5; + // Identity of the client that has the exclusive right to make changes to this Worker Deployment. + // Empty by default. + // If this is set, clients whose identity does not match `manager_identity` will not be able to make changes + // to this Worker Deployment. They can either set their own identity as the manager or unset the field to proceed. + string manager_identity = 6; + message WorkerDeploymentVersionSummary { // Deprecated. Use `deployment_version`. string version = 1 [deprecated = true]; diff --git a/sdk-core-protos/protos/api_upstream/temporal/api/namespace/v1/message.proto b/sdk-core-protos/protos/api_upstream/temporal/api/namespace/v1/message.proto index 405cd53c9..79c44cb05 100644 --- a/sdk-core-protos/protos/api_upstream/temporal/api/namespace/v1/message.proto +++ b/sdk-core-protos/protos/api_upstream/temporal/api/namespace/v1/message.proto @@ -34,6 +34,10 @@ message NamespaceInfo { bool sync_update = 2; // True if the namespace supports async update bool async_update = 3; + // True if the namespace supports worker heartbeats + bool worker_heartbeats = 4; + // True if the namespace supports reported problems search attribute + bool reported_problems_search_attribute = 5; } // Whether scheduled workflows are supported on this namespace. This is only needed @@ -68,8 +72,8 @@ message UpdateNamespaceInfo { string description = 1; string owner_email = 2; // A key-value map for any customized purpose. - // If data already exists on the namespace, - // this will merge with the existing key values. + // If data already exists on the namespace, + // this will merge with the existing key values. map data = 3; // New namespace state, server will reject if transition is not allowed. // Allowed transitions are: diff --git a/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/request_response.proto b/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/request_response.proto index 5059575dc..37ad083c4 100644 --- a/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/request_response.proto +++ b/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/request_response.proto @@ -194,6 +194,8 @@ message StartWorkflowExecutionRequest { temporal.api.workflow.v1.OnConflictOptions on_conflict_options = 26; // Priority metadata temporal.api.common.v1.Priority priority = 27; + // Deployment Options of the worker who will process the eager task. Passed when `request_eager_execution=true`. + temporal.api.deployment.v1.WorkerDeploymentOptions eager_worker_deployment_options = 28; } message StartWorkflowExecutionResponse { @@ -1157,6 +1159,8 @@ message GetClusterInfoResponse { int32 history_shard_count = 6; string persistence_store = 7; string visibility_store = 8; + int64 initial_failover_version = 9; + int64 failover_version_increment = 10; } message GetSystemInfoRequest { @@ -1938,6 +1942,7 @@ message PauseActivityRequest { // Only the activity with this ID will be paused. string id = 4; // Pause all running activities of this type. + // Note: Experimental - the behavior of pause by activity type might change in a future release. string type = 5; } @@ -2163,6 +2168,10 @@ message SetWorkerDeploymentCurrentVersionRequest { // pollers have not reached to the server yet. Only set this if you expect those pollers to // never arrive. bool ignore_missing_task_queues = 6; + // Optional. By default this request will be rejected if no pollers have been seen for the proposed + // Current Version, in order to protect users from routing tasks to pollers that do not exist, leading + // to possible timeouts. Pass `true` here to bypass this protection. + bool allow_no_pollers = 9; } message SetWorkerDeploymentCurrentVersionResponse { @@ -2215,6 +2224,10 @@ message SetWorkerDeploymentRampingVersionRequest { // that the percentage changes. Also note that the check is against the deployment's Current // Version, not the previous Ramping Version. bool ignore_missing_task_queues = 7; + // Optional. By default this request will be rejected if no pollers have been seen for the proposed + // Current Version, in order to protect users from routing tasks to pollers that do not exist, leading + // to possible timeouts. Pass `true` here to bypass this protection. + bool allow_no_pollers = 10; } message SetWorkerDeploymentRampingVersionResponse { @@ -2248,8 +2261,8 @@ message ListWorkerDeploymentsResponse { google.protobuf.Timestamp create_time = 2; temporal.api.deployment.v1.RoutingConfig routing_config = 3; // Summary of the version that was added most recently in the Worker Deployment. - temporal.api.deployment.v1.WorkerDeploymentInfo.WorkerDeploymentVersionSummary latest_version_summary = 4; - // Summary of the current version of the Worker Deployment. + temporal.api.deployment.v1.WorkerDeploymentInfo.WorkerDeploymentVersionSummary latest_version_summary = 4; + // Summary of the current version of the Worker Deployment. temporal.api.deployment.v1.WorkerDeploymentInfo.WorkerDeploymentVersionSummary current_version_summary = 5; // Summary of the ramping version of the Worker Deployment. temporal.api.deployment.v1.WorkerDeploymentInfo.WorkerDeploymentVersionSummary ramping_version_summary = 6; @@ -2309,6 +2322,39 @@ message UpdateWorkerDeploymentVersionMetadataResponse { temporal.api.deployment.v1.VersionMetadata metadata = 1; } +// Update the ManagerIdentity of a Worker Deployment. +message SetWorkerDeploymentManagerRequest { + string namespace = 1; + string deployment_name = 2; + + oneof new_manager_identity { + // Arbitrary value for `manager_identity`. + // Empty will unset the field. + string manager_identity = 3; + + // True will set `manager_identity` to `identity`. + bool self = 4; + } + + // Optional. This can be the value of conflict_token from a Describe, or another Worker + // Deployment API. Passing a non-nil conflict token will cause this request to fail if the + // Deployment's configuration has been modified between the API call that generated the + // token and this one. + bytes conflict_token = 5; + + // Required. The identity of the client who initiated this request. + string identity = 6; +} + +message SetWorkerDeploymentManagerResponse { + // This value is returned so that it can be optionally passed to APIs + // that write to the Worker Deployment state to ensure that the state + // did not change between this API call and a future write. + bytes conflict_token = 1; + + // What the `manager_identity` field was before this change. + string previous_manager_identity = 2; +} // Returns the Current Deployment of a deployment series. // [cleanup-wv-pre-release] Pre-release deployment APIs, clean up later @@ -2537,3 +2583,15 @@ message UpdateWorkerConfigResponse { // Once we support sending update to a multiple workers - it will be converted into a batch job, and job id will be returned. } } + +message DescribeWorkerRequest { + // Namespace this worker belongs to. + string namespace = 1; + + // Worker instance key to describe. + string worker_instance_key = 2; +} + +message DescribeWorkerResponse { + temporal.api.worker.v1.WorkerInfo worker_info = 1; +} diff --git a/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/service.proto b/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/service.proto index cc74230af..dc33b84ef 100644 --- a/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/service.proto +++ b/sdk-core-protos/protos/api_upstream/temporal/api/workflowservice/v1/service.proto @@ -133,9 +133,9 @@ service WorkflowService { } }; } - - // GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse - // order (starting from last event). Fails with`NotFound` if the specified workflow execution is + + // GetWorkflowExecutionHistoryReverse returns the history of specified workflow execution in reverse + // order (starting from last event). Fails with`NotFound` if the specified workflow execution is // unknown to the service. rpc GetWorkflowExecutionHistoryReverse (GetWorkflowExecutionHistoryReverseRequest) returns (GetWorkflowExecutionHistoryReverseResponse) { option (google.api.http) = { @@ -458,7 +458,8 @@ service WorkflowService { }; } - // ScanWorkflowExecutions is a visibility API to list large amount of workflow executions in a specific namespace without order. + // ScanWorkflowExecutions _was_ a visibility API to list large amount of workflow executions in a specific namespace without order. + // It has since been deprecated in favor of `ListWorkflowExecutions` and rewritten to use `ListWorkflowExecutions` internally. // // Deprecated: Replaced with `ListWorkflowExecutions`. // (-- api-linter: core::0127::http-annotation=disabled @@ -669,8 +670,8 @@ service WorkflowService { // members are compatible with one another. // // A single build id may be mapped to multiple task queues using this API for cases where a single process hosts - // multiple workers. - // + // multiple workers. + // // To query which workers can be retired, use the `GetWorkerTaskReachability` API. // // NOTE: The number of task queues mapped to a single build id is limited by the `limit.taskQueuesPerBuildId` @@ -923,6 +924,19 @@ service WorkflowService { }; } + // Set/unset the ManagerIdentity of a Worker Deployment. + // Experimental. This API might significantly change or be removed in a future release. + rpc SetWorkerDeploymentManager (SetWorkerDeploymentManagerRequest) returns (SetWorkerDeploymentManagerResponse) { + option (google.api.http) = { + post: "/namespaces/{namespace}/worker-deployments/{deployment_name}/set-manager" + body: "*" + additional_bindings { + post: "/api/v1/namespaces/{namespace}/worker-deployments/{deployment_name}/set-manager" + body: "*" + } + }; + } + // Invokes the specified Update function on user Workflow code. rpc UpdateWorkflowExecution(UpdateWorkflowExecutionRequest) returns (UpdateWorkflowExecutionResponse) { option (google.api.http) = { @@ -1235,4 +1249,14 @@ service WorkflowService { } }; } + + // DescribeWorker returns information about the specified worker. + rpc DescribeWorker (DescribeWorkerRequest) returns (DescribeWorkerResponse) { + option (google.api.http) = { + get: "/namespaces/{namespace}/workers/describe/{worker_instance_key}" + additional_bindings { + get: "/api/v1/namespaces/{namespace}/workers/describe/{worker_instance_key}" + } + }; + } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 35b2825c3..94785b9e6 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -38,6 +38,7 @@ use temporal_sdk::{ WorkerInterceptor, }, }; +pub(crate) use temporal_sdk_core::test_help::NAMESPACE; use temporal_sdk_core::{ ClientOptions, ClientOptionsBuilder, CoreRuntime, RuntimeOptions, RuntimeOptionsBuilder, WorkerConfigBuilder, init_replay_worker, init_worker, @@ -67,8 +68,7 @@ use temporal_sdk_core_protos::{ use tokio::{sync::OnceCell, task::AbortHandle}; use tracing::{debug, warn}; use url::Url; - -pub(crate) use temporal_sdk_core::test_help::NAMESPACE; +use uuid::Uuid; /// The env var used to specify where the integ tests should point pub(crate) const INTEG_SERVER_TARGET_ENV_VAR: &str = "TEMPORAL_SERVICE_ADDRESS"; pub(crate) const INTEG_NAMESPACE_ENV_VAR: &str = "TEMPORAL_NAMESPACE"; @@ -101,7 +101,8 @@ pub(crate) fn integ_worker_config(tq: &str) -> WorkerConfigBuilder { .max_outstanding_workflow_tasks(100_usize) .versioning_strategy(WorkerVersioningStrategy::None { build_id: "test_build_id".to_owned(), - }); + }) + .skip_client_worker_set_check(true); b } @@ -498,6 +499,10 @@ impl TestWorker { &mut self.inner } + pub(crate) fn worker_instance_key(&self) -> Uuid { + self.core_worker.worker_instance_key() + } + // TODO: Maybe trait-ify? pub(crate) fn register_wf>( &mut self, diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index dc7caa812..901e7c178 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -762,15 +762,8 @@ async fn docker_metrics_with_prometheus( assert!(!data.is_empty(), "No metrics found for query: {test_uid}"); assert_eq!(data[0]["metric"]["exported_job"], "temporal-core-sdk"); assert_eq!(data[0]["metric"]["job"], "otel-collector"); - // Worker heartbeating nexus worker assert!( data[0]["metric"]["task_queue"] - .as_str() - .unwrap() - .starts_with("temporal-sys/worker-commands/default/") - ); - assert!( - data[1]["metric"]["task_queue"] .as_str() .unwrap() .starts_with(test_name) diff --git a/tests/integ_tests/worker_heartbeat_tests.rs b/tests/integ_tests/worker_heartbeat_tests.rs new file mode 100644 index 000000000..1184cc505 --- /dev/null +++ b/tests/integ_tests/worker_heartbeat_tests.rs @@ -0,0 +1,1039 @@ +use crate::common::{ANY_PORT, CoreWfStarter, eventually, get_integ_telem_options}; +use anyhow::anyhow; +use crossbeam_utils::atomic::AtomicCell; +use futures_util::StreamExt; +use prost_types::Duration as PbDuration; +use prost_types::Timestamp; +use std::collections::HashSet; +use std::env; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use temporal_client::{ + Client, NamespacedClient, RetryClient, WfClientExt, WorkflowClientTrait, WorkflowService, +}; +use temporal_sdk::{ActContext, ActivityOptions, WfContext}; +use temporal_sdk_core::telemetry::{build_otlp_metric_exporter, start_prometheus_metric_exporter}; +use temporal_sdk_core::{ + CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, RuntimeOptionsBuilder, +}; +use temporal_sdk_core_api::telemetry::{ + OtelCollectorOptionsBuilder, PrometheusExporterOptionsBuilder, TelemetryOptionsBuilder, +}; +use temporal_sdk_core_api::worker::PollerBehavior; +use temporal_sdk_core_protos::coresdk::{AsJsonPayloadExt, FromJsonPayloadExt}; +use temporal_sdk_core_protos::prost_dur; +use temporal_sdk_core_protos::temporal::api::common::v1::RetryPolicy; +use temporal_sdk_core_protos::temporal::api::enums::v1::WorkerStatus; +use temporal_sdk_core_protos::temporal::api::worker::v1::{PluginInfo, WorkerHeartbeat}; +use temporal_sdk_core_protos::temporal::api::workflowservice::v1::DescribeWorkerRequest; +use temporal_sdk_core_protos::temporal::api::workflowservice::v1::ListWorkersRequest; +use tokio::sync::Notify; +use tokio::time::sleep; +use url::Url; + +fn within_two_minutes_ts(ts: Timestamp) -> bool { + let ts_time = UNIX_EPOCH + Duration::new(ts.seconds as u64, ts.nanos as u32); + + let now = SystemTime::now(); + // ts should be at most 2 minutes before the current time + now.duration_since(ts_time).unwrap() <= Duration::from_secs(2 * 60) +} + +fn within_duration(dur: PbDuration, threshold: Duration) -> bool { + let std_dur = Duration::new(dur.seconds as u64, dur.nanos as u32); + std_dur <= threshold +} + +fn new_no_metrics_starter(wf_name: &str) -> CoreWfStarter { + let runtimeopts = RuntimeOptionsBuilder::default() + .telemetry_options(TelemetryOptionsBuilder::default().build().unwrap()) + .heartbeat_interval(Some(Duration::from_millis(100))) + .build() + .unwrap(); + CoreWfStarter::new_with_runtime(wf_name, CoreRuntime::new_assume_tokio(runtimeopts).unwrap()) +} + +fn to_system_time(ts: Timestamp) -> SystemTime { + UNIX_EPOCH + Duration::new(ts.seconds as u64, ts.nanos as u32) +} + +async fn list_worker_heartbeats( + client: &Arc>, + query: impl Into, +) -> Vec { + let mut raw_client = client.as_ref().clone(); + WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 200, + next_page_token: Vec::new(), + query: query.into(), + }, + ) + .await + .unwrap() + .into_inner() + .workers_info + .into_iter() + .filter_map(|info| info.worker_heartbeat) + .collect() +} + +// Tests that rely on Prometheus running in a docker container need to start +// with `docker_` and set the `DOCKER_PROMETHEUS_RUNNING` env variable to run +#[rstest::rstest] +#[tokio::test] +async fn docker_worker_heartbeat_basic(#[values("otel", "prom", "no_metrics")] backing: &str) { + if env::var("DOCKER_PROMETHEUS_RUNNING").is_err() { + return; + } + let telemopts = if backing == "no_metrics" { + TelemetryOptionsBuilder::default().build().unwrap() + } else { + get_integ_telem_options() + }; + let runtimeopts = RuntimeOptionsBuilder::default() + .telemetry_options(telemopts) + .heartbeat_interval(Some(Duration::from_millis(100))) + .build() + .unwrap(); + let mut rt = CoreRuntime::new_assume_tokio(runtimeopts).unwrap(); + match backing { + "otel" => { + let url = Some("grpc://localhost:4317") + .map(|x| x.parse::().unwrap()) + .unwrap(); + let mut opts_build = OtelCollectorOptionsBuilder::default(); + let opts = opts_build.url(url).build().unwrap(); + rt.telemetry_mut() + .attach_late_init_metrics(Arc::new(build_otlp_metric_exporter(opts).unwrap())); + } + "prom" => { + let mut opts_build = PrometheusExporterOptionsBuilder::default(); + opts_build.socket_addr(ANY_PORT.parse().unwrap()); + let opts = opts_build.build().unwrap(); + rt.telemetry_mut() + .attach_late_init_metrics(start_prometheus_metric_exporter(opts).unwrap().meter); + } + "no_metrics" => {} + _ => unreachable!(), + } + let wf_name = format!("worker_heartbeat_basic_{backing}"); + let mut starter = CoreWfStarter::new_with_runtime(&wf_name, rt); + starter + .worker_config + .max_outstanding_workflow_tasks(5_usize) + .max_cached_workflows(5_usize) + .max_outstanding_activities(5_usize) + .plugins(vec![ + PluginInfo { + name: "plugin1".to_string(), + version: "1".to_string(), + }, + PluginInfo { + name: "plugin2".to_string(), + version: "2".to_string(), + }, + ]); + let mut worker = starter.worker().await; + let worker_instance_key = worker.worker_instance_key(); + + worker.register_wf(wf_name.to_string(), |ctx: WfContext| async move { + ctx.activity(ActivityOptions { + activity_type: "pass_fail_act".to_string(), + input: "pass".as_json_payload().expect("serializes fine"), + start_to_close_timeout: Some(Duration::from_secs(1)), + ..Default::default() + }) + .await; + Ok(().into()) + }); + + let acts_started = Arc::new(Notify::new()); + let acts_done = Arc::new(Notify::new()); + + let acts_started_act = acts_started.clone(); + let acts_done_act = acts_done.clone(); + worker.register_activity("pass_fail_act", move |_ctx: ActContext, i: String| { + let acts_started = acts_started_act.clone(); + let acts_done = acts_done_act.clone(); + async move { + acts_started.notify_one(); + acts_done.notified().await; + Ok(i) + } + }); + + starter + .start_with_worker(wf_name.clone(), &mut worker) + .await; + + let start_time = AtomicCell::new(None); + let heartbeat_time = AtomicCell::new(None); + + let test_fut = async { + // Give enough time to ensure heartbeat interval has been hit + tokio::time::sleep(Duration::from_millis(110)).await; + acts_started.notified().await; + let client = starter.get_client().await; + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + let worker_info = workers_list + .workers_info + .iter() + .find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }) + .unwrap(); + let heartbeat = worker_info.worker_heartbeat.as_ref().unwrap(); + assert_eq!( + heartbeat.worker_instance_key, + worker_instance_key.to_string() + ); + in_activity_checks(heartbeat, &start_time, &heartbeat_time); + acts_done.notify_one(); + }; + + let runner = async move { + worker.run_until_done().await.unwrap(); + }; + tokio::join!(test_fut, runner); + + let client = starter.get_client().await; + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + // Since list_workers finds all workers in the namespace, must find specific worker used in this + // test + let worker_info = workers_list + .workers_info + .iter() + .find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }) + .unwrap(); + let heartbeat = worker_info.worker_heartbeat.as_ref().unwrap(); + after_shutdown_checks(heartbeat, &wf_name, &start_time, &heartbeat_time); +} + +// Tests that rely on Prometheus running in a docker container need to start +// with `docker_` and set the `DOCKER_PROMETHEUS_RUNNING` env variable to run +#[tokio::test] +async fn docker_worker_heartbeat_tuner() { + if env::var("DOCKER_PROMETHEUS_RUNNING").is_err() { + return; + } + let runtimeopts = RuntimeOptionsBuilder::default() + .telemetry_options(get_integ_telem_options()) + .heartbeat_interval(Some(Duration::from_millis(100))) + .build() + .unwrap(); + let mut rt = CoreRuntime::new_assume_tokio(runtimeopts).unwrap(); + + let url = Some("grpc://localhost:4317") + .map(|x| x.parse::().unwrap()) + .unwrap(); + let mut opts_build = OtelCollectorOptionsBuilder::default(); + let opts = opts_build.url(url).build().unwrap(); + + rt.telemetry_mut() + .attach_late_init_metrics(Arc::new(build_otlp_metric_exporter(opts).unwrap())); + let wf_name = "worker_heartbeat_tuner"; + let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); + let mut tuner = ResourceBasedTuner::new(0.0, 0.0); + tuner + .with_workflow_slots_options(ResourceSlotOptions::new(2, 10, Duration::from_millis(0))) + .with_activity_slots_options(ResourceSlotOptions::new(5, 10, Duration::from_millis(50))); + starter + .worker_config + .workflow_task_poller_behavior(PollerBehavior::Autoscaling { + minimum: 1, + maximum: 200, + initial: 5, + }) + .nexus_task_poller_behavior(PollerBehavior::Autoscaling { + minimum: 1, + maximum: 200, + initial: 5, + }) + .clear_max_outstanding_opts() + .tuner(Arc::new(tuner)); + let mut worker = starter.worker().await; + let worker_instance_key = worker.worker_instance_key(); + + // Run a workflow + worker.register_wf(wf_name.to_string(), |ctx: WfContext| async move { + ctx.activity(ActivityOptions { + activity_type: "pass_fail_act".to_string(), + input: "pass".as_json_payload().expect("serializes fine"), + start_to_close_timeout: Some(Duration::from_secs(1)), + ..Default::default() + }) + .await; + Ok(().into()) + }); + worker.register_activity("pass_fail_act", |_ctx: ActContext, i: String| async move { + Ok(i) + }); + + starter.start_with_worker(wf_name, &mut worker).await; + worker.run_until_done().await.unwrap(); + + let client = starter.get_client().await; + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + // Since list_workers finds all workers in the namespace, must find specific worker used in this + // test + let worker_info = workers_list + .workers_info + .iter() + .find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }) + .unwrap(); + let heartbeat = worker_info.worker_heartbeat.as_ref().unwrap(); + assert!(heartbeat.task_queue.starts_with(wf_name)); + + assert_eq!( + heartbeat + .workflow_task_slots_info + .clone() + .unwrap() + .slot_supplier_kind, + "ResourceBased" + ); + assert_eq!( + heartbeat + .activity_task_slots_info + .clone() + .unwrap() + .slot_supplier_kind, + "ResourceBased" + ); + assert_eq!( + heartbeat + .nexus_task_slots_info + .clone() + .unwrap() + .slot_supplier_kind, + "ResourceBased" + ); + assert_eq!( + heartbeat + .local_activity_slots_info + .clone() + .unwrap() + .slot_supplier_kind, + "ResourceBased" + ); + + let workflow_poller_info = heartbeat.workflow_poller_info.unwrap(); + assert!(workflow_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + workflow_poller_info.last_successful_poll_time.unwrap() + )); + let sticky_poller_info = heartbeat.workflow_sticky_poller_info.unwrap(); + assert!(sticky_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + sticky_poller_info.last_successful_poll_time.unwrap() + )); + let nexus_poller_info = heartbeat.nexus_poller_info.unwrap(); + assert!(nexus_poller_info.is_autoscaling); + assert!(nexus_poller_info.last_successful_poll_time.is_none()); + let activity_poller_info = heartbeat.activity_poller_info.unwrap(); + assert!(!activity_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + activity_poller_info.last_successful_poll_time.unwrap() + )); +} + +fn in_activity_checks( + heartbeat: &WorkerHeartbeat, + start_time: &AtomicCell>, + heartbeat_time: &AtomicCell>, +) { + assert_eq!(heartbeat.status, WorkerStatus::Running as i32); + + let workflow_task_slots = heartbeat.workflow_task_slots_info.clone().unwrap(); + assert_eq!(workflow_task_slots.total_processed_tasks, 1); + assert_eq!(workflow_task_slots.current_available_slots, 5); + assert_eq!(workflow_task_slots.current_used_slots, 0); + assert_eq!(workflow_task_slots.slot_supplier_kind, "Fixed"); + let activity_task_slots = heartbeat.activity_task_slots_info.clone().unwrap(); + assert_eq!(activity_task_slots.current_available_slots, 4); + assert_eq!(activity_task_slots.current_used_slots, 1); + assert_eq!(activity_task_slots.slot_supplier_kind, "Fixed"); + let nexus_task_slots = heartbeat.nexus_task_slots_info.clone().unwrap(); + assert_eq!(nexus_task_slots.current_available_slots, 0); + assert_eq!(nexus_task_slots.current_used_slots, 0); + assert_eq!(nexus_task_slots.slot_supplier_kind, "Fixed"); + let local_activity_task_slots = heartbeat.local_activity_slots_info.clone().unwrap(); + assert_eq!(local_activity_task_slots.current_available_slots, 100); + assert_eq!(local_activity_task_slots.current_used_slots, 0); + assert_eq!(local_activity_task_slots.slot_supplier_kind, "Fixed"); + + let workflow_poller_info = heartbeat.workflow_poller_info.unwrap(); + assert_eq!(workflow_poller_info.current_pollers, 1); + let sticky_poller_info = heartbeat.workflow_sticky_poller_info.unwrap(); + assert_ne!(sticky_poller_info.current_pollers, 0); + let nexus_poller_info = heartbeat.nexus_poller_info.unwrap(); + assert_eq!(nexus_poller_info.current_pollers, 0); + let activity_poller_info = heartbeat.activity_poller_info.unwrap(); + assert_ne!(activity_poller_info.current_pollers, 0); + assert_ne!(heartbeat.current_sticky_cache_size, 0); + start_time.store(Some(heartbeat.start_time.unwrap())); + heartbeat_time.store(Some(heartbeat.heartbeat_time.unwrap())); +} + +fn after_shutdown_checks( + heartbeat: &WorkerHeartbeat, + wf_name: &str, + start_time: &AtomicCell>, + heartbeat_time: &AtomicCell>, +) { + assert_eq!(heartbeat.worker_identity, "integ_tester"); + let host_info = heartbeat.host_info.clone().unwrap(); + assert!(!host_info.host_name.is_empty()); + assert!(!host_info.process_key.is_empty()); + assert!(!host_info.process_id.is_empty()); + assert_ne!(host_info.current_host_cpu_usage, 0.0); + assert_ne!(host_info.current_host_mem_usage, 0.0); + + assert!(heartbeat.task_queue.starts_with(wf_name)); + assert_eq!( + heartbeat.deployment_version.clone().unwrap().build_id, + "test_build_id" + ); + assert_eq!(heartbeat.sdk_name, "temporal-core"); + assert_eq!(heartbeat.sdk_version, "0.1.0"); + assert_eq!(heartbeat.status, WorkerStatus::Shutdown as i32); + + assert_eq!(start_time.load().unwrap(), heartbeat.start_time.unwrap()); + assert_ne!( + heartbeat_time.load().unwrap(), + heartbeat.heartbeat_time.unwrap() + ); + assert!(within_two_minutes_ts(heartbeat.start_time.unwrap())); + assert!(within_two_minutes_ts(heartbeat.heartbeat_time.unwrap())); + assert!( + to_system_time(heartbeat_time.load().unwrap()) + < to_system_time(heartbeat.heartbeat_time.unwrap()) + ); + assert!(within_duration( + heartbeat.elapsed_since_last_heartbeat.unwrap(), + Duration::from_millis(200) + )); + + let workflow_task_slots = heartbeat.workflow_task_slots_info.clone().unwrap(); + assert_eq!(workflow_task_slots.current_available_slots, 5); + assert_eq!(workflow_task_slots.current_used_slots, 1); + assert_eq!(workflow_task_slots.total_processed_tasks, 2); + assert_eq!(workflow_task_slots.slot_supplier_kind, "Fixed"); + let activity_task_slots = heartbeat.activity_task_slots_info.clone().unwrap(); + assert_eq!(activity_task_slots.current_available_slots, 5); + assert_eq!(workflow_task_slots.current_used_slots, 1); + assert_eq!(activity_task_slots.slot_supplier_kind, "Fixed"); + assert_eq!(activity_task_slots.last_interval_processed_tasks, 1); + let nexus_task_slots = heartbeat.nexus_task_slots_info.clone().unwrap(); + assert_eq!(nexus_task_slots.current_available_slots, 0); + assert_eq!(nexus_task_slots.current_used_slots, 0); + assert_eq!(nexus_task_slots.slot_supplier_kind, "Fixed"); + let local_activity_task_slots = heartbeat.local_activity_slots_info.clone().unwrap(); + assert_eq!(local_activity_task_slots.current_available_slots, 100); + assert_eq!(local_activity_task_slots.current_used_slots, 0); + assert_eq!(local_activity_task_slots.slot_supplier_kind, "Fixed"); + + let workflow_poller_info = heartbeat.workflow_poller_info.unwrap(); + assert!(!workflow_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + workflow_poller_info.last_successful_poll_time.unwrap() + )); + let sticky_poller_info = heartbeat.workflow_sticky_poller_info.unwrap(); + assert!(!sticky_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + sticky_poller_info.last_successful_poll_time.unwrap() + )); + let nexus_poller_info = heartbeat.nexus_poller_info.unwrap(); + assert!(!nexus_poller_info.is_autoscaling); + assert!(nexus_poller_info.last_successful_poll_time.is_none()); + let activity_poller_info = heartbeat.activity_poller_info.unwrap(); + assert!(!activity_poller_info.is_autoscaling); + assert!(within_two_minutes_ts( + activity_poller_info.last_successful_poll_time.unwrap() + )); + + assert_eq!(heartbeat.total_sticky_cache_hit, 2); + assert_eq!(heartbeat.current_sticky_cache_size, 0); + assert_eq!( + heartbeat.plugins, + vec![ + PluginInfo { + name: "plugin1".to_string(), + version: "1".to_string() + }, + PluginInfo { + name: "plugin2".to_string(), + version: "2".to_string() + } + ] + ); +} + +#[tokio::test] +async fn worker_heartbeat_sticky_cache_miss() { + let wf_name = "worker_heartbeat_cache_miss"; + let mut starter = new_no_metrics_starter(wf_name); + starter + .worker_config + .max_cached_workflows(1_usize) + .max_outstanding_workflow_tasks(2_usize); + + let mut worker = starter.worker().await; + worker.fetch_results = false; + let worker_key = worker.worker_instance_key().to_string(); + let worker_core = worker.core_worker.clone(); + let submitter = worker.get_submitter_handle(); + let wf_opts = starter.workflow_options.clone(); + let client = starter.get_client().await; + let client_for_orchestrator = client.clone(); + + static HISTORY_WF1_ACTIVITY_STARTED: Notify = Notify::const_new(); + static HISTORY_WF1_ACTIVITY_FINISH: Notify = Notify::const_new(); + static HISTORY_WF2_ACTIVITY_STARTED: Notify = Notify::const_new(); + static HISTORY_WF2_ACTIVITY_FINISH: Notify = Notify::const_new(); + + worker.register_wf(wf_name.to_string(), |ctx: WfContext| async move { + let wf_marker = ctx + .get_args() + .first() + .and_then(|p| String::from_json_payload(p).ok()) + .unwrap_or_else(|| "wf1".to_string()); + + ctx.activity(ActivityOptions { + activity_type: "sticky_cache_history_act".to_string(), + input: wf_marker.clone().as_json_payload().expect("serialize"), + start_to_close_timeout: Some(Duration::from_secs(5)), + ..Default::default() + }) + .await; + + Ok(().into()) + }); + worker.register_activity( + "sticky_cache_history_act", + |_ctx: ActContext, marker: String| async move { + match marker.as_str() { + "wf1" => { + HISTORY_WF1_ACTIVITY_STARTED.notify_one(); + HISTORY_WF1_ACTIVITY_FINISH.notified().await; + } + "wf2" => { + HISTORY_WF2_ACTIVITY_STARTED.notify_one(); + HISTORY_WF2_ACTIVITY_FINISH.notified().await; + } + _ => {} + } + Ok(marker) + }, + ); + + let wf1_id = format!("{wf_name}_wf1"); + let wf2_id = format!("{wf_name}_wf2"); + + let orchestrator = async move { + let wf1_run = submitter + .submit_wf( + wf1_id.clone(), + wf_name.to_string(), + vec!["wf1".to_string().as_json_payload().unwrap()], + wf_opts.clone(), + ) + .await + .unwrap(); + + HISTORY_WF1_ACTIVITY_STARTED.notified().await; + + let wf2_run = submitter + .submit_wf( + wf2_id.clone(), + wf_name.to_string(), + vec!["wf2".to_string().as_json_payload().unwrap()], + wf_opts, + ) + .await + .unwrap(); + + HISTORY_WF2_ACTIVITY_STARTED.notified().await; + + HISTORY_WF1_ACTIVITY_FINISH.notify_one(); + let handle1 = client_for_orchestrator.get_untyped_workflow_handle(wf1_id, wf1_run); + handle1 + .get_workflow_result(Default::default()) + .await + .expect("wf1 result"); + + HISTORY_WF2_ACTIVITY_FINISH.notify_one(); + let handle2 = client_for_orchestrator.get_untyped_workflow_handle(wf2_id, wf2_run); + handle2 + .get_workflow_result(Default::default()) + .await + .expect("wf2 result"); + + worker_core.initiate_shutdown(); + }; + + let mut worker_runner = worker; + let runner = async move { + worker_runner.run_until_done().await.unwrap(); + }; + + tokio::join!(orchestrator, runner); + + sleep(Duration::from_millis(200)).await; + let mut heartbeats = + list_worker_heartbeats(&client, format!("WorkerInstanceKey=\"{worker_key}\"")).await; + assert_eq!(heartbeats.len(), 1); + let heartbeat = heartbeats.pop().unwrap(); + + assert!(heartbeat.total_sticky_cache_miss >= 1); + assert_eq!(heartbeat.worker_instance_key, worker_key); +} + +#[tokio::test] +async fn worker_heartbeat_multiple_workers() { + let wf_name = "worker_heartbeat_multi_workers"; + let mut starter = new_no_metrics_starter(wf_name); + starter + .worker_config + .max_outstanding_workflow_tasks(5_usize) + .max_cached_workflows(5_usize); + + let client = starter.get_client().await; + let starting_hb_len = list_worker_heartbeats(&client, String::new()).await.len(); + + let mut worker_a = starter.worker().await; + worker_a.register_wf(wf_name.to_string(), |_ctx: WfContext| async move { + Ok(().into()) + }); + worker_a.register_activity("failing_act", |_ctx: ActContext, _: String| async move { + Ok(()) + }); + + let mut starter_b = starter.clone_no_worker(); + let mut worker_b = starter_b.worker().await; + worker_b.register_wf(wf_name.to_string(), |_ctx: WfContext| async move { + Ok(().into()) + }); + worker_b.register_activity("failing_act", |_ctx: ActContext, _: String| async move { + Ok(()) + }); + + let worker_a_key = worker_a.worker_instance_key().to_string(); + let worker_b_key = worker_b.worker_instance_key().to_string(); + let _ = starter.start_with_worker(wf_name, &mut worker_a).await; + worker_a.run_until_done().await.unwrap(); + + let _ = starter_b.start_with_worker(wf_name, &mut worker_b).await; + worker_b.run_until_done().await.unwrap(); + + sleep(Duration::from_millis(200)).await; + + let all = list_worker_heartbeats(&client, String::new()).await; + let keys: HashSet<_> = all + .iter() + .map(|hb| hb.worker_instance_key.clone()) + .collect(); + assert!(keys.contains(&worker_a_key)); + assert!(keys.contains(&worker_b_key)); + + // Verify both heartbeats contain the same shared process_key + let process_keys: HashSet<_> = all + .iter() + .filter_map(|hb| hb.host_info.as_ref().map(|info| info.process_key.clone())) + .collect(); + assert!(process_keys.len() > starting_hb_len); + + let filtered = + list_worker_heartbeats(&client, format!("WorkerInstanceKey=\"{worker_a_key}\"")).await; + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].worker_instance_key, worker_a_key); + + // Verify describe worker gives the same heartbeat as listworker + let mut raw_client = client.as_ref().clone(); + let describe_worker_a = WorkflowService::describe_worker( + &mut raw_client, + DescribeWorkerRequest { + namespace: client.namespace().to_owned(), + worker_instance_key: worker_a_key.to_string(), + }, + ) + .await + .unwrap() + .into_inner() + .worker_info + .unwrap() + .worker_heartbeat + .unwrap(); + assert_eq!(describe_worker_a, filtered[0]); + + let filtered_b = + list_worker_heartbeats(&client, format!("WorkerInstanceKey = \"{worker_b_key}\"")).await; + assert_eq!(filtered_b.len(), 1); + assert_eq!(filtered_b[0].worker_instance_key, worker_b_key); + let describe_worker_b = WorkflowService::describe_worker( + &mut raw_client, + DescribeWorkerRequest { + namespace: client.namespace().to_owned(), + worker_instance_key: worker_b_key.to_string(), + }, + ) + .await + .unwrap() + .into_inner() + .worker_info + .unwrap() + .worker_heartbeat + .unwrap(); + assert_eq!(describe_worker_b, filtered_b[0]); +} + +#[tokio::test] +async fn worker_heartbeat_failure_metrics() { + const WORKFLOW_CONTINUE_SIGNAL: &str = "workflow-continue"; + + let wf_name = "worker_heartbeat_failure_metrics"; + let mut starter = new_no_metrics_starter(wf_name); + starter.worker_config.max_outstanding_activities(5_usize); + + let mut worker = starter.worker().await; + let worker_instance_key = worker.worker_instance_key(); + static ACT_COUNT: AtomicU64 = AtomicU64::new(0); + static WF_COUNT: AtomicU64 = AtomicU64::new(0); + static ACT_FAIL: Notify = Notify::const_new(); + static WF_FAIL: Notify = Notify::const_new(); + worker.register_wf(wf_name.to_string(), |ctx: WfContext| async move { + let _ = ctx + .activity(ActivityOptions { + activity_type: "failing_act".to_string(), + input: "boom".as_json_payload().expect("serialize"), + start_to_close_timeout: Some(Duration::from_secs(1)), + retry_policy: Some(RetryPolicy { + initial_interval: Some(prost_dur!(from_millis(10))), + backoff_coefficient: 1.0, + maximum_attempts: 4, + ..Default::default() + }), + ..Default::default() + }) + .await; + + if WF_COUNT.load(Ordering::Relaxed) == 0 { + WF_COUNT.fetch_add(1, Ordering::Relaxed); + WF_FAIL.notify_one(); + panic!("expected WF panic"); + } + + // Signal here to avoid workflow from completing and shutdown heartbeat from sending + // before we check workflow_slots.last_interval_failure_tasks + let mut proceed_signal = ctx.make_signal_channel(WORKFLOW_CONTINUE_SIGNAL); + proceed_signal.next().await.unwrap(); + Ok(().into()) + }); + + worker.register_activity("failing_act", |_ctx: ActContext, _: String| async move { + if ACT_COUNT.load(Ordering::Relaxed) == 3 { + return Ok(()); + } + ACT_COUNT.fetch_add(1, Ordering::Relaxed); + ACT_FAIL.notify_one(); + Err(anyhow!("Expected error").into()) + }); + + let worker_key = worker_instance_key.to_string(); + starter.workflow_options.retry_policy = Some(RetryPolicy { + maximum_attempts: 2, + ..Default::default() + }); + + let _ = starter.start_with_worker(wf_name, &mut worker).await; + + let test_fut = async { + ACT_FAIL.notified().await; + let client = starter.get_client().await; + eventually( + || async { + let mut raw_client = (*client).clone(); + + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + let worker_info = workers_list + .workers_info + .iter() + .find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }) + .unwrap(); + let heartbeat = worker_info.worker_heartbeat.as_ref().unwrap(); + assert_eq!( + heartbeat.worker_instance_key, + worker_instance_key.to_string() + ); + let activity_slots = heartbeat.activity_task_slots_info.clone().unwrap(); + if activity_slots.last_interval_failure_tasks >= 1 { + return Ok(()); + } + Err("activity_slots.last_interval_failure_tasks still 0, retrying") + }, + Duration::from_millis(150), + ) + .await + .unwrap(); + + WF_FAIL.notified().await; + + eventually( + || async { + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + let worker_info = workers_list + .workers_info + .iter() + .find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }) + .unwrap(); + + let heartbeat = worker_info.worker_heartbeat.as_ref().unwrap(); + let workflow_slots = heartbeat.workflow_task_slots_info.clone().unwrap(); + if workflow_slots.last_interval_failure_tasks >= 1 { + return Ok(()); + } + Err("workflow_slots.last_interval_failure_tasks still 0, retrying") + }, + Duration::from_millis(150), + ) + .await + .unwrap(); + client + .signal_workflow_execution( + starter.get_wf_id().to_string(), + String::new(), + WORKFLOW_CONTINUE_SIGNAL.to_string(), + None, + None, + ) + .await + .unwrap(); + }; + + let runner = async move { + worker.run_until_done().await.unwrap(); + }; + tokio::join!(test_fut, runner); + + let client = starter.get_client().await; + let mut heartbeats = + list_worker_heartbeats(&client, format!("WorkerInstanceKey=\"{worker_key}\"")).await; + assert_eq!(heartbeats.len(), 1); + let heartbeat = heartbeats.pop().unwrap(); + + let activity_slots = heartbeat.activity_task_slots_info.unwrap(); + assert_eq!(activity_slots.total_failed_tasks, 3); + + let workflow_slots = heartbeat.workflow_task_slots_info.unwrap(); + assert_eq!(workflow_slots.total_failed_tasks, 1); +} + +#[tokio::test] +async fn worker_heartbeat_no_runtime_heartbeat() { + let wf_name = "worker_heartbeat_no_runtime_heartbeat"; + let runtimeopts = RuntimeOptionsBuilder::default() + .telemetry_options(get_integ_telem_options()) + .heartbeat_interval(None) // Turn heartbeating off + .build() + .unwrap(); + let rt = CoreRuntime::new_assume_tokio(runtimeopts).unwrap(); + let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); + let mut worker = starter.worker().await; + let worker_instance_key = worker.worker_instance_key(); + + worker.register_wf(wf_name.to_owned(), |ctx: WfContext| async move { + ctx.activity(ActivityOptions { + activity_type: "pass_fail_act".to_string(), + input: "pass".as_json_payload().expect("serializes fine"), + start_to_close_timeout: Some(Duration::from_secs(1)), + ..Default::default() + }) + .await; + Ok(().into()) + }); + + worker.register_activity("pass_fail_act", |_ctx: ActContext, i: String| async move { + Ok(i) + }); + + starter + .start_with_worker(wf_name.to_owned(), &mut worker) + .await; + + worker.run_until_done().await.unwrap(); + let client = starter.get_client().await; + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + + // Ensure worker has not ever heartbeated + let heartbeat = workers_list.workers_info.iter().find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }); + assert!(heartbeat.is_none()); +} + +#[tokio::test] +async fn worker_heartbeat_skip_client_worker_set_check() { + let wf_name = "worker_heartbeat_skip_client_worker_set_check"; + let runtimeopts = RuntimeOptionsBuilder::default() + .telemetry_options(get_integ_telem_options()) + .heartbeat_interval(Some(Duration::from_millis(100))) + .build() + .unwrap(); + let rt = CoreRuntime::new_assume_tokio(runtimeopts).unwrap(); + let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); + starter.worker_config.skip_client_worker_set_check(true); + let mut worker = starter.worker().await; + let worker_instance_key = worker.worker_instance_key(); + + worker.register_wf(wf_name.to_owned(), |ctx: WfContext| async move { + ctx.activity(ActivityOptions { + activity_type: "pass_fail_act".to_string(), + input: "pass".as_json_payload().expect("serializes fine"), + start_to_close_timeout: Some(Duration::from_secs(1)), + ..Default::default() + }) + .await; + Ok(().into()) + }); + + worker.register_activity("pass_fail_act", |_ctx: ActContext, i: String| async move { + Ok(i) + }); + + starter + .start_with_worker(wf_name.to_owned(), &mut worker) + .await; + + worker.run_until_done().await.unwrap(); + let client = starter.get_client().await; + let mut raw_client = (*client).clone(); + let workers_list = WorkflowService::list_workers( + &mut raw_client, + ListWorkersRequest { + namespace: client.namespace().to_owned(), + page_size: 100, + next_page_token: Vec::new(), + query: String::new(), + }, + ) + .await + .unwrap() + .into_inner(); + + // Ensure worker still heartbeats + let heartbeat = workers_list.workers_info.iter().find(|worker_info| { + if let Some(hb) = worker_info.worker_heartbeat.as_ref() { + hb.worker_instance_key == worker_instance_key.to_string() + } else { + false + } + }); + assert!(heartbeat.is_some()); +} diff --git a/tests/main.rs b/tests/main.rs index 8a71f03ca..d48a68bd5 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -22,6 +22,7 @@ mod integ_tests { mod queries_tests; mod update_tests; mod visibility_tests; + mod worker_heartbeat_tests; mod worker_tests; mod worker_versioning_tests; mod workflow_tests; diff --git a/tests/runner.rs b/tests/runner.rs index f2d843968..af763cb4a 100644 --- a/tests/runner.rs +++ b/tests/runner.rs @@ -121,6 +121,10 @@ async fn main() -> Result<(), anyhow::Error> { "system.enableDeploymentVersions=true".to_owned(), "--dynamic-config-value".to_owned(), "component.nexusoperations.recordCancelRequestCompletionEvents=true".to_owned(), + "--dynamic-config-value".to_owned(), + "frontend.WorkerHeartbeatsEnabled=true".to_owned(), + "--dynamic-config-value".to_owned(), + "frontend.ListWorkersEnabled=true".to_owned(), "--http-port".to_string(), "7243".to_string(), "--search-attribute".to_string(), From cfecca3e1d026f1e2f7a706bca87f64398daf9c9 Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Fri, 17 Oct 2025 11:31:13 -0700 Subject: [PATCH 3/5] merge --- .clippy.toml | 1 + .github/workflows/heavy.yml | 2 + .github/workflows/per-pr.yml | 42 +- Cargo.toml | 14 +- README.md | 4 + arch_docs/sdks_intro.md | 15 +- client/Cargo.toml | 1 + client/src/lib.rs | 546 +++++++------ client/src/metrics.rs | 3 +- client/src/raw.rs | 726 ++++++++++-------- client/src/replaceable.rs | 253 ++++++ client/src/retry.rs | 15 +- client/src/workflow_handle/mod.rs | 34 +- core-api/src/envconfig.rs | 74 +- core-c-bridge/Cargo.toml | 9 +- .../include/temporal-sdk-core-c-bridge.h | 195 ++++- core-c-bridge/src/client.rs | 617 ++++++++++----- core-c-bridge/src/envconfig.rs | 314 ++++++++ core-c-bridge/src/lib.rs | 1 + core-c-bridge/src/tests/mod.rs | 26 +- core-c-bridge/src/worker.rs | 280 +++++-- core/Cargo.toml | 9 +- core/src/core_tests/updates.rs | 4 +- core/src/internal_flags.rs | 1 + core/src/lib.rs | 92 ++- core/src/protosext/mod.rs | 10 +- core/src/protosext/protocol_messages.rs | 15 +- core/src/retry_logic.rs | 364 ++++++--- core/src/test_help/integ_helpers.rs | 1 - core/src/worker/activities.rs | 13 +- .../activities/activity_heartbeat_manager.rs | 48 +- .../src/worker/activities/local_activities.rs | 38 +- core/src/worker/client.rs | 204 ++--- core/src/worker/client/mocks.rs | 2 +- core/src/worker/mod.rs | 17 +- core/src/worker/nexus.rs | 17 +- .../workflow/machines/patch_state_machine.rs | 11 +- core/src/worker/workflow/mod.rs | 1 + core/src/worker/workflow/workflow_stream.rs | 11 +- .../trybuild/dupe_transitions_fail.stderr | 10 +- sdk-core-protos/Cargo.toml | 9 +- sdk-core-protos/build.rs | 33 +- sdk-core-protos/src/history_builder.rs | 2 +- sdk-core-protos/src/lib.rs | 27 +- sdk-core-protos/src/utilities.rs | 19 +- sdk/Cargo.toml | 10 +- sdk/src/activity_context.rs | 2 +- sdk/src/app_data.rs | 2 +- tests/c_bridge_smoke_test.c | 10 + tests/common/mod.rs | 68 +- tests/integ_tests/client_tests.rs | 37 +- tests/integ_tests/ephemeral_server_tests.rs | 22 +- tests/integ_tests/metrics_tests.rs | 23 +- tests/integ_tests/polling_tests.rs | 402 +++++++--- tests/integ_tests/update_tests.rs | 6 +- tests/integ_tests/worker_heartbeat_tests.rs | 31 +- tests/integ_tests/worker_tests.rs | 293 ++++++- tests/integ_tests/worker_versioning_tests.rs | 57 +- .../integ_tests/workflow_tests/activities.rs | 90 ++- tests/integ_tests/workflow_tests/nexus.rs | 10 + tests/integ_tests/workflow_tests/patches.rs | 12 +- tests/integ_tests/workflow_tests/resets.rs | 49 +- tests/main.rs | 32 +- tests/runner.rs | 50 +- 64 files changed, 3800 insertions(+), 1536 deletions(-) create mode 100644 .clippy.toml create mode 100644 client/src/replaceable.rs create mode 100644 core-c-bridge/src/envconfig.rs create mode 100644 tests/c_bridge_smoke_test.c diff --git a/.clippy.toml b/.clippy.toml new file mode 100644 index 000000000..b9bf2c47f --- /dev/null +++ b/.clippy.toml @@ -0,0 +1 @@ +allow-dbg-in-tests = true \ No newline at end of file diff --git a/.github/workflows/heavy.yml b/.github/workflows/heavy.yml index 3c04d85fe..e3833f764 100644 --- a/.github/workflows/heavy.yml +++ b/.github/workflows/heavy.yml @@ -19,6 +19,8 @@ jobs: with: submodules: recursive - uses: dtolnay/rust-toolchain@stable + with: + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: diff --git a/.github/workflows/per-pr.yml b/.github/workflows/per-pr.yml index 8a26e8a06..9c7ce75c5 100644 --- a/.github/workflows/per-pr.yml +++ b/.github/workflows/per-pr.yml @@ -21,7 +21,7 @@ jobs: submodules: recursive - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.88.0 + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -49,16 +49,16 @@ jobs: - os: ubuntu-latest - os: ubuntu-arm runsOn: ubuntu-24.04-arm64-2-core - - os: macos-intel - runsOn: macos-13 - os: macos-arm runsOn: macos-14 + - os: macos-intel + runsOn: macos-15-intel runs-on: ${{ matrix.runsOn || matrix.os }} steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.88.0 + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -97,16 +97,16 @@ jobs: - os: ubuntu-latest - os: ubuntu-arm runsOn: ubuntu-24.04-arm64-2-core - - os: macos-intel - runsOn: macos-13 - os: macos-arm runsOn: macos-14 + - os: macos-intel + runsOn: macos-15-intel runs-on: ${{ matrix.runsOn || matrix.os }} steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.88.0 + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -139,7 +139,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.88.0 + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -164,7 +164,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.88.0 + toolchain: 1.90.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -177,3 +177,27 @@ jobs: compose-file: ./docker/docker-compose-ci.yaml - uses: Swatinem/rust-cache@v2 - run: cargo integ-test docker_ + + c-bridge-static-link-test: + name: "C bridge static link test" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: 1.90.0 + - name: Install protoc + uses: arduino/setup-protoc@v3 + with: + # TODO: Upgrade proto once https://github.com/arduino/setup-protoc/issues/99 is fixed + version: "23.x" + repo-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Build crate as static library + run: cargo rustc --package temporal-sdk-core-c-bridge --features xz2-static -- --crate-type=staticlib + + - name: Build C test program + run: gcc -I./core-c-bridge/include tests/c_bridge_smoke_test.c target/debug/deps/libtemporal_sdk_core_c_bridge.a -lpthread -ldl -lm -o test + + - name: Run C test program + run: ./test diff --git a/Cargo.toml b/Cargo.toml index 915fce921..a4eef1f75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,15 +25,19 @@ derive_more = { version = "2.0", features = [ "try_into", ] } thiserror = "2" -tonic = "0.13" -tonic-build = "0.13" -opentelemetry = { version = "0.30", features = ["metrics"] } -prost = "0.13" -prost-types = "0.13" +tonic = "0.14" +tonic-prost = "0.14" +tonic-prost-build = "0.14" +opentelemetry = { version = "0.31", features = ["metrics"] } +prost = "0.14" +prost-types = { version = "0.7", package = "prost-wkt-types" } [workspace.lints.rust] unreachable_pub = "warn" +[workspace.lints.clippy] +dbg_macro = "warn" + [profile.release-lto] inherits = "release" lto = true diff --git a/README.md b/README.md index a481574df..aee4960b5 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,10 @@ use an already-running server by passing `-s external`. Run load tests with `cargo test --test heavy_tests`. +NOTE: Integration tests should pass locally, if running on MacOS and you see integration tests consistently failing +with an error that mentions `Too many open files`, this is likely due to `ulimit -n` being too low. You can raise +it temporarily (current shell) with `ulimit -n 65535`, or add it to your `~/.zshrc` file to apply to all shells. + ## Formatting To format all code run: diff --git a/arch_docs/sdks_intro.md b/arch_docs/sdks_intro.md index a61b645a7..f8d33d717 100644 --- a/arch_docs/sdks_intro.md +++ b/arch_docs/sdks_intro.md @@ -119,10 +119,13 @@ refer back to this diagram during later explanations. ```mermaid %%{ init : { "theme" : "default", "themeVariables" : { "background" : "#fff" }}}%% sequenceDiagram - participant Client - participant Server - participant SDK - participant UserCode as "User Code" + box rgb(255, 255, 255) + participant Client + participant Server + participant SDK + participant UserCode as "User Code" + end + Client ->> Server: StartWorkflowExecution RPC Server ->> Server: Generate Workflow Task & persist Server -->> Client: runID @@ -280,7 +283,7 @@ with a command already in history. If the order of arguments to `asyncio.gather` is swapped in the previous example Workflow, replaying the Workflow will result in a nondeterminism error (often referred to as an "NDE"): -``` +``` temporalio.workflow.NondeterminismError: Workflow activation completion failed: Failure { failure: Some(Failure { message: "[TMPRL1100] Nondeterminism error: Timer machine does not handle this event: HistoryEvent(id: 5, ActivityTaskScheduled)", source: "", stack_trace: "", encoded_attributes: None, @@ -293,4 +296,4 @@ This error occurs because the "dual" events in the history are not in the same o produced by the modified Workflow code. The timer machine, expecting a `TimerStarted` event, encounters an `ActivityTaskScheduled` event instead, leading to a nondeterminism error. In other words, we did not produce the same commands in the same order. This mechanism ensures Workflow state -consistency. \ No newline at end of file +consistency. diff --git a/client/Cargo.toml b/client/Cargo.toml index 50095bfde..f083ed777 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -20,6 +20,7 @@ backoff = "0.4" base64 = "0.22" derive_builder = { workspace = true } derive_more = { workspace = true } +dyn-clone = "1.0" bytes = "1.10" futures-util = { version = "0.3", default-features = false } futures-retry = "0.6.0" diff --git a/client/src/lib.rs b/client/src/lib.rs index 4d6088f87..4d633aa3c 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -13,6 +13,7 @@ mod metrics; #[doc(hidden)] pub mod proxy; mod raw; +mod replaceable; mod retry; mod worker_registry; mod workflow_handle; @@ -23,6 +24,7 @@ pub use crate::{ }; pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME}; pub use raw::{CloudService, HealthService, OperatorService, TestService, WorkflowService}; +pub use replaceable::SharedReplaceableClient; pub use temporal_sdk_core_protos::temporal::api::{ enums::v1::ArchivalState, filter::v1::{StartTimeFilter, StatusFilter, WorkflowExecutionFilter, WorkflowTypeFilter}, @@ -41,7 +43,7 @@ pub use workflow_handle::{ use crate::{ metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext}, - raw::{AttachMetricLabels, sealed::RawClientLike}, + raw::AttachMetricLabels, sealed::WfHandleClient, workflow_handle::UntypedWorkflowHandle, }; @@ -76,7 +78,7 @@ use temporal_sdk_core_protos::{ }, }; use tonic::{ - Code, + Code, IntoRequest, body::Body, client::GrpcService, codegen::InterceptedService, @@ -518,8 +520,7 @@ impl ClientOptions { pub async fn connect_no_namespace( &self, metrics_meter: Option, - ) -> Result>, ClientInitError> - { + ) -> Result>, ClientInitError> { self.connect_no_namespace_with_service_override(metrics_meter, None) .await } @@ -534,8 +535,7 @@ impl ClientOptions { &self, metrics_meter: Option, service_override: Option, - ) -> Result>, ClientInitError> - { + ) -> Result>, ClientInitError> { let service = if let Some(service_override) = service_override { GrpcMetricSvc { inner: ChannelOrGrpcOverride::GrpcOverride(service_override), @@ -595,7 +595,7 @@ impl ClientOptions { }; if !self.skip_get_system_info { match client - .get_system_info(GetSystemInfoRequest::default()) + .get_system_info(GetSystemInfoRequest::default().into_request()) .await { Ok(sysinfo) => { @@ -614,11 +614,13 @@ impl ClientOptions { /// Passes it through if TLS options not set. async fn add_tls_to_channel(&self, mut channel: Endpoint) -> Result { if let Some(tls_cfg) = &self.tls_cfg { - let mut tls = tonic::transport::ClientTlsConfig::new().with_native_roots(); + let mut tls = tonic::transport::ClientTlsConfig::new(); if let Some(root_cert) = &tls_cfg.server_root_ca_cert { let server_root_ca_cert = Certificate::from_pem(root_cert); tls = tls.ca_certificate(server_root_ca_cert); + } else { + tls = tls.with_native_roots(); } if let Some(domain) = &tls_cfg.domain { @@ -737,14 +739,13 @@ impl Interceptor for ServiceCallInterceptor { } /// Aggregates various services exposed by the Temporal server -#[derive(Debug, Clone)] -pub struct TemporalServiceClient { - svc: T, - workflow_svc_client: OnceLock>, - operator_svc_client: OnceLock>, - cloud_svc_client: OnceLock>, - test_svc_client: OnceLock>, - health_svc_client: OnceLock>, +#[derive(Clone)] +pub struct TemporalServiceClient { + workflow_svc_client: Box, + operator_svc_client: Box, + cloud_svc_client: Box, + test_svc_client: Box, + health_svc_client: Box, } /// We up the limit on incoming messages from server from the 4Mb default to 128Mb. If for @@ -759,136 +760,100 @@ fn get_decode_max_size() -> usize { }) } -impl TemporalServiceClient -where - T: Clone, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, -{ - fn new(svc: T) -> Self { +impl TemporalServiceClient { + fn new(svc: T) -> Self + where + T: GrpcService + Send + Sync + Clone + 'static, + T::ResponseBody: tonic::codegen::Body + Send + 'static, + T::Error: Into, + ::Error: Into + Send, + >::Future: Send, + { + let workflow_svc_client = Box::new( + WorkflowServiceClient::new(svc.clone()) + .max_decoding_message_size(get_decode_max_size()), + ); + let operator_svc_client = Box::new( + OperatorServiceClient::new(svc.clone()) + .max_decoding_message_size(get_decode_max_size()), + ); + let cloud_svc_client = Box::new( + CloudServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + let test_svc_client = Box::new( + TestServiceClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + let health_svc_client = Box::new( + HealthClient::new(svc.clone()).max_decoding_message_size(get_decode_max_size()), + ); + Self { - svc, - workflow_svc_client: OnceLock::new(), - operator_svc_client: OnceLock::new(), - cloud_svc_client: OnceLock::new(), - test_svc_client: OnceLock::new(), - health_svc_client: OnceLock::new(), + workflow_svc_client, + operator_svc_client, + cloud_svc_client, + test_svc_client, + health_svc_client, } } + + /// Create a service client from implementations of the individual underlying services. Useful + /// for mocking out service implementations. + pub fn from_services( + workflow: Box, + operator: Box, + cloud: Box, + test: Box, + health: Box, + ) -> Self { + Self { + workflow_svc_client: workflow, + operator_svc_client: operator, + cloud_svc_client: cloud, + test_svc_client: test, + health_svc_client: health, + } + } + /// Get the underlying workflow service client - pub fn workflow_svc(&self) -> &WorkflowServiceClient { - self.workflow_svc_client.get_or_init(|| { - WorkflowServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn workflow_svc(&self) -> Box { + self.workflow_svc_client.clone() } /// Get the underlying operator service client - pub fn operator_svc(&self) -> &OperatorServiceClient { - self.operator_svc_client.get_or_init(|| { - OperatorServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn operator_svc(&self) -> Box { + self.operator_svc_client.clone() } /// Get the underlying cloud service client - pub fn cloud_svc(&self) -> &CloudServiceClient { - self.cloud_svc_client.get_or_init(|| { - CloudServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn cloud_svc(&self) -> Box { + self.cloud_svc_client.clone() } /// Get the underlying test service client - pub fn test_svc(&self) -> &TestServiceClient { - self.test_svc_client.get_or_init(|| { - TestServiceClient::new(self.svc.clone()) - .max_decoding_message_size(get_decode_max_size()) - }) + pub fn test_svc(&self) -> Box { + self.test_svc_client.clone() } /// Get the underlying health service client - pub fn health_svc(&self) -> &HealthClient { - self.health_svc_client.get_or_init(|| { - HealthClient::new(self.svc.clone()).max_decoding_message_size(get_decode_max_size()) - }) - } - /// Get the underlying workflow service client mutably - pub fn workflow_svc_mut(&mut self) -> &mut WorkflowServiceClient { - let _ = self.workflow_svc(); - self.workflow_svc_client.get_mut().unwrap() - } - /// Get the underlying operator service client mutably - pub fn operator_svc_mut(&mut self) -> &mut OperatorServiceClient { - let _ = self.operator_svc(); - self.operator_svc_client.get_mut().unwrap() - } - /// Get the underlying cloud service client mutably - pub fn cloud_svc_mut(&mut self) -> &mut CloudServiceClient { - let _ = self.cloud_svc(); - self.cloud_svc_client.get_mut().unwrap() - } - /// Get the underlying test service client mutably - pub fn test_svc_mut(&mut self) -> &mut TestServiceClient { - let _ = self.test_svc(); - self.test_svc_client.get_mut().unwrap() - } - /// Get the underlying health service client mutably - pub fn health_svc_mut(&mut self) -> &mut HealthClient { - let _ = self.health_svc(); - self.health_svc_client.get_mut().unwrap() + pub fn health_svc(&self) -> Box { + self.health_svc_client.clone() } } -/// A [WorkflowServiceClient] with the default interceptors attached. -pub type WorkflowServiceClientWithMetrics = WorkflowServiceClient; -/// An [OperatorServiceClient] with the default interceptors attached. -pub type OperatorServiceClientWithMetrics = OperatorServiceClient; -/// An [TestServiceClient] with the default interceptors attached. -pub type TestServiceClientWithMetrics = TestServiceClient; -/// A [TemporalServiceClient] with the default interceptors attached. -pub type TemporalServiceClientWithMetrics = TemporalServiceClient; -type InterceptedMetricsSvc = InterceptedService; - /// Contains an instance of a namespace-bound client for interacting with the Temporal server -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Client { /// Client for interacting with workflow service - inner: ConfiguredClient, + inner: ConfiguredClient, /// The namespace this client interacts with namespace: String, } impl Client { /// Create a new client from an existing configured lower level client and a namespace - pub fn new( - client: ConfiguredClient, - namespace: String, - ) -> Self { + pub fn new(client: ConfiguredClient, namespace: String) -> Self { Client { inner: client, namespace, } } - /// Return an auto-retrying version of the underling grpc client (instrumented with metrics - /// collection, if enabled). - /// - /// Note that it is reasonably cheap to clone the returned type if you need to own it. Such - /// clones will keep re-using the same channel. - pub fn raw_retry_client(&self) -> RetryClient { - RetryClient::new( - self.raw_client().clone(), - self.inner.options.retry_config.clone(), - ) - } - - /// Access the underling grpc client. This raw client is not bound to a specific namespace. - /// - /// Note that it is reasonably cheap to clone the returned type if you need to own it. Such - /// clones will keep re-using the same channel. - pub fn raw_client(&self) -> &WorkflowServiceClientWithMetrics { - self.inner.workflow_svc() - } - /// Return the options this client was initialized with pub fn options(&self) -> &ClientOptions { &self.inner.options @@ -900,12 +865,12 @@ impl Client { } /// Returns a reference to the underlying client - pub fn inner(&self) -> &ConfiguredClient { + pub fn inner(&self) -> &ConfiguredClient { &self.inner } /// Consumes self and returns the underlying client - pub fn into_inner(self) -> ConfiguredClient { + pub fn into_inner(self) -> ConfiguredClient { self.inner } @@ -916,12 +881,12 @@ impl Client { } impl NamespacedClient for Client { - fn namespace(&self) -> &str { - &self.namespace + fn namespace(&self) -> String { + self.namespace.clone() } - fn get_identity(&self) -> &str { - &self.inner.options.identity + fn identity(&self) -> String { + self.inner.options.identity.clone() } } @@ -1244,9 +1209,9 @@ pub trait WorkflowClientTrait: NamespacedClient { /// A client that is bound to a namespace pub trait NamespacedClient { /// Returns the namespace this client is bound to - fn namespace(&self) -> &str; + fn namespace(&self) -> String; /// Returns the client identity - fn get_identity(&self) -> &str; + fn identity(&self) -> String; } /// Optional fields supplied at the start of workflow execution @@ -1386,15 +1351,7 @@ impl From for Priority { #[async_trait::async_trait] impl WorkflowClientTrait for T where - T: RawClientLike + NamespacedClient + Clone + Send + Sync + 'static, - ::SvcType: GrpcService + Send + Clone + 'static, - <::SvcType as GrpcService>::ResponseBody: - tonic::codegen::Body + Send + 'static, - <::SvcType as GrpcService>::Error: - Into, - <::SvcType as GrpcService>::Future: Send, - <<::SvcType as GrpcService>::ResponseBody - as tonic::codegen::Body>::Error: Into + Send, + T: WorkflowService + NamespacedClient + Clone + Send + Sync + 'static, { async fn start_workflow( &self, @@ -1407,35 +1364,38 @@ where ) -> Result { Ok(self .clone() - .start_workflow_execution(StartWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), - input: input.into_payloads(), - workflow_id, - workflow_type: Some(WorkflowType { - name: workflow_type, - }), - task_queue: Some(TaskQueue { - name: task_queue, - kind: TaskQueueKind::Unspecified as i32, - normal_name: "".to_string(), - }), - request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), - workflow_id_reuse_policy: options.id_reuse_policy as i32, - workflow_id_conflict_policy: options.id_conflict_policy as i32, - workflow_execution_timeout: options - .execution_timeout - .and_then(|d| d.try_into().ok()), - workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), - workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), - search_attributes: options.search_attributes.map(|d| d.into()), - cron_schedule: options.cron_schedule.unwrap_or_default(), - request_eager_execution: options.enable_eager_workflow_start, - retry_policy: options.retry_policy, - links: options.links, - completion_callbacks: options.completion_callbacks, - priority: options.priority.map(Into::into), - ..Default::default() - }) + .start_workflow_execution( + StartWorkflowExecutionRequest { + namespace: self.namespace(), + input: input.into_payloads(), + workflow_id, + workflow_type: Some(WorkflowType { + name: workflow_type, + }), + task_queue: Some(TaskQueue { + name: task_queue, + kind: TaskQueueKind::Unspecified as i32, + normal_name: "".to_string(), + }), + request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), + workflow_id_reuse_policy: options.id_reuse_policy as i32, + workflow_id_conflict_policy: options.id_conflict_policy as i32, + workflow_execution_timeout: options + .execution_timeout + .and_then(|d| d.try_into().ok()), + workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), + workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), + search_attributes: options.search_attributes.map(|d| d.into()), + cron_schedule: options.cron_schedule.unwrap_or_default(), + request_eager_execution: options.enable_eager_workflow_start, + retry_policy: options.retry_policy, + links: options.links, + completion_callbacks: options.completion_callbacks, + priority: options.priority.map(Into::into), + ..Default::default() + } + .into_request(), + ) .await? .into_inner()) } @@ -1446,14 +1406,14 @@ where run_id: String, ) -> Result { let request = ResetStickyTaskQueueRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), execution: Some(WorkflowExecution { workflow_id, run_id, }), }; Ok( - WorkflowService::reset_sticky_task_queue(&mut self.clone(), request) + WorkflowService::reset_sticky_task_queue(&mut self.clone(), request.into_request()) .await? .into_inner(), ) @@ -1464,17 +1424,20 @@ where task_token: TaskToken, result: Option, ) -> Result { - Ok(self.clone().respond_activity_task_completed( - RespondActivityTaskCompletedRequest { - task_token: task_token.0, - result, - identity: self.get_identity().to_owned(), - namespace: self.namespace().to_owned(), - ..Default::default() - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .respond_activity_task_completed( + RespondActivityTaskCompletedRequest { + task_token: task_token.0, + result, + identity: self.identity(), + namespace: self.namespace(), + ..Default::default() + } + .into_request(), + ) + .await? + .into_inner()) } async fn record_activity_heartbeat( @@ -1482,16 +1445,19 @@ where task_token: TaskToken, details: Option, ) -> Result { - Ok(self.clone().record_activity_task_heartbeat( - RecordActivityTaskHeartbeatRequest { - task_token: task_token.0, - details, - identity: self.get_identity().to_owned(), - namespace: self.namespace().to_owned(), - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .record_activity_task_heartbeat( + RecordActivityTaskHeartbeatRequest { + task_token: task_token.0, + details, + identity: self.identity(), + namespace: self.namespace(), + } + .into_request(), + ) + .await? + .into_inner()) } async fn cancel_activity_task( @@ -1499,17 +1465,20 @@ where task_token: TaskToken, details: Option, ) -> Result { - Ok(self.clone().respond_activity_task_canceled( - RespondActivityTaskCanceledRequest { - task_token: task_token.0, - details, - identity: self.get_identity().to_owned(), - namespace: self.namespace().to_owned(), - ..Default::default() - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .respond_activity_task_canceled( + RespondActivityTaskCanceledRequest { + task_token: task_token.0, + details, + identity: self.identity(), + namespace: self.namespace(), + ..Default::default() + } + .into_request(), + ) + .await? + .into_inner()) } async fn signal_workflow_execution( @@ -1520,19 +1489,21 @@ where payloads: Option, request_id: Option, ) -> Result { - Ok(WorkflowService::signal_workflow_execution(&mut self.clone(), + Ok(WorkflowService::signal_workflow_execution( + &mut self.clone(), SignalWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { workflow_id, run_id, }), signal_name, input: payloads, - identity: self.get_identity().to_owned(), + identity: self.identity(), request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1543,9 +1514,10 @@ where options: SignalWithStartOptions, workflow_options: WorkflowOptions, ) -> Result { - Ok(WorkflowService::signal_with_start_workflow_execution(&mut self.clone(), + Ok(WorkflowService::signal_with_start_workflow_execution( + &mut self.clone(), SignalWithStartWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), workflow_id: options.workflow_id, workflow_type: Some(WorkflowType { name: options.workflow_type, @@ -1558,7 +1530,7 @@ where input: options.input, signal_name: options.signal_name, signal_input: options.signal_input, - identity: self.get_identity().to_owned(), + identity: self.identity(), request_id: options .request_id .unwrap_or_else(|| Uuid::new_v4().to_string()), @@ -1575,7 +1547,8 @@ where cron_schedule: workflow_options.cron_schedule.unwrap_or_default(), header: options.signal_header, ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1587,19 +1560,22 @@ where run_id: String, query: WorkflowQuery, ) -> Result { - Ok(self.clone().query_workflow( - QueryWorkflowRequest { - namespace: self.namespace().to_owned(), - execution: Some(WorkflowExecution { - workflow_id, - run_id, - }), - query: Some(query), - query_reject_condition: 1, - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .query_workflow( + QueryWorkflowRequest { + namespace: self.namespace(), + execution: Some(WorkflowExecution { + workflow_id, + run_id, + }), + query: Some(query), + query_reject_condition: 1, + } + .into_request(), + ) + .await? + .into_inner()) } async fn describe_workflow_execution( @@ -1607,14 +1583,16 @@ where workflow_id: String, run_id: Option, ) -> Result { - Ok(WorkflowService::describe_workflow_execution(&mut self.clone(), + Ok(WorkflowService::describe_workflow_execution( + &mut self.clone(), DescribeWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), execution: Some(WorkflowExecution { workflow_id, run_id: run_id.unwrap_or_default(), }), - }, + } + .into_request(), ) .await? .into_inner()) @@ -1626,16 +1604,18 @@ where run_id: Option, page_token: Vec, ) -> Result { - Ok(WorkflowService::get_workflow_execution_history(&mut self.clone(), + Ok(WorkflowService::get_workflow_execution_history( + &mut self.clone(), GetWorkflowExecutionHistoryRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), execution: Some(WorkflowExecution { workflow_id, run_id: run_id.unwrap_or_default(), }), next_page_token: page_token, ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1648,22 +1628,25 @@ where reason: String, request_id: Option, ) -> Result { - Ok(self.clone().request_cancel_workflow_execution( - RequestCancelWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), - workflow_execution: Some(WorkflowExecution { - workflow_id, - run_id: run_id.unwrap_or_default(), - }), - identity: self.get_identity().to_owned(), - request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), - first_execution_run_id: "".to_string(), - reason, - links: vec![], - }, - ) - .await? - .into_inner()) + Ok(self + .clone() + .request_cancel_workflow_execution( + RequestCancelWorkflowExecutionRequest { + namespace: self.namespace(), + workflow_execution: Some(WorkflowExecution { + workflow_id, + run_id: run_id.unwrap_or_default(), + }), + identity: self.identity(), + request_id: request_id.unwrap_or_else(|| Uuid::new_v4().to_string()), + first_execution_run_id: "".to_string(), + reason, + links: vec![], + } + .into_request(), + ) + .await? + .into_inner()) } async fn terminate_workflow_execution( @@ -1671,19 +1654,21 @@ where workflow_id: String, run_id: Option, ) -> Result { - Ok(WorkflowService::terminate_workflow_execution(&mut self.clone(), + Ok(WorkflowService::terminate_workflow_execution( + &mut self.clone(), TerminateWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { workflow_id, run_id: run_id.unwrap_or_default(), }), reason: "".to_string(), details: None, - identity: self.get_identity().to_owned(), + identity: self.identity(), first_execution_run_id: "".to_string(), links: vec![], - }, + } + .into_request(), ) .await? .into_inner()) @@ -1695,23 +1680,25 @@ where ) -> Result { let req = Into::::into(options); Ok( - WorkflowService::register_namespace(&mut self.clone(),req) + WorkflowService::register_namespace(&mut self.clone(), req.into_request()) .await? .into_inner(), ) } async fn list_namespaces(&self) -> Result { - Ok(WorkflowService::list_namespaces(&mut self.clone(), - ListNamespacesRequest::default(), + Ok(WorkflowService::list_namespaces( + &mut self.clone(), + ListNamespacesRequest::default().into_request(), ) .await? .into_inner()) } async fn describe_namespace(&self, namespace: Namespace) -> Result { - Ok(WorkflowService::describe_namespace(&mut self.clone(), - namespace.into_describe_namespace_request(), + Ok(WorkflowService::describe_namespace( + &mut self.clone(), + namespace.into_describe_namespace_request().into_request(), ) .await? .into_inner()) @@ -1724,14 +1711,16 @@ where start_time_filter: Option, filters: Option, ) -> Result { - Ok(WorkflowService::list_open_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_open_workflow_executions( + &mut self.clone(), ListOpenWorkflowExecutionsRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), maximum_page_size, next_page_token, start_time_filter, filters, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1744,14 +1733,16 @@ where start_time_filter: Option, filters: Option, ) -> Result { - Ok(WorkflowService::list_closed_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_closed_workflow_executions( + &mut self.clone(), ListClosedWorkflowExecutionsRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), maximum_page_size, next_page_token, start_time_filter, filters, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1763,13 +1754,15 @@ where next_page_token: Vec, query: String, ) -> Result { - Ok(WorkflowService::list_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_workflow_executions( + &mut self.clone(), ListWorkflowExecutionsRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), page_size, next_page_token, query, - }, + } + .into_request(), ) .await? .into_inner()) @@ -1781,21 +1774,24 @@ where next_page_token: Vec, query: String, ) -> Result { - Ok(WorkflowService::list_archived_workflow_executions(&mut self.clone(), + Ok(WorkflowService::list_archived_workflow_executions( + &mut self.clone(), ListArchivedWorkflowExecutionsRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), page_size, next_page_token, query, - }, + } + .into_request(), ) .await? .into_inner()) } async fn get_search_attributes(&self) -> Result { - Ok(WorkflowService::get_search_attributes(&mut self.clone(), - GetSearchAttributesRequest {}, + Ok(WorkflowService::get_search_attributes( + &mut self.clone(), + GetSearchAttributesRequest {}.into_request(), ) .await? .into_inner()) @@ -1809,9 +1805,10 @@ where wait_policy: update::v1::WaitPolicy, args: Option, ) -> Result { - Ok(WorkflowService::update_workflow_execution(&mut self.clone(), + Ok(WorkflowService::update_workflow_execution( + &mut self.clone(), UpdateWorkflowExecutionRequest { - namespace: self.namespace().to_owned(), + namespace: self.namespace(), workflow_execution: Some(WorkflowExecution { workflow_id, run_id, @@ -1820,7 +1817,7 @@ where request: Some(update::v1::Request { meta: Some(update::v1::Meta { update_id: "".into(), - identity: self.get_identity().to_owned(), + identity: self.identity(), }), input: Some(update::v1::Input { header: None, @@ -1829,7 +1826,8 @@ where }), }), ..Default::default() - }, + } + .into_request(), ) .await? .into_inner()) @@ -1837,17 +1835,9 @@ where } mod sealed { - use crate::{InterceptedMetricsSvc, RawClientLike, WorkflowClientTrait}; - - pub trait WfHandleClient: - WorkflowClientTrait + RawClientLike - { - } - - impl WfHandleClient for T where - T: WorkflowClientTrait + RawClientLike - { - } + use crate::{WorkflowClientTrait, WorkflowService}; + pub trait WfHandleClient: WorkflowClientTrait + WorkflowService {} + impl WfHandleClient for T where T: WorkflowClientTrait + WorkflowService {} } /// Additional methods for workflow clients @@ -1863,7 +1853,7 @@ pub trait WfClientExt: WfHandleClient + Sized + Clone { UntypedWorkflowHandle::new( self.clone(), WorkflowExecutionInfo { - namespace: self.namespace().to_string(), + namespace: self.namespace(), workflow_id: workflow_id.into(), run_id: if rid.is_empty() { None } else { Some(rid) }, }, diff --git a/client/src/metrics.rs b/client/src/metrics.rs index d2dbbf279..aeceb4a67 100644 --- a/client/src/metrics.rs +++ b/client/src/metrics.rs @@ -208,7 +208,7 @@ fn code_as_screaming_snake(code: &Code) -> &'static str { /// Implements metrics functionality for gRPC (really, any http) calls #[derive(Debug, Clone)] -pub struct GrpcMetricSvc { +pub(crate) struct GrpcMetricSvc { pub(crate) inner: ChannelOrGrpcOverride, // If set to none, metrics are a no-op pub(crate) metrics: Option, @@ -230,6 +230,7 @@ impl fmt::Debug for ChannelOrGrpcOverride { } } +// TODO: Rewrite as a RawGrpcCaller implementation impl Service> for GrpcMetricSvc { type Response = http::Response; type Error = Box; diff --git a/client/src/raw.rs b/client/src/raw.rs index 26d08203d..b59dac696 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -3,14 +3,14 @@ //! happen. use crate::{ - Client, ConfiguredClient, InterceptedMetricsSvc, LONG_POLL_TIMEOUT, RequestExt, RetryClient, + Client, ConfiguredClient, LONG_POLL_TIMEOUT, RequestExt, RetryClient, SharedReplaceableClient, TEMPORAL_NAMESPACE_HEADER_KEY, TemporalServiceClient, metrics::{namespace_kv, task_queue_kv}, - raw::sealed::RawClientLike, worker_registry::{ClientWorkerSet, Slot}, }; +use dyn_clone::DynClone; use futures_util::{FutureExt, TryFutureExt, future::BoxFuture}; -use std::sync::Arc; +use std::{any::Any, marker::PhantomData, sync::Arc}; use temporal_sdk_core_api::telemetry::metrics::MetricKeyValue; use temporal_sdk_core_protos::{ grpc::health::v1::{health_client::HealthClient, *}, @@ -29,115 +29,183 @@ use tonic::{ metadata::{AsciiMetadataValue, KeyAndValueRef}, }; -pub(super) mod sealed { - use super::*; - - /// Something that has access to the raw grpc services - #[async_trait::async_trait] - pub trait RawClientLike: Send { - type SvcType: Send + Sync + Clone + 'static; - - /// Return a ref to the workflow service client instance - fn workflow_client(&self) -> &WorkflowServiceClient; - - /// Return a mutable ref to the workflow service client instance - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient; +/// Something that has access to the raw grpc services +trait RawClientProducer { + /// Returns information about workers associated with this client. Implementers outside of + /// core can safely return `None`. + fn get_workers_info(&self) -> Option>; - /// Return a ref to the operator service client instance - fn operator_client(&self) -> &OperatorServiceClient; + /// Return a workflow service client instance + fn workflow_client(&mut self) -> Box; - /// Return a mutable ref to the operator service client instance - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient; + /// Return a mutable ref to the operator service client instance + fn operator_client(&mut self) -> Box; - /// Return a ref to the cloud service client instance - fn cloud_client(&self) -> &CloudServiceClient; + /// Return a mutable ref to the cloud service client instance + fn cloud_client(&mut self) -> Box; - /// Return a mutable ref to the cloud service client instance - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient; + /// Return a mutable ref to the test service client instance + fn test_client(&mut self) -> Box; - /// Return a ref to the test service client instance - fn test_client(&self) -> &TestServiceClient; + /// Return a mutable ref to the health service client instance + fn health_client(&mut self) -> Box; +} - /// Return a mutable ref to the test service client instance - fn test_client_mut(&mut self) -> &mut TestServiceClient; +/// Any client that can make gRPC calls. The default implementation simply invokes the passed-in +/// function. Implementers may override this to provide things like retry behavior, ex: +/// [RetryClient]. +#[async_trait::async_trait] +trait RawGrpcCaller: Send + Sync + 'static { + async fn call( + &mut self, + _call_name: &'static str, + mut callfn: F, + req: Request, + ) -> Result, Status> + where + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: + FnMut(&'a mut Self, Request) -> BoxFuture<'static, Result, Status>>, + { + callfn(self, req).await + } +} - /// Return a ref to the health service client instance - fn health_client(&self) -> &HealthClient; +trait ErasedRawClient: Send + Sync + 'static { + fn erased_call( + &mut self, + call_name: &'static str, + op: &mut dyn ErasedCallOp, + ) -> BoxFuture<'static, Result>, Status>>; +} - /// Return a mutable ref to the health service client instance - fn health_client_mut(&mut self) -> &mut HealthClient; +trait ErasedCallOp: Send { + fn invoke( + &mut self, + raw: &mut dyn ErasedRawClient, + call_name: &'static str, + ) -> BoxFuture<'static, Result>, Status>>; +} - /// Return a registry with workers using this client instance - fn get_workers_info(&self) -> Option>; +struct CallShim { + callfn: F, + seed_req: Option>, + _resp: PhantomData, +} - async fn call( - &mut self, - _call_name: &'static str, - mut callfn: F, - req: Request, - ) -> Result, Status> - where - Req: Clone + Unpin + Send + Sync + 'static, - F: FnMut(&mut Self, Request) -> BoxFuture<'static, Result, Status>>, - F: Send + Sync + Unpin + 'static, - { - callfn(self, req).await +impl CallShim { + fn new(callfn: F, seed_req: Request) -> Self { + Self { + callfn, + seed_req: Some(seed_req), + _resp: PhantomData, } } } -#[async_trait::async_trait] -impl RawClientLike for RetryClient +impl ErasedCallOp for CallShim where - RC: RawClientLike + 'static, - T: Send + Sync + Clone + 'static, + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: FnMut( + &'a mut dyn ErasedRawClient, + Request, + ) -> BoxFuture<'static, Result, Status>>, { - type SvcType = T; - - fn workflow_client(&self) -> &WorkflowServiceClient { - self.get_client().workflow_client() - } - - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.get_client_mut().workflow_client_mut() - } - - fn operator_client(&self) -> &OperatorServiceClient { - self.get_client().operator_client() + fn invoke( + &mut self, + raw: &mut dyn ErasedRawClient, + _call_name: &'static str, + ) -> BoxFuture<'static, Result>, Status>> { + (self.callfn)( + raw, + self.seed_req + .take() + .expect("CallShim must have request populated"), + ) + .map(|res| res.map(|payload| payload.map(|t| Box::new(t) as Box))) + .boxed() } +} - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.get_client_mut().operator_client_mut() +#[async_trait::async_trait] +impl RawGrpcCaller for dyn ErasedRawClient { + async fn call( + &mut self, + call_name: &'static str, + callfn: F, + req: Request, + ) -> Result, Status> + where + Req: Clone + Unpin + Send + Sync + 'static, + Resp: Send + 'static, + F: Send + Sync + Unpin + 'static, + for<'a> F: FnMut( + &'a mut dyn ErasedRawClient, + Request, + ) -> BoxFuture<'static, Result, Status>>, + { + let mut shim = CallShim::new(callfn, req); + let erased_resp = ErasedRawClient::erased_call(self, call_name, &mut shim).await?; + Ok(erased_resp.map(|boxed| { + *boxed + .downcast() + .expect("RawGrpcCaller erased response type mismatch") + })) } +} - fn cloud_client(&self) -> &CloudServiceClient { - self.get_client().cloud_client() +impl ErasedRawClient for T +where + T: RawGrpcCaller + 'static, +{ + fn erased_call( + &mut self, + call_name: &'static str, + op: &mut dyn ErasedCallOp, + ) -> BoxFuture<'static, Result>, Status>> { + let raw: &mut dyn ErasedRawClient = self; + op.invoke(raw, call_name) } +} - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.get_client_mut().cloud_client_mut() +impl RawClientProducer for RetryClient +where + RC: RawClientProducer + 'static, +{ + fn get_workers_info(&self) -> Option> { + self.get_client().get_workers_info() } - fn test_client(&self) -> &TestServiceClient { - self.get_client().test_client() + fn workflow_client(&mut self) -> Box { + self.get_client_mut().workflow_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.get_client_mut().test_client_mut() + fn operator_client(&mut self) -> Box { + self.get_client_mut().operator_client() } - fn health_client(&self) -> &HealthClient { - self.get_client().health_client() + fn cloud_client(&mut self) -> Box { + self.get_client_mut().cloud_client() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.get_client_mut().health_client_mut() + fn test_client(&mut self) -> Box { + self.get_client_mut().test_client() } - fn get_workers_info(&self) -> Option> { - self.get_client().get_workers_info() + fn health_client(&mut self) -> Box { + self.get_client_mut().health_client() } +} +#[async_trait::async_trait] +impl RawGrpcCaller for RetryClient +where + RC: RawGrpcCaller + 'static, +{ async fn call( &mut self, call_name: &'static str, @@ -163,182 +231,142 @@ where } } -impl RawClientLike for TemporalServiceClient +/// Helper for cloning a tonic request as long as the inner message may be cloned. +fn req_cloner(cloneme: &Request) -> Request { + let msg = cloneme.get_ref().clone(); + let mut new_req = Request::new(msg); + let new_met = new_req.metadata_mut(); + for kv in cloneme.metadata().iter() { + match kv { + KeyAndValueRef::Ascii(k, v) => { + new_met.insert(k, v.clone()); + } + KeyAndValueRef::Binary(k, v) => { + new_met.insert_bin(k, v.clone()); + } + } + } + *new_req.extensions_mut() = cloneme.extensions().clone(); + new_req +} + +impl RawClientProducer for SharedReplaceableClient where - T: Send + Sync + Clone + 'static, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, + RC: RawClientProducer + Clone + Send + Sync + 'static, { - type SvcType = T; - - fn workflow_client(&self) -> &WorkflowServiceClient { - self.workflow_svc() + fn get_workers_info(&self) -> Option> { + self.inner_cow().get_workers_info() + } + fn workflow_client(&mut self) -> Box { + self.inner_mut_refreshed().workflow_client() } - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.workflow_svc_mut() + fn operator_client(&mut self) -> Box { + self.inner_mut_refreshed().operator_client() } - fn operator_client(&self) -> &OperatorServiceClient { - self.operator_svc() + fn cloud_client(&mut self) -> Box { + self.inner_mut_refreshed().cloud_client() } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.operator_svc_mut() + fn test_client(&mut self) -> Box { + self.inner_mut_refreshed().test_client() } - fn cloud_client(&self) -> &CloudServiceClient { - self.cloud_svc() + fn health_client(&mut self) -> Box { + self.inner_mut_refreshed().health_client() } +} + +#[async_trait::async_trait] +impl RawGrpcCaller for SharedReplaceableClient where + RC: RawGrpcCaller + Clone + Sync + 'static +{ +} - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.cloud_svc_mut() +impl RawClientProducer for TemporalServiceClient { + fn get_workers_info(&self) -> Option> { + None } - fn test_client(&self) -> &TestServiceClient { - self.test_svc() + fn workflow_client(&mut self) -> Box { + self.workflow_svc() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.test_svc_mut() + fn operator_client(&mut self) -> Box { + self.operator_svc() } - fn health_client(&self) -> &HealthClient { - self.health_svc() + fn cloud_client(&mut self) -> Box { + self.cloud_svc() } - fn health_client_mut(&mut self) -> &mut HealthClient { - self.health_svc_mut() + fn test_client(&mut self) -> Box { + self.test_svc() } - fn get_workers_info(&self) -> Option> { - None + fn health_client(&mut self) -> Box { + self.health_svc() } } -impl RawClientLike for ConfiguredClient> -where - T: Send + Sync + Clone + 'static, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - ::Error: Into + Send, -{ - type SvcType = T; +impl RawGrpcCaller for TemporalServiceClient {} - fn workflow_client(&self) -> &WorkflowServiceClient { - self.client.workflow_client() +impl RawClientProducer for ConfiguredClient { + fn get_workers_info(&self) -> Option> { + Some(self.workers()) } - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.client.workflow_client_mut() + fn workflow_client(&mut self) -> Box { + self.client.workflow_client() } - fn operator_client(&self) -> &OperatorServiceClient { + fn operator_client(&mut self) -> Box { self.client.operator_client() } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.client.operator_client_mut() - } - - fn cloud_client(&self) -> &CloudServiceClient { + fn cloud_client(&mut self) -> Box { self.client.cloud_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.client.cloud_client_mut() - } - - fn test_client(&self) -> &TestServiceClient { + fn test_client(&mut self) -> Box { self.client.test_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.client.test_client_mut() - } - - fn health_client(&self) -> &HealthClient { + fn health_client(&mut self) -> Box { self.client.health_client() } +} - fn health_client_mut(&mut self) -> &mut HealthClient { - self.client.health_client_mut() - } +impl RawGrpcCaller for ConfiguredClient {} +impl RawClientProducer for Client { fn get_workers_info(&self) -> Option> { - Some(self.workers()) + self.inner.get_workers_info() } -} -impl RawClientLike for Client { - type SvcType = InterceptedMetricsSvc; - - fn workflow_client(&self) -> &WorkflowServiceClient { + fn workflow_client(&mut self) -> Box { self.inner.workflow_client() } - fn workflow_client_mut(&mut self) -> &mut WorkflowServiceClient { - self.inner.workflow_client_mut() - } - - fn operator_client(&self) -> &OperatorServiceClient { + fn operator_client(&mut self) -> Box { self.inner.operator_client() } - fn operator_client_mut(&mut self) -> &mut OperatorServiceClient { - self.inner.operator_client_mut() - } - - fn cloud_client(&self) -> &CloudServiceClient { + fn cloud_client(&mut self) -> Box { self.inner.cloud_client() } - fn cloud_client_mut(&mut self) -> &mut CloudServiceClient { - self.inner.cloud_client_mut() - } - - fn test_client(&self) -> &TestServiceClient { + fn test_client(&mut self) -> Box { self.inner.test_client() } - fn test_client_mut(&mut self) -> &mut TestServiceClient { - self.inner.test_client_mut() - } - - fn health_client(&self) -> &HealthClient { + fn health_client(&mut self) -> Box { self.inner.health_client() } - - fn health_client_mut(&mut self) -> &mut HealthClient { - self.inner.health_client_mut() - } - - fn get_workers_info(&self) -> Option> { - self.inner.get_workers_info() - } } -/// Helper for cloning a tonic request as long as the inner message may be cloned. -fn req_cloner(cloneme: &Request) -> Request { - let msg = cloneme.get_ref().clone(); - let mut new_req = Request::new(msg); - let new_met = new_req.metadata_mut(); - for kv in cloneme.metadata().iter() { - match kv { - KeyAndValueRef::Ascii(k, v) => { - new_met.insert(k, v.clone()); - } - KeyAndValueRef::Binary(k, v) => { - new_met.insert_bin(k, v.clone()); - } - } - } - *new_req.extensions_mut() = cloneme.extensions().clone(); - new_req -} +impl RawGrpcCaller for Client {} #[derive(Clone, Debug)] pub(super) struct AttachMetricLabels { @@ -368,62 +396,27 @@ impl AttachMetricLabels { #[derive(Copy, Clone, Debug)] pub(super) struct IsUserLongPoll; -// Blanket impl the trait for all raw-client-like things. Since the trait default-implements -// everything, there's nothing to actually implement. -impl WorkflowService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl OperatorService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl CloudService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl TestService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ -} -impl HealthService for RC -where - RC: RawClientLike, - T: GrpcService + Send + Clone + 'static, - T::ResponseBody: tonic::codegen::Body + Send + 'static, - T::Error: Into, - T::Future: Send, - ::Error: Into + Send, -{ +macro_rules! proxy_def { + ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty, defaults) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + _request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + async { Ok(tonic::Response::new(<$resp>::default())) }.boxed() + } + }; + ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + _request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>>; + }; } - /// Helps re-declare gRPC client methods /// -/// There are two forms: +/// There are four forms: /// /// * The first takes a closure that can modify the request. This is only called once, before the /// actual rpc call is made, and before determinations are made about the kind of call (long poll @@ -431,22 +424,23 @@ where /// * The second takes three closures. The first can modify the request like in the first form. /// The second can modify the request and return a value, and is called right before every call /// (including on retries). The third is called with the response to the call after it resolves. -macro_rules! proxy { +/// * The third and fourth are equivalents of the above that skip calling through the `call` method +/// and are implemented directly on the generated gRPC clients (IE: the bottom of the stack). +macro_rules! proxy_impl { ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty $(, $closure:expr)?) => { #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] fn $method( &mut self, - request: impl tonic::IntoRequest<$req>, - ) -> BoxFuture<'_, Result, tonic::Status>> { #[allow(unused_mut)] - let mut as_req = request.into_request(); - $( type_closure_arg(&mut as_req, $closure); )* + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + $( type_closure_arg(&mut request, $closure); )* #[allow(unused_mut)] let fact = |c: &mut Self, mut req: tonic::Request<$req>| { - let mut c = c.$client_meth().clone(); + let mut c = c.$client_meth(); async move { c.$method(req).await }.boxed() }; - self.call(stringify!($method), fact, as_req) + self.call(stringify!($method), fact, request) } }; ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty, @@ -454,52 +448,108 @@ macro_rules! proxy { #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] fn $method( &mut self, - request: impl tonic::IntoRequest<$req>, + mut request: tonic::Request<$req>, ) -> BoxFuture<'_, Result, tonic::Status>> { - #[allow(unused_mut)] - let mut as_req = request.into_request(); - type_closure_arg(&mut as_req, $closure_request); + type_closure_arg(&mut request, $closure_request); #[allow(unused_mut)] let fact = |c: &mut Self, mut req: tonic::Request<$req>| { - let data = type_closure_two_arg(&mut req, c.get_workers_info().unwrap(), - $closure_before); - let mut c = c.$client_meth().clone(); + let data = type_closure_two_arg(&mut req, c.get_workers_info(), $closure_before); + let mut c = c.$client_meth(); async move { type_closure_two_arg(c.$method(req).await, data, $closure_after) }.boxed() }; - self.call(stringify!($method), fact, as_req) + self.call(stringify!($method), fact, request) + } + }; + ($client_type:tt, $method:ident, $req:ty, $resp:ty $(, $closure:expr)?) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + #[allow(unused_mut)] + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + $( type_closure_arg(&mut request, $closure); )* + async move { <$client_type<_>>::$method(self, request).await }.boxed() + } + }; + ($client_type:tt, $method:ident, $req:ty, $resp:ty, + $closure_request:expr, $closure_before:expr, $closure_after:expr) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + mut request: tonic::Request<$req>, + ) -> BoxFuture<'_, Result, tonic::Status>> { + type_closure_arg(&mut request, $closure_request); + let data = type_closure_two_arg(&mut request, Option::>::None, + $closure_before); + async move { + type_closure_two_arg(<$client_type<_>>::$method(self, request).await, + data, $closure_after) + }.boxed() } }; } -macro_rules! proxier { - ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; - $(($method:ident, $req:ty, $resp:ty - $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { +macro_rules! proxier_impl { + ($trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; + [$( proxy_def!($($def_args:tt)*); )*]; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { #[cfg(test)] const $impl_list_name: &'static [&'static str] = &[$(stringify!($method)),*]; - /// Trait version of the generated client with modifications to attach appropriate metric - /// labels or whatever else to requests - pub trait $trait_name: RawClientLike + + #[doc = concat!("Trait version of [", stringify!($client_type), "]")] + pub trait $trait_name: Send + Sync + DynClone + { + $( proxy_def!($($def_args)*); )* + } + dyn_clone::clone_trait_object!($trait_name); + + impl $trait_name for RC + where + RC: RawGrpcCaller + RawClientProducer + Clone, + { + $( + proxy_impl!($client_type, $client_meth, $method, $req, $resp + $(,$closure $(,$closure_before, $closure_after)*)*); + )* + } + + impl RawGrpcCaller for $client_type {} + + impl $trait_name for $client_type where - // Yo this is wild - ::SvcType: GrpcService + Send + Clone + 'static, - <::SvcType as GrpcService>::ResponseBody: - tonic::codegen::Body + Send + 'static, - <::SvcType as GrpcService>::Error: - Into, - <::SvcType as GrpcService>::Future: Send, - <<::SvcType as GrpcService>::ResponseBody - as tonic::codegen::Body>::Error: Into + Send, + T: GrpcService + Clone + Send + Sync + 'static, + T::ResponseBody: tonic::codegen::Body + Send + 'static, + T::Error: Into, + ::Error: Into + Send, + >::Future: Send { $( - proxy!($client_type, $client_meth, $method, $req, $resp - $(,$closure $(,$closure_before, $closure_after)*)*); + proxy_impl!($client_type, $method, $req, $resp + $(,$closure $(,$closure_before, $closure_after)*)*); )* } }; } +macro_rules! proxier { + ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { + proxier_impl!($trait_name; $impl_list_name; $client_type; $client_meth; + [$(proxy_def!($client_type, $client_meth, $method, $req, $resp);)*]; + $(($method, $req, $resp $(, $closure $(, $closure_before, $closure_after)?)?);)*); + }; + ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; defaults; + $(($method:ident, $req:ty, $resp:ty + $(, $closure:expr $(, $closure_before:expr, $closure_after:expr)?)? );)* ) => { + proxier_impl!($trait_name; $impl_list_name; $client_type; $client_meth; + [$(proxy_def!($client_type, $client_meth, $method, $req, $resp, defaults);)*]; + $(($method, $req, $resp $(, $closure $(, $closure_before, $closure_after)?)?);)*); + }; +} + macro_rules! namespaced_request { ($req:ident) => {{ let ns_str = $req.get_ref().namespace.clone(); @@ -526,7 +576,7 @@ fn type_closure_two_arg(arg1: R, arg2: T, f: impl FnOnce(R, T) -> S) -> } proxier! { - WorkflowService; ALL_IMPLEMENTED_WORKFLOW_SERVICE_RPCS; WorkflowServiceClient; workflow_client_mut; + WorkflowService; ALL_IMPLEMENTED_WORKFLOW_SERVICE_RPCS; WorkflowServiceClient; workflow_client; defaults; ( register_namespace, RegisterNamespaceRequest, @@ -578,21 +628,25 @@ proxier! { r.extensions_mut().insert(labels); }, |r, workers| { - let mut slot: Option> = None; - let req_mut = r.get_mut(); - if req_mut.request_eager_execution { - let namespace = req_mut.namespace.clone(); - let task_queue = req_mut.task_queue.as_ref() - .map(|tq| tq.name.clone()).unwrap_or_default(); - match workers.try_reserve_wft_slot(namespace, task_queue) { - Some(s) => slot = Some(s), - None => req_mut.request_eager_execution = false + if let Some(workers) = workers { + let mut slot: Option> = None; + let req_mut = r.get_mut(); + if req_mut.request_eager_execution { + let namespace = req_mut.namespace.clone(); + let task_queue = req_mut.task_queue.as_ref() + .map(|tq| tq.name.clone()).unwrap_or_default(); + match workers.try_reserve_wft_slot(namespace, task_queue) { + Some(s) => slot = Some(s), + None => req_mut.request_eager_execution = false + } } + slot + } else { + None } - slot }, |resp, slot| { - if let Some(mut s) = slot + if let Some(s) = slot && let Ok(response) = resp.as_ref() && let Some(task) = response.get_ref().clone().eager_workflow_task && let Err(e) = s.schedule_wft(task) { @@ -1403,7 +1457,7 @@ proxier! { } proxier! { - OperatorService; ALL_IMPLEMENTED_OPERATOR_SERVICE_RPCS; OperatorServiceClient; operator_client_mut; + OperatorService; ALL_IMPLEMENTED_OPERATOR_SERVICE_RPCS; OperatorServiceClient; operator_client; defaults; (add_search_attributes, AddSearchAttributesRequest, AddSearchAttributesResponse); (remove_search_attributes, RemoveSearchAttributesRequest, RemoveSearchAttributesResponse); (list_search_attributes, ListSearchAttributesRequest, ListSearchAttributesResponse); @@ -1424,7 +1478,7 @@ proxier! { } proxier! { - CloudService; ALL_IMPLEMENTED_CLOUD_SERVICE_RPCS; CloudServiceClient; cloud_client_mut; + CloudService; ALL_IMPLEMENTED_CLOUD_SERVICE_RPCS; CloudServiceClient; cloud_client; defaults; (get_users, cloudreq::GetUsersRequest, cloudreq::GetUsersResponse); (get_user, cloudreq::GetUserRequest, cloudreq::GetUserResponse); (create_user, cloudreq::CreateUserRequest, cloudreq::CreateUserResponse); @@ -1499,7 +1553,7 @@ proxier! { } proxier! { - TestService; ALL_IMPLEMENTED_TEST_SERVICE_RPCS; TestServiceClient; test_client_mut; + TestService; ALL_IMPLEMENTED_TEST_SERVICE_RPCS; TestServiceClient; test_client; defaults; (lock_time_skipping, LockTimeSkippingRequest, LockTimeSkippingResponse); (unlock_time_skipping, UnlockTimeSkippingRequest, UnlockTimeSkippingResponse); (sleep, SleepRequest, SleepResponse); @@ -1509,7 +1563,7 @@ proxier! { } proxier! { - HealthService; ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS; HealthClient; health_client_mut; + HealthService; ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS; HealthClient; health_client; (check, HealthCheckRequest, HealthCheckResponse); (watch, HealthCheckRequest, tonic::codec::Streaming); } @@ -1522,6 +1576,7 @@ mod tests { use temporal_sdk_core_protos::temporal::api::{ operatorservice::v1::DeleteNamespaceRequest, workflowservice::v1::ListNamespacesRequest, }; + use tonic::IntoRequest; // Just to help make sure some stuff compiles. Not run. #[allow(dead_code)] @@ -1532,7 +1587,7 @@ mod tests { let list_ns_req = ListNamespacesRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.workflow_client_mut().clone(); + let mut c = c.workflow_client(); async move { c.list_namespaces(req).await }.boxed() }; retry_client @@ -1543,7 +1598,7 @@ mod tests { // Operator svc method let op_del_ns_req = DeleteNamespaceRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.operator_client_mut().clone(); + let mut c = c.operator_client(); async move { c.delete_namespace(req).await }.boxed() }; retry_client @@ -1554,7 +1609,7 @@ mod tests { // Cloud svc method let cloud_del_ns_req = cloudreq::DeleteNamespaceRequest::default(); let fact = |c: &mut RetryClient<_>, req| { - let mut c = c.cloud_client_mut().clone(); + let mut c = c.cloud_client(); async move { c.delete_namespace(req).await }.boxed() }; retry_client @@ -1563,17 +1618,23 @@ mod tests { .unwrap(); // Verify calling through traits works - retry_client.list_namespaces(list_ns_req).await.unwrap(); + retry_client + .list_namespaces(list_ns_req.into_request()) + .await + .unwrap(); // Have to disambiguate operator and cloud service - OperatorService::delete_namespace(&mut retry_client, op_del_ns_req) + OperatorService::delete_namespace(&mut retry_client, op_del_ns_req.into_request()) .await .unwrap(); - CloudService::delete_namespace(&mut retry_client, cloud_del_ns_req) + CloudService::delete_namespace(&mut retry_client, cloud_del_ns_req.into_request()) .await .unwrap(); - retry_client.get_current_time(()).await.unwrap(); retry_client - .check(HealthCheckRequest::default()) + .get_current_time(().into_request()) + .await + .unwrap(); + retry_client + .check(HealthCheckRequest::default().into_request()) .await .unwrap(); } @@ -1639,4 +1700,65 @@ mod tests { let proto_def = include_str!("../../sdk-core-protos/protos/grpc/health/v1/health.proto"); verify_methods(proto_def, ALL_IMPLEMENTED_HEALTH_SERVICE_RPCS); } + + #[tokio::test] + async fn can_mock_workflow_service() { + #[derive(Clone)] + struct MyFakeServices {} + impl RawGrpcCaller for MyFakeServices {} + impl WorkflowService for MyFakeServices { + fn list_namespaces( + &mut self, + _request: Request, + ) -> BoxFuture<'_, Result, Status>> { + async { + Ok(Response::new(ListNamespacesResponse { + namespaces: vec![DescribeNamespaceResponse { + failover_version: 12345, + ..Default::default() + }], + ..Default::default() + })) + } + .boxed() + } + } + impl OperatorService for MyFakeServices {} + impl CloudService for MyFakeServices {} + impl TestService for MyFakeServices {} + // Health service isn't possible to create a default impl for. + impl HealthService for MyFakeServices { + fn check( + &mut self, + _request: tonic::Request, + ) -> BoxFuture<'_, Result, tonic::Status>> + { + todo!() + } + fn watch( + &mut self, + _request: tonic::Request, + ) -> BoxFuture< + '_, + Result< + tonic::Response>, + tonic::Status, + >, + > { + todo!() + } + } + let mut mocked_client = TemporalServiceClient::from_services( + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + Box::new(MyFakeServices {}), + ); + let r = mocked_client + .list_namespaces(ListNamespacesRequest::default().into_request()) + .await + .unwrap(); + assert_eq!(r.into_inner().namespaces[0].failover_version, 12345); + } } diff --git a/client/src/replaceable.rs b/client/src/replaceable.rs new file mode 100644 index 000000000..ada6b2b23 --- /dev/null +++ b/client/src/replaceable.rs @@ -0,0 +1,253 @@ +use crate::NamespacedClient; +use std::{ + borrow::Cow, + sync::{ + Arc, RwLock, + atomic::{AtomicU32, Ordering}, + }, +}; + +/// A client wrapper that allows replacing the underlying client at a later point in time. +/// Clones of this struct have a shared reference to the underlying client, and each clone also +/// has its own cached clone of the underlying client. Before every service call, a check is made +/// whether the shared client was replaced, and the cached clone is updated accordingly. +/// +/// This struct is fully thread-safe, and it works in a lock-free manner except when the client is +/// being replaced. A read-write lock is used then, with minimal locking time. +#[derive(Debug)] +pub struct SharedReplaceableClient +where + C: Clone + Send + Sync, +{ + shared_data: Arc>, + cloned_client: C, + cloned_generation: u32, +} + +#[derive(Debug)] +struct SharedClientData +where + C: Clone + Send + Sync, +{ + client: RwLock, + generation: AtomicU32, +} + +impl SharedClientData +where + C: Clone + Send + Sync, +{ + fn fetch(&self) -> (C, u32) { + let lock = self.client.read().unwrap(); + let client = lock.clone(); + // Loading generation under lock to ensure the client won't be updated in the meantime. + let generation = self.generation.load(Ordering::Acquire); + (client, generation) + } + + fn fetch_newer_than(&self, current_generation: u32) -> Option<(C, u32)> { + // fetch() will do a second atomic load, but it's necessary to avoid a race condition. + (current_generation != self.generation.load(Ordering::Acquire)).then(|| self.fetch()) + } + + fn replace_client(&self, client: C) { + let mut lock = self.client.write().unwrap(); + *lock = client; + // Updating generation under lock to guarantee consistency when multiple threads replace the + // client at the same time. The client stored last is always the one with latest generation. + self.generation.fetch_add(1, Ordering::AcqRel); + } +} + +impl SharedReplaceableClient +where + C: Clone + Send + Sync, +{ + /// Creates the initial instance of replaceable client with the provided underlying client. + /// Use [`clone()`](Self::clone) method to create more instances that share the same underlying client. + pub fn new(client: C) -> Self { + let cloned_client = client.clone(); + Self { + shared_data: Arc::new(SharedClientData { + client: RwLock::new(client), + generation: AtomicU32::new(0), + }), + cloned_client, + cloned_generation: 0, + } + } + + /// Replaces the client for all instances that share this instance's underlying client. + pub fn replace_client(&self, new_client: C) { + self.shared_data.replace_client(new_client); // cloned_client will be updated on next mutable call + } + + /// Returns a clone of the underlying client. + pub fn inner_clone(&self) -> C { + self.inner_cow().into_owned() + } + + /// Returns an immutable reference to this instance's cached clone of the underlying client if + /// it's up to date, or a fresh clone of the shared client otherwise. Because it's an immutable + /// method, it will not update this instance's cached clone. For this reason, prefer to use + /// [`inner_mut_refreshed()`](Self::inner_mut_refreshed) when possible. + pub fn inner_cow(&self) -> Cow<'_, C> { + self.shared_data + .fetch_newer_than(self.cloned_generation) + .map(|(c, _)| Cow::Owned(c)) + .unwrap_or_else(|| Cow::Borrowed(&self.cloned_client)) + } + + /// Returns a mutable reference to this instance's cached clone of the underlying client. If the + /// cached clone is not up to date, it's refreshed before the reference is returned. This method + /// is called automatically by most other mutable methods, in particular by all service calls, + /// so most of the time it doesn't need to be called directly. + /// + /// While this method allows mutable access to the underlying client, any configuration changes + /// will not be shared with other instances, and will be lost if the client gets replaced from + /// anywhere. To make configuration changes, use [`replace_client()`](Self::replace_client) instead. + pub fn inner_mut_refreshed(&mut self) -> &mut C { + if let Some((client, generation)) = + self.shared_data.fetch_newer_than(self.cloned_generation) + { + self.cloned_client = client; + self.cloned_generation = generation; + } + &mut self.cloned_client + } +} + +impl Clone for SharedReplaceableClient +where + C: Clone + Send + Sync, +{ + /// Creates a new instance of replaceable client that shares the underlying client with this + /// instance. Replacing a client in either instance will replace it for both instances, and all + /// other clones too. + fn clone(&self) -> Self { + // self's cloned_client could've been modified through a mutable reference, + // so for consistent behavior, we need to fetch it from shared_data. + let (client, generation) = self.shared_data.fetch(); + Self { + shared_data: self.shared_data.clone(), + cloned_client: client, + cloned_generation: generation, + } + } +} + +impl NamespacedClient for SharedReplaceableClient +where + C: NamespacedClient + Clone + Send + Sync, +{ + fn namespace(&self) -> String { + self.inner_cow().namespace() + } + + fn identity(&self) -> String { + self.inner_cow().identity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::NamespacedClient; + use std::borrow::Cow; + + #[derive(Debug, Clone)] + struct StubClient { + identity: String, + } + + impl StubClient { + fn new(identity: &str) -> Self { + Self { + identity: identity.to_owned(), + } + } + } + + impl NamespacedClient for StubClient { + fn namespace(&self) -> String { + "default".into() + } + + fn identity(&self) -> String { + self.identity.clone() + } + } + + #[test] + fn cow_returns_reference_before_and_clone_after_refresh() { + let mut client = SharedReplaceableClient::new(StubClient::new("1")); + let Cow::Borrowed(inner) = client.inner_cow() else { + panic!("expected borrowed inner"); + }; + assert_eq!(inner.identity, "1"); + + client.replace_client(StubClient::new("2")); + let Cow::Owned(inner) = client.inner_cow() else { + panic!("expected owned inner"); + }; + assert_eq!(inner.identity, "2"); + + assert_eq!(client.inner_mut_refreshed().identity, "2"); + let Cow::Borrowed(inner) = client.inner_cow() else { + panic!("expected borrowed inner"); + }; + assert_eq!(inner.identity, "2"); + } + + #[test] + fn client_replaced_in_clones() { + let original1 = SharedReplaceableClient::new(StubClient::new("1")); + let clone1 = original1.clone(); + assert_eq!(original1.identity(), "1"); + assert_eq!(clone1.identity(), "1"); + + original1.replace_client(StubClient::new("2")); + assert_eq!(original1.identity(), "2"); + assert_eq!(clone1.identity(), "2"); + + let original2 = SharedReplaceableClient::new(StubClient::new("3")); + let clone2 = original2.clone(); + assert_eq!(original2.identity(), "3"); + assert_eq!(clone2.identity(), "3"); + + clone2.replace_client(StubClient::new("4")); + assert_eq!(original2.identity(), "4"); + assert_eq!(clone2.identity(), "4"); + assert_eq!(original1.identity(), "2"); + assert_eq!(clone1.identity(), "2"); + } + + #[test] + fn client_replaced_from_multiple_threads() { + let mut client = SharedReplaceableClient::new(StubClient::new("original")); + std::thread::scope(|scope| { + for thread_no in 0..100 { + let mut client = client.clone(); + scope.spawn(move || { + for i in 0..1000 { + let old_generation = client.cloned_generation; + client.inner_mut_refreshed(); + let current_generation = client.cloned_generation; + assert!(current_generation >= old_generation); + let replace_identity = format!("{thread_no}-{i}"); + client.replace_client(StubClient::new(&replace_identity)); + client.inner_mut_refreshed(); + assert!(client.cloned_generation > current_generation); + let refreshed_identity = client.identity(); + if refreshed_identity.split('-').next().unwrap() == thread_no.to_string() { + assert_eq!(replace_identity, refreshed_identity); + } + } + }); + } + }); + client.inner_mut_refreshed(); + assert_eq!(client.cloned_generation, 100_000); + assert!(client.identity().ends_with("-999")); + } +} diff --git a/client/src/retry.rs b/client/src/retry.rs index b02ff0528..698ac6e28 100644 --- a/client/src/retry.rs +++ b/client/src/retry.rs @@ -1,5 +1,5 @@ use crate::{ - Client, ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY, + ERROR_RETURNED_DUE_TO_SHORT_CIRCUIT, IsWorkerTaskLongPoll, MESSAGE_TOO_LARGE_KEY, NamespacedClient, NoRetryOnMatching, Result, RetryConfig, raw::IsUserLongPoll, }; use backoff::{Clock, SystemClock, backoff::Backoff, exponential::ExponentialBackoff}; @@ -98,13 +98,16 @@ impl RetryClient { } } -impl NamespacedClient for RetryClient { - fn namespace(&self) -> &str { - &self.client.namespace +impl NamespacedClient for RetryClient +where + SG: NamespacedClient, +{ + fn namespace(&self) -> String { + self.client.namespace() } - fn get_identity(&self) -> &str { - &self.client.options().identity + fn identity(&self) -> String { + self.client.identity() } } diff --git a/client/src/workflow_handle/mod.rs b/client/src/workflow_handle/mod.rs index af6d7ab16..6c112a888 100644 --- a/client/src/workflow_handle/mod.rs +++ b/client/src/workflow_handle/mod.rs @@ -1,4 +1,4 @@ -use crate::{InterceptedMetricsSvc, RawClientLike, WorkflowService}; +use crate::WorkflowService; use anyhow::{anyhow, bail}; use std::{fmt::Debug, marker::PhantomData}; use temporal_sdk_core_protos::{ @@ -11,6 +11,7 @@ use temporal_sdk_core_protos::{ workflowservice::v1::GetWorkflowExecutionHistoryRequest, }, }; +use tonic::IntoRequest; /// Enumerates terminal states for a particular workflow execution // TODO: Add non-proto failure types, flesh out details, etc. @@ -81,7 +82,7 @@ impl WorkflowExecutionInfo { /// Bind the workflow info to a specific client, turning it into a workflow handle pub fn bind_untyped(self, client: CT) -> UntypedWorkflowHandle where - CT: RawClientLike + Clone, + CT: WorkflowService + Clone, { UntypedWorkflowHandle::new(client, self) } @@ -92,7 +93,7 @@ pub(crate) type UntypedWorkflowHandle = WorkflowHandle>; impl WorkflowHandle where - CT: RawClientLike + Clone, + CT: WorkflowService + Clone, // TODO: Make more generic, capable of (de)serialization w/ serde RT: FromPayloadsExt, { @@ -125,18 +126,21 @@ where let server_res = self .client .clone() - .get_workflow_execution_history(GetWorkflowExecutionHistoryRequest { - namespace: self.info.namespace.to_string(), - execution: Some(WorkflowExecution { - workflow_id: self.info.workflow_id.clone(), - run_id: run_id.clone(), - }), - skip_archival: true, - wait_new_event: true, - history_event_filter_type: HistoryEventFilterType::CloseEvent as i32, - next_page_token: next_page_tok.clone(), - ..Default::default() - }) + .get_workflow_execution_history( + GetWorkflowExecutionHistoryRequest { + namespace: self.info.namespace.to_string(), + execution: Some(WorkflowExecution { + workflow_id: self.info.workflow_id.clone(), + run_id: run_id.clone(), + }), + skip_archival: true, + wait_new_event: true, + history_event_filter_type: HistoryEventFilterType::CloseEvent as i32, + next_page_token: next_page_tok.clone(), + ..Default::default() + } + .into_request(), + ) .await? .into_inner(); diff --git a/core-api/src/envconfig.rs b/core-api/src/envconfig.rs index a52918627..38335c280 100644 --- a/core-api/src/envconfig.rs +++ b/core-api/src/envconfig.rs @@ -358,9 +358,6 @@ pub fn load_client_config_profile( profile.load_from_env(env_vars)?; } - // Apply API key → TLS auto-enabling logic - profile.apply_api_key_tls_logic(); - Ok(profile) } @@ -528,14 +525,6 @@ impl ClientConfigProfile { } Ok(()) } - - /// Apply automatic TLS enabling when API key is present - pub fn apply_api_key_tls_logic(&mut self) { - if self.api_key.is_some() && self.tls.is_none() { - // If API key is present but no TLS config exists, create one with TLS enabled - self.tls = Some(ClientConfigTLS::default()); - } - } } /// Helper to check if any of the given environment variables are set. @@ -1582,27 +1571,6 @@ address = "localhost:7233" assert_eq!(profile.namespace.as_ref().unwrap(), "env-namespace"); } - #[test] - fn test_api_key_tls_auto_enable() { - // Test 1: When API key is present, TLS should be automatically enabled - let toml_str = r#" -[profile.default] -api_key = "my-api-key" -"#; - - let options = LoadClientConfigProfileOptions { - config_source: Some(DataSource::Data(toml_str.as_bytes().to_vec())), - ..Default::default() - }; - - let profile = load_client_config_profile(options, None).unwrap(); - - // TLS should be enabled due to API key presence - assert!(profile.tls.is_some()); - let tls = profile.tls.as_ref().unwrap(); - assert_eq!(tls.disabled, None); // Not explicitly set - } - #[test] fn test_no_api_key_no_tls_is_none() { // Test that if no API key is present and no TLS block exists, TLS config is None @@ -1624,12 +1592,38 @@ address = "some-address" #[test] fn test_load_client_config_profile_from_system_env() { - // Set up system env vars. These tests can't be run in parallel. - unsafe { - std::env::set_var("TEMPORAL_ADDRESS", "system-address"); - std::env::set_var("TEMPORAL_NAMESPACE", "system-namespace"); - } + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()); + let output = std::process::Command::new(cargo) + .arg("test") + .arg("-F") + .arg("envconfig") + .arg("envconfig::tests::test_load_client_config_profile_from_system_env_impl") + .arg("--") + .arg("--exact") + .arg("--ignored") + .env("TEMPORAL_ADDRESS", "system-address") + .env("TEMPORAL_NAMESPACE", "system-namespace") + .output() + .expect("Failed to execute subprocess test"); + + assert!( + output.status.success(), + "Subprocess test failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + ); + } + #[test] + #[ignore] // Only run when explicitly called + fn test_load_client_config_profile_from_system_env_impl() { + // Check if we're in the right context + if std::env::var("TEMPORAL_ADDRESS").is_err() + || std::env::var("TEMPORAL_NAMESPACE").is_err() + { + eprintln!("Skipping test - required env vars not set"); + return; // Early return instead of panic + } let options = LoadClientConfigProfileOptions { disable_file: true, // Don't load from any files ..Default::default() @@ -1639,12 +1633,6 @@ address = "some-address" let profile = load_client_config_profile(options, None).unwrap(); assert_eq!(profile.address.as_ref().unwrap(), "system-address"); assert_eq!(profile.namespace.as_ref().unwrap(), "system-namespace"); - - // Clean up - unsafe { - std::env::remove_var("TEMPORAL_ADDRESS"); - std::env::remove_var("TEMPORAL_NAMESPACE"); - } } #[test] diff --git a/core-c-bridge/Cargo.toml b/core-c-bridge/Cargo.toml index e6e7578fe..e2e361aee 100644 --- a/core-c-bridge/Cargo.toml +++ b/core-c-bridge/Cargo.toml @@ -10,6 +10,7 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0" async-trait = "0.1" +crossbeam-utils = "0.8" futures-util = { version = "0.3", default-features = false } http = "1.3" libc = "0.2" @@ -19,6 +20,7 @@ prost = { workspace = true } # cause non-determinism. rand = "0.9.2" rand_pcg = "0.9.0" +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = "1.47" tokio-stream = "0.1" @@ -26,6 +28,8 @@ tokio-util = "0.7" tonic = { workspace = true } tracing = "0.1" url = "2.5" +# This is only needed as an explicit dependency so we can enable static as a feature +xz2 = { version = "0.1" } [dependencies.temporal-client] path = "../client" @@ -36,6 +40,7 @@ features = ["ephemeral-server"] [dependencies.temporal-sdk-core-api] path = "../core-api" +features = ["envconfig"] [dependencies.temporal-sdk-core-protos] path = "../sdk-core-protos" @@ -44,6 +49,8 @@ path = "../sdk-core-protos" futures-util = "0.3" thiserror = { workspace = true } - [build-dependencies] cbindgen = { version = "0.29", default-features = false } + +[features] +xz2-static = ["xz2/static"] diff --git a/core-c-bridge/include/temporal-sdk-core-c-bridge.h b/core-c-bridge/include/temporal-sdk-core-c-bridge.h index b6ee3e275..3390fd2e3 100644 --- a/core-c-bridge/include/temporal-sdk-core-c-bridge.h +++ b/core-c-bridge/include/temporal-sdk-core-c-bridge.h @@ -81,6 +81,8 @@ typedef struct TemporalCoreRandom TemporalCoreRandom; typedef struct TemporalCoreRuntime TemporalCoreRuntime; +typedef struct TemporalCoreSlotReserveCompletionCtx TemporalCoreSlotReserveCompletionCtx; + typedef struct TemporalCoreWorker TemporalCoreWorker; typedef struct TemporalCoreWorkerReplayPusher TemporalCoreWorkerReplayPusher; @@ -240,6 +242,53 @@ typedef void (*TemporalCoreClientRpcCallCallback)(void *user_data, const struct TemporalCoreByteArray *failure_message, const struct TemporalCoreByteArray *failure_details); +/** + * OrFail result for client config loading operations. + * Either success or fail will be null, but never both. + * If success is not null, it contains JSON-serialized client configuration data. + * If fail is not null, it contains UTF-8 encoded error message. + * The returned ByteArrays must be freed by the caller. + */ +typedef struct TemporalCoreClientEnvConfigOrFail { + const struct TemporalCoreByteArray *success; + const struct TemporalCoreByteArray *fail; +} TemporalCoreClientEnvConfigOrFail; + +/** + * Options for loading client configuration. + */ +typedef struct TemporalCoreClientEnvConfigLoadOptions { + struct TemporalCoreByteArrayRef path; + struct TemporalCoreByteArrayRef data; + bool config_file_strict; + struct TemporalCoreByteArrayRef env_vars; +} TemporalCoreClientEnvConfigLoadOptions; + +/** + * OrFail result for client config profile loading operations. + * Either success or fail will be null, but never both. + * If success is not null, it contains JSON-serialized client configuration profile data. + * If fail is not null, it contains UTF-8 encoded error message. + * The returned ByteArrays must be freed by the caller. + */ +typedef struct TemporalCoreClientEnvConfigProfileOrFail { + const struct TemporalCoreByteArray *success; + const struct TemporalCoreByteArray *fail; +} TemporalCoreClientEnvConfigProfileOrFail; + +/** + * Options for loading a specific client configuration profile. + */ +typedef struct TemporalCoreClientEnvConfigProfileLoadOptions { + struct TemporalCoreByteArrayRef profile; + struct TemporalCoreByteArrayRef path; + struct TemporalCoreByteArrayRef data; + bool disable_file; + bool disable_env; + bool config_file_strict; + struct TemporalCoreByteArrayRef env_vars; +} TemporalCoreClientEnvConfigProfileLoadOptions; + typedef union TemporalCoreMetricAttributeValue { struct TemporalCoreByteArrayRef string_value; int64_t int_value; @@ -524,18 +573,17 @@ typedef struct TemporalCoreSlotReserveCtx { struct TemporalCoreByteArrayRef worker_identity; struct TemporalCoreByteArrayRef worker_build_id; bool is_sticky; - void *token_src; } TemporalCoreSlotReserveCtx; -typedef void (*TemporalCoreCustomReserveSlotCallback)(const struct TemporalCoreSlotReserveCtx *ctx, - void *sender); +typedef void (*TemporalCoreCustomSlotSupplierReserveCallback)(const struct TemporalCoreSlotReserveCtx *ctx, + const struct TemporalCoreSlotReserveCompletionCtx *completion_ctx, + void *user_data); -typedef void (*TemporalCoreCustomCancelReserveCallback)(void *token_source); +typedef void (*TemporalCoreCustomSlotSupplierCancelReserveCallback)(const struct TemporalCoreSlotReserveCompletionCtx *completion_ctx, + void *user_data); -/** - * Must return C#-tracked id for the permit. A zero value means no permit was reserved. - */ -typedef uintptr_t (*TemporalCoreCustomTryReserveSlotCallback)(const struct TemporalCoreSlotReserveCtx *ctx); +typedef uintptr_t (*TemporalCoreCustomSlotSupplierTryReserveCallback)(const struct TemporalCoreSlotReserveCtx *ctx, + void *user_data); typedef enum TemporalCoreSlotInfo_Tag { WorkflowSlotInfo, @@ -575,32 +623,87 @@ typedef struct TemporalCoreSlotInfo { typedef struct TemporalCoreSlotMarkUsedCtx { struct TemporalCoreSlotInfo slot_info; /** - * C# id for the slot permit. + * Lang-issued permit ID. */ uintptr_t slot_permit; } TemporalCoreSlotMarkUsedCtx; -typedef void (*TemporalCoreCustomMarkSlotUsedCallback)(const struct TemporalCoreSlotMarkUsedCtx *ctx); +typedef void (*TemporalCoreCustomSlotSupplierMarkUsedCallback)(const struct TemporalCoreSlotMarkUsedCtx *ctx, + void *user_data); typedef struct TemporalCoreSlotReleaseCtx { const struct TemporalCoreSlotInfo *slot_info; /** - * C# id for the slot permit. + * Lang-issued permit ID. */ uintptr_t slot_permit; } TemporalCoreSlotReleaseCtx; -typedef void (*TemporalCoreCustomReleaseSlotCallback)(const struct TemporalCoreSlotReleaseCtx *ctx); +typedef void (*TemporalCoreCustomSlotSupplierReleaseCallback)(const struct TemporalCoreSlotReleaseCtx *ctx, + void *user_data); + +typedef bool (*TemporalCoreCustomSlotSupplierAvailableSlotsCallback)(uintptr_t *available_slots, + void *user_data); -typedef void (*TemporalCoreCustomSlotImplFreeCallback)(const struct TemporalCoreCustomSlotSupplierCallbacks *userimpl); +typedef void (*TemporalCoreCustomSlotSupplierFreeCallback)(const struct TemporalCoreCustomSlotSupplierCallbacks *userimpl); typedef struct TemporalCoreCustomSlotSupplierCallbacks { - TemporalCoreCustomReserveSlotCallback reserve; - TemporalCoreCustomCancelReserveCallback cancel_reserve; - TemporalCoreCustomTryReserveSlotCallback try_reserve; - TemporalCoreCustomMarkSlotUsedCallback mark_used; - TemporalCoreCustomReleaseSlotCallback release; - TemporalCoreCustomSlotImplFreeCallback free; + /** + * Called to initiate asynchronous slot reservation. `ctx` contains information about + * reservation request. The pointer is only valid for the duration of the function call; the + * implementation should copy the data out of it for later use, and return as soon as possible. + * + * When slot is reserved, the implementation should call [`temporal_core_complete_async_reserve`] + * with the same `completion_ctx` as passed to this function. Reservation cannot be cancelled + * by Lang, but it can be cancelled by Core through [`cancel_reserve`](Self::cancel_reserve) + * callback. If reservation was cancelled, [`temporal_core_complete_async_cancel_reserve`] + * should be called instead. + * + * Slot reservation cannot error. The implementation should recover from errors and keep trying + * to reserve a slot until it eventually succeeds, or until reservation is cancelled by Core. + */ + TemporalCoreCustomSlotSupplierReserveCallback reserve; + /** + * Called to cancel slot reservation. `completion_ctx` specifies which reservation is being + * cancelled; the matching [`reserve`](Self::reserve) call was made with the same `completion_ctx`. + * After cancellation, the implementation should call [`temporal_core_complete_async_cancel_reserve`] + * with the same `completion_ctx`. Calling [`temporal_core_complete_async_reserve`] is not + * needed after cancellation. + */ + TemporalCoreCustomSlotSupplierCancelReserveCallback cancel_reserve; + /** + * Called to try an immediate slot reservation. The callback should return 0 if immediate + * reservation is not currently possible, or permit ID if reservation was successful. Permit ID + * is arbitrary, but must be unique among live reservations as it's later used for [`mark_used`](Self::mark_used) + * and [`release`](Self::release) callbacks. + */ + TemporalCoreCustomSlotSupplierTryReserveCallback try_reserve; + /** + * Called after successful reservation to mark slot as used. See [`SlotSupplier`](temporal_sdk_core_api::worker::SlotSupplier) + * trait for details. + */ + TemporalCoreCustomSlotSupplierMarkUsedCallback mark_used; + /** + * Called to free a previously reserved slot. + */ + TemporalCoreCustomSlotSupplierReleaseCallback release; + /** + * Called to retrieve the number of available slots if known. If the implementation knows how + * many slots are available at the moment, it should set the value behind the `available_slots` + * pointer and return true. If that number is unknown, it should return false. + * + * This function pointer can be set to null. It will be treated as if the number of available + * slots is never known. + */ + TemporalCoreCustomSlotSupplierAvailableSlotsCallback available_slots; + /** + * Called when the slot supplier is being dropped. All resources should be freed. + */ + TemporalCoreCustomSlotSupplierFreeCallback free; + /** + * Passed as an extra argument to the callbacks. + */ + void *user_data; } TemporalCoreCustomSlotSupplierCallbacks; typedef struct TemporalCoreCustomSlotSupplierCallbacksImpl { @@ -772,6 +875,20 @@ void temporal_core_client_rpc_call(struct TemporalCoreClient *client, void *user_data, TemporalCoreClientRpcCallCallback callback); +/** + * Load all client profiles from given sources. + * Returns ClientConfigOrFail with either success JSON or error message. + * The returned ByteArrays must be freed by the caller. + */ +struct TemporalCoreClientEnvConfigOrFail temporal_core_client_env_config_load(const struct TemporalCoreClientEnvConfigLoadOptions *options); + +/** + * Load a single client profile from given sources with env overrides. + * Returns ClientConfigProfileOrFail with either success JSON or error message. + * The returned ByteArrays must be freed by the caller. + */ +struct TemporalCoreClientEnvConfigProfileOrFail temporal_core_client_env_config_profile_load(const struct TemporalCoreClientEnvConfigProfileLoadOptions *options); + struct TemporalCoreMetricMeter *temporal_core_metric_meter_new(struct TemporalCoreRuntime *runtime); void temporal_core_metric_meter_free(struct TemporalCoreMetricMeter *meter); @@ -924,10 +1041,44 @@ struct TemporalCoreWorkerReplayPushResult temporal_core_worker_replay_push(struc struct TemporalCoreByteArrayRef workflow_id, struct TemporalCoreByteArrayRef history); -void temporal_core_complete_async_reserve(void *sender, uintptr_t permit_id); +/** + * Completes asynchronous slot reservation started by a call to [`CustomSlotSupplierCallbacks::reserve`]. + * + * `completion_ctx` must be the same as the one passed to the matching [`reserve`](CustomSlotSupplierCallbacks::reserve) + * call. `permit_id` is arbitrary, but must be unique among live reservations as it's later used + * for [`mark_used`](CustomSlotSupplierCallbacks::mark_used) and [`release`](CustomSlotSupplierCallbacks::release) + * callbacks. + * + * This function returns true if the reservation was completed successfully, or false if the + * reservation was cancelled before completion. If this function returns false, the implementation + * should call [`temporal_core_complete_async_cancel_reserve`] with the same `completion_ctx`. + * + * **Caution:** if this function returns true, `completion_ctx` gets freed. Afterwards, calling + * either [`temporal_core_complete_async_reserve`] or [`temporal_core_complete_async_cancel_reserve`] + * with the same `completion_ctx` will cause **memory corruption!** + */ +bool temporal_core_complete_async_reserve(const struct TemporalCoreSlotReserveCompletionCtx *completion_ctx, + uintptr_t permit_id); -void temporal_core_set_reserve_cancel_target(struct TemporalCoreSlotReserveCtx *ctx, - void *token_ptr); +/** + * Completes cancellation of asynchronous slot reservation. + * + * Cancellation can only be initiated by Core. It's done by calling [`CustomSlotSupplierCallbacks::cancel_reserve`] + * after an earlier call to [`CustomSlotSupplierCallbacks::reserve`]. + * + * `completion_ctx` must be the same as the one passed to the matching [`cancel_reserve`](CustomSlotSupplierCallbacks::cancel_reserve) + * call. + * + * This function returns true on successful cancellation, or false if cancellation was not + * requested for the given `completion_ctx`. A false value indicates there's likely a logic bug in + * the implementation where it doesn't correctly wait for [`cancel_reserve`](CustomSlotSupplierCallbacks::cancel_reserve) + * callback to be called. + * + * **Caution:** if this function returns true, `completion_ctx` gets freed. Afterwards, calling + * either [`temporal_core_complete_async_reserve`] or [`temporal_core_complete_async_cancel_reserve`] + * with the same `completion_ctx` will cause **memory corruption!** + */ +bool temporal_core_complete_async_cancel_reserve(const struct TemporalCoreSlotReserveCompletionCtx *completion_ctx); #ifdef __cplusplus } // extern "C" diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index 28ba7799d..fabdde13f 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -16,8 +16,8 @@ use std::{ use temporal_client::{ ClientKeepAliveConfig, ClientOptions as CoreClientOptions, ClientOptionsBuilder, ClientTlsConfig, CloudService, ConfiguredClient, HealthService, HttpConnectProxyOptions, - OperatorService, RetryClient, RetryConfig, TemporalServiceClientWithMetrics, TestService, - TlsConfig, WorkflowService, callback_based, + OperatorService, RetryClient, RetryConfig, TemporalServiceClient, TestService, TlsConfig, + WorkflowService, callback_based, }; use tokio::sync::oneshot; use tonic::metadata::MetadataKey; @@ -79,7 +79,7 @@ pub struct ClientHttpConnectProxyOptions { pub password: ByteArrayRef, } -type CoreClient = RetryClient>; +type CoreClient = RetryClient>; pub struct Client { pub(crate) runtime: Runtime, @@ -528,16 +528,6 @@ pub extern "C" fn temporal_core_client_rpc_call( }); } -macro_rules! rpc_call { - ($client:ident, $call:ident, $call_name:ident) => { - if $call.retry { - rpc_resp($client.$call_name(rpc_req($call)?).await) - } else { - rpc_resp($client.into_inner().$call_name(rpc_req($call)?).await) - } - }; -} - macro_rules! rpc_call_on_trait { ($client:ident, $call:ident, $trait:tt, $call_name:ident) => { if $call.retry { @@ -555,119 +545,311 @@ async fn call_workflow_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "CountWorkflowExecutions" => rpc_call!(client, call, count_workflow_executions), - "CreateSchedule" => rpc_call!(client, call, create_schedule), - "CreateWorkflowRule" => rpc_call!(client, call, create_workflow_rule), - "DeleteSchedule" => rpc_call!(client, call, delete_schedule), - "DeleteWorkerDeployment" => rpc_call!(client, call, delete_worker_deployment), + "CountWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, count_workflow_executions) + } + "CreateSchedule" => rpc_call_on_trait!(client, call, WorkflowService, create_schedule), + "CreateWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, create_workflow_rule) + } + "DeleteSchedule" => rpc_call_on_trait!(client, call, WorkflowService, delete_schedule), + "DeleteWorkerDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_worker_deployment) + } "DeleteWorkerDeploymentVersion" => { - rpc_call!(client, call, delete_worker_deployment_version) - } - "DeleteWorkflowExecution" => rpc_call!(client, call, delete_workflow_execution), - "DeleteWorkflowRule" => rpc_call!(client, call, delete_workflow_rule), - "DeprecateNamespace" => rpc_call!(client, call, deprecate_namespace), - "DescribeBatchOperation" => rpc_call!(client, call, describe_batch_operation), - "DescribeDeployment" => rpc_call!(client, call, describe_deployment), - "DescribeNamespace" => rpc_call!(client, call, describe_namespace), - "DescribeSchedule" => rpc_call!(client, call, describe_schedule), - "DescribeTaskQueue" => rpc_call!(client, call, describe_task_queue), - "DescribeWorker" => rpc_call!(client, call, describe_worker), - "DescribeWorkerDeployment" => rpc_call!(client, call, describe_worker_deployment), + rpc_call_on_trait!( + client, + call, + WorkflowService, + delete_worker_deployment_version + ) + } + "DeleteWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_workflow_execution) + } + "DeleteWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, delete_workflow_rule) + } + "DeprecateNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, deprecate_namespace) + } + "DescribeBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_batch_operation) + } + "DescribeDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_deployment) + } + "DescribeNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_namespace) + } + "DescribeSchedule" => rpc_call_on_trait!(client, call, WorkflowService, describe_schedule), + "DescribeTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_task_queue) + } + "DescribeWorker" => rpc_call_on_trait!(client, call, WorkflowService, describe_worker), + "DescribeWorkerDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_worker_deployment) + } "DescribeWorkerDeploymentVersion" => { - rpc_call!(client, call, describe_worker_deployment_version) - } - "DescribeWorkflowExecution" => rpc_call!(client, call, describe_workflow_execution), - "DescribeWorkflowRule" => rpc_call!(client, call, describe_workflow_rule), - "ExecuteMultiOperation" => rpc_call!(client, call, execute_multi_operation), - "FetchWorkerConfig" => rpc_call!(client, call, fetch_worker_config), - "GetClusterInfo" => rpc_call!(client, call, get_cluster_info), - "GetCurrentDeployment" => rpc_call!(client, call, get_current_deployment), - "GetDeploymentReachability" => rpc_call!(client, call, get_deployment_reachability), - "GetSearchAttributes" => rpc_call!(client, call, get_search_attributes), - "GetSystemInfo" => rpc_call!(client, call, get_system_info), + rpc_call_on_trait!( + client, + call, + WorkflowService, + describe_worker_deployment_version + ) + } + "DescribeWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_workflow_execution) + } + "DescribeWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, describe_workflow_rule) + } + "ExecuteMultiOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, execute_multi_operation) + } + "FetchWorkerConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, fetch_worker_config) + } + "GetClusterInfo" => rpc_call_on_trait!(client, call, WorkflowService, get_cluster_info), + "GetCurrentDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, get_current_deployment) + } + "GetDeploymentReachability" => { + rpc_call_on_trait!(client, call, WorkflowService, get_deployment_reachability) + } + "GetSearchAttributes" => { + rpc_call_on_trait!(client, call, WorkflowService, get_search_attributes) + } + "GetSystemInfo" => rpc_call_on_trait!(client, call, WorkflowService, get_system_info), "GetWorkerBuildIdCompatibility" => { - rpc_call!(client, call, get_worker_build_id_compatibility) + rpc_call_on_trait!( + client, + call, + WorkflowService, + get_worker_build_id_compatibility + ) } "GetWorkerTaskReachability" => { - rpc_call!(client, call, get_worker_task_reachability) + rpc_call_on_trait!(client, call, WorkflowService, get_worker_task_reachability) } - "GetWorkerVersioningRules" => rpc_call!(client, call, get_worker_versioning_rules), - "GetWorkflowExecutionHistory" => rpc_call!(client, call, get_workflow_execution_history), + "GetWorkerVersioningRules" => { + rpc_call_on_trait!(client, call, WorkflowService, get_worker_versioning_rules) + } + "GetWorkflowExecutionHistory" => rpc_call_on_trait!( + client, + call, + WorkflowService, + get_workflow_execution_history + ), "GetWorkflowExecutionHistoryReverse" => { - rpc_call!(client, call, get_workflow_execution_history_reverse) + rpc_call_on_trait!( + client, + call, + WorkflowService, + get_workflow_execution_history_reverse + ) } "ListArchivedWorkflowExecutions" => { - rpc_call!(client, call, list_archived_workflow_executions) - } - "ListBatchOperations" => rpc_call!(client, call, list_batch_operations), - "ListClosedWorkflowExecutions" => rpc_call!(client, call, list_closed_workflow_executions), - "ListDeployments" => rpc_call!(client, call, list_deployments), - "ListNamespaces" => rpc_call!(client, call, list_namespaces), - "ListOpenWorkflowExecutions" => rpc_call!(client, call, list_open_workflow_executions), - "ListScheduleMatchingTimes" => rpc_call!(client, call, list_schedule_matching_times), - "ListSchedules" => rpc_call!(client, call, list_schedules), - "ListTaskQueuePartitions" => rpc_call!(client, call, list_task_queue_partitions), - "ListWorkerDeployments" => rpc_call!(client, call, list_worker_deployments), - "ListWorkers" => rpc_call!(client, call, list_workers), - "ListWorkflowExecutions" => rpc_call!(client, call, list_workflow_executions), - "ListWorkflowRules" => rpc_call!(client, call, list_workflow_rules), - "PatchSchedule" => rpc_call!(client, call, patch_schedule), - "PauseActivity" => rpc_call!(client, call, pause_activity), - "PollActivityTaskQueue" => rpc_call!(client, call, poll_activity_task_queue), - "PollNexusTaskQueue" => rpc_call!(client, call, poll_nexus_task_queue), - "PollWorkflowExecutionUpdate" => rpc_call!(client, call, poll_workflow_execution_update), - "PollWorkflowTaskQueue" => rpc_call!(client, call, poll_workflow_task_queue), - "QueryWorkflow" => rpc_call!(client, call, query_workflow), - "RecordActivityTaskHeartbeat" => rpc_call!(client, call, record_activity_task_heartbeat), + rpc_call_on_trait!( + client, + call, + WorkflowService, + list_archived_workflow_executions + ) + } + "ListBatchOperations" => { + rpc_call_on_trait!(client, call, WorkflowService, list_batch_operations) + } + "ListClosedWorkflowExecutions" => rpc_call_on_trait!( + client, + call, + WorkflowService, + list_closed_workflow_executions + ), + "ListDeployments" => rpc_call_on_trait!(client, call, WorkflowService, list_deployments), + "ListNamespaces" => rpc_call_on_trait!(client, call, WorkflowService, list_namespaces), + "ListOpenWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_open_workflow_executions) + } + "ListScheduleMatchingTimes" => { + rpc_call_on_trait!(client, call, WorkflowService, list_schedule_matching_times) + } + "ListSchedules" => rpc_call_on_trait!(client, call, WorkflowService, list_schedules), + "ListTaskQueuePartitions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_task_queue_partitions) + } + "ListWorkerDeployments" => { + rpc_call_on_trait!(client, call, WorkflowService, list_worker_deployments) + } + "ListWorkers" => rpc_call_on_trait!(client, call, WorkflowService, list_workers), + "ListWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, list_workflow_executions) + } + "ListWorkflowRules" => { + rpc_call_on_trait!(client, call, WorkflowService, list_workflow_rules) + } + "PatchSchedule" => rpc_call_on_trait!(client, call, WorkflowService, patch_schedule), + "PauseActivity" => rpc_call_on_trait!(client, call, WorkflowService, pause_activity), + "PollActivityTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_activity_task_queue) + } + "PollNexusTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_nexus_task_queue) + } + "PollWorkflowExecutionUpdate" => rpc_call_on_trait!( + client, + call, + WorkflowService, + poll_workflow_execution_update + ), + "PollWorkflowTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, poll_workflow_task_queue) + } + "QueryWorkflow" => rpc_call_on_trait!(client, call, WorkflowService, query_workflow), + "RecordActivityTaskHeartbeat" => rpc_call_on_trait!( + client, + call, + WorkflowService, + record_activity_task_heartbeat + ), "RecordActivityTaskHeartbeatById" => { - rpc_call!(client, call, record_activity_task_heartbeat_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + record_activity_task_heartbeat_by_id + ) + } + "RecordWorkerHeartbeat" => { + rpc_call_on_trait!(client, call, WorkflowService, record_worker_heartbeat) + } + "RegisterNamespace" => { + rpc_call_on_trait!(client, call, WorkflowService, register_namespace) } - "RecordWorkerHeartbeat" => rpc_call!(client, call, record_worker_heartbeat), - "RegisterNamespace" => rpc_call!(client, call, register_namespace), "RequestCancelWorkflowExecution" => { - rpc_call!(client, call, request_cancel_workflow_execution) + rpc_call_on_trait!( + client, + call, + WorkflowService, + request_cancel_workflow_execution + ) + } + "ResetActivity" => rpc_call_on_trait!(client, call, WorkflowService, reset_activity), + "ResetStickyTaskQueue" => { + rpc_call_on_trait!(client, call, WorkflowService, reset_sticky_task_queue) } - "ResetActivity" => rpc_call!(client, call, reset_activity), - "ResetStickyTaskQueue" => rpc_call!(client, call, reset_sticky_task_queue), - "ResetWorkflowExecution" => rpc_call!(client, call, reset_workflow_execution), - "RespondActivityTaskCanceled" => rpc_call!(client, call, respond_activity_task_canceled), + "ResetWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, reset_workflow_execution) + } + "RespondActivityTaskCanceled" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_canceled + ), "RespondActivityTaskCanceledById" => { - rpc_call!(client, call, respond_activity_task_canceled_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_canceled_by_id + ) } - "RespondActivityTaskCompleted" => rpc_call!(client, call, respond_activity_task_completed), + "RespondActivityTaskCompleted" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_completed + ), "RespondActivityTaskCompletedById" => { - rpc_call!(client, call, respond_activity_task_completed_by_id) + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_completed_by_id + ) + } + "RespondActivityTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_activity_task_failed) } - "RespondActivityTaskFailed" => rpc_call!(client, call, respond_activity_task_failed), "RespondActivityTaskFailedById" => { - rpc_call!(client, call, respond_activity_task_failed_by_id) - } - "RespondNexusTaskCompleted" => rpc_call!(client, call, respond_nexus_task_completed), - "RespondNexusTaskFailed" => rpc_call!(client, call, respond_nexus_task_failed), - "RespondQueryTaskCompleted" => rpc_call!(client, call, respond_query_task_completed), - "RespondWorkflowTaskCompleted" => rpc_call!(client, call, respond_workflow_task_completed), - "RespondWorkflowTaskFailed" => rpc_call!(client, call, respond_workflow_task_failed), - "ScanWorkflowExecutions" => rpc_call!(client, call, scan_workflow_executions), - "SetCurrentDeployment" => rpc_call!(client, call, set_current_deployment), + rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_activity_task_failed_by_id + ) + } + "RespondNexusTaskCompleted" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_nexus_task_completed) + } + "RespondNexusTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_nexus_task_failed) + } + "RespondQueryTaskCompleted" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_query_task_completed) + } + "RespondWorkflowTaskCompleted" => rpc_call_on_trait!( + client, + call, + WorkflowService, + respond_workflow_task_completed + ), + "RespondWorkflowTaskFailed" => { + rpc_call_on_trait!(client, call, WorkflowService, respond_workflow_task_failed) + } + "ScanWorkflowExecutions" => { + rpc_call_on_trait!(client, call, WorkflowService, scan_workflow_executions) + } + "SetCurrentDeployment" => { + rpc_call_on_trait!(client, call, WorkflowService, set_current_deployment) + } "SetWorkerDeploymentCurrentVersion" => { - rpc_call!(client, call, set_worker_deployment_current_version) + rpc_call_on_trait!( + client, + call, + WorkflowService, + set_worker_deployment_current_version + ) } "SetWorkerDeploymentManager" => { - rpc_call!(client, call, set_worker_deployment_manager) + rpc_call_on_trait!(client, call, WorkflowService, set_worker_deployment_manager) } "SetWorkerDeploymentRampingVersion" => { - rpc_call!(client, call, set_worker_deployment_ramping_version) + rpc_call_on_trait!( + client, + call, + WorkflowService, + set_worker_deployment_ramping_version + ) } - "ShutdownWorker" => rpc_call!(client, call, shutdown_worker), + "ShutdownWorker" => rpc_call_on_trait!(client, call, WorkflowService, shutdown_worker), "SignalWithStartWorkflowExecution" => { - rpc_call!(client, call, signal_with_start_workflow_execution) - } - "SignalWorkflowExecution" => rpc_call!(client, call, signal_workflow_execution), - "StartWorkflowExecution" => rpc_call!(client, call, start_workflow_execution), - "StartBatchOperation" => rpc_call!(client, call, start_batch_operation), - "StopBatchOperation" => rpc_call!(client, call, stop_batch_operation), - "TerminateWorkflowExecution" => rpc_call!(client, call, terminate_workflow_execution), - "TriggerWorkflowRule" => rpc_call!(client, call, trigger_workflow_rule), + rpc_call_on_trait!( + client, + call, + WorkflowService, + signal_with_start_workflow_execution + ) + } + "SignalWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, signal_workflow_execution) + } + "StartWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, start_workflow_execution) + } + "StartBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, start_batch_operation) + } + "StopBatchOperation" => { + rpc_call_on_trait!(client, call, WorkflowService, stop_batch_operation) + } + "TerminateWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, terminate_workflow_execution) + } + "TriggerWorkflowRule" => { + rpc_call_on_trait!(client, call, WorkflowService, trigger_workflow_rule) + } "UnpauseActivity" => { rpc_call_on_trait!(client, call, WorkflowService, unpause_activity) } @@ -675,19 +857,45 @@ async fn call_workflow_service( rpc_call_on_trait!(client, call, WorkflowService, update_activity_options) } "UpdateNamespace" => rpc_call_on_trait!(client, call, WorkflowService, update_namespace), - "UpdateSchedule" => rpc_call!(client, call, update_schedule), - "UpdateTaskQueueConfig" => rpc_call!(client, call, update_task_queue_config), - "UpdateWorkerConfig" => rpc_call!(client, call, update_worker_config), + "UpdateSchedule" => rpc_call_on_trait!(client, call, WorkflowService, update_schedule), + "UpdateTaskQueueConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, update_task_queue_config) + } + "UpdateWorkerConfig" => { + rpc_call_on_trait!(client, call, WorkflowService, update_worker_config) + } "UpdateWorkerDeploymentVersionMetadata" => { - rpc_call!(client, call, update_worker_deployment_version_metadata) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_deployment_version_metadata + ) + } + "UpdateWorkerVersioningRules" => rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_versioning_rules + ), + "UpdateWorkflowExecution" => { + rpc_call_on_trait!(client, call, WorkflowService, update_workflow_execution) } - "UpdateWorkerVersioningRules" => rpc_call!(client, call, update_worker_versioning_rules), - "UpdateWorkflowExecution" => rpc_call!(client, call, update_workflow_execution), "UpdateWorkflowExecutionOptions" => { - rpc_call!(client, call, update_workflow_execution_options) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_workflow_execution_options + ) } "UpdateWorkerBuildIdCompatibility" => { - rpc_call!(client, call, update_worker_build_id_compatibility) + rpc_call_on_trait!( + client, + call, + WorkflowService, + update_worker_build_id_compatibility + ) } rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } @@ -700,8 +908,12 @@ async fn call_operator_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "AddOrUpdateRemoteCluster" => rpc_call!(client, call, add_or_update_remote_cluster), - "AddSearchAttributes" => rpc_call!(client, call, add_search_attributes), + "AddOrUpdateRemoteCluster" => { + rpc_call_on_trait!(client, call, OperatorService, add_or_update_remote_cluster) + } + "AddSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, add_search_attributes) + } "CreateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, create_nexus_endpoint) } @@ -709,13 +921,20 @@ async fn call_operator_service( "DeleteNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, delete_nexus_endpoint) } - "DeleteWorkflowExecution" => rpc_call!(client, call, delete_workflow_execution), "GetNexusEndpoint" => rpc_call_on_trait!(client, call, OperatorService, get_nexus_endpoint), - "ListClusters" => rpc_call!(client, call, list_clusters), - "ListNexusEndpoints" => rpc_call!(client, call, list_nexus_endpoints), - "ListSearchAttributes" => rpc_call!(client, call, list_search_attributes), - "RemoveRemoteCluster" => rpc_call!(client, call, remove_remote_cluster), - "RemoveSearchAttributes" => rpc_call!(client, call, remove_search_attributes), + "ListClusters" => rpc_call_on_trait!(client, call, OperatorService, list_clusters), + "ListNexusEndpoints" => { + rpc_call_on_trait!(client, call, OperatorService, list_nexus_endpoints) + } + "ListSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, list_search_attributes) + } + "RemoveRemoteCluster" => { + rpc_call_on_trait!(client, call, OperatorService, remove_remote_cluster) + } + "RemoveSearchAttributes" => { + rpc_call_on_trait!(client, call, OperatorService, remove_search_attributes) + } "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, OperatorService, update_nexus_endpoint) } @@ -727,68 +946,116 @@ async fn call_cloud_service(client: &CoreClient, call: &RpcCallOptions) -> anyho let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "AddNamespaceRegion" => rpc_call!(client, call, add_namespace_region), - "AddUserGroupMember" => rpc_call!(client, call, add_user_group_member), - "CreateApiKey" => rpc_call!(client, call, create_api_key), - "CreateNamespace" => rpc_call!(client, call, create_namespace), - "CreateNamespaceExportSink" => rpc_call!(client, call, create_namespace_export_sink), + "AddNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, add_namespace_region) + } + "AddUserGroupMember" => { + rpc_call_on_trait!(client, call, CloudService, add_user_group_member) + } + "CreateApiKey" => rpc_call_on_trait!(client, call, CloudService, create_api_key), + "CreateNamespace" => rpc_call_on_trait!(client, call, CloudService, create_namespace), + "CreateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, create_namespace_export_sink) + } "CreateNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, create_nexus_endpoint) } - "CreateServiceAccount" => rpc_call!(client, call, create_service_account), - "CreateUserGroup" => rpc_call!(client, call, create_user_group), - "CreateUser" => rpc_call!(client, call, create_user), - "DeleteApiKey" => rpc_call!(client, call, delete_api_key), + "CreateServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, create_service_account) + } + "CreateUserGroup" => rpc_call_on_trait!(client, call, CloudService, create_user_group), + "CreateUser" => rpc_call_on_trait!(client, call, CloudService, create_user), + "DeleteApiKey" => rpc_call_on_trait!(client, call, CloudService, delete_api_key), "DeleteNamespace" => rpc_call_on_trait!(client, call, CloudService, delete_namespace), - "DeleteNamespaceExportSink" => rpc_call!(client, call, delete_namespace_export_sink), - "DeleteNamespaceRegion" => rpc_call!(client, call, delete_namespace_region), + "DeleteNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, delete_namespace_export_sink) + } + "DeleteNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, delete_namespace_region) + } "DeleteNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, delete_nexus_endpoint) } - "DeleteServiceAccount" => rpc_call!(client, call, delete_service_account), - "DeleteUserGroup" => rpc_call!(client, call, delete_user_group), - "DeleteUser" => rpc_call!(client, call, delete_user), - "FailoverNamespaceRegion" => rpc_call!(client, call, failover_namespace_region), - "GetAccount" => rpc_call!(client, call, get_account), - "GetApiKey" => rpc_call!(client, call, get_api_key), - "GetApiKeys" => rpc_call!(client, call, get_api_keys), - "GetAsyncOperation" => rpc_call!(client, call, get_async_operation), - "GetNamespace" => rpc_call!(client, call, get_namespace), - "GetNamespaceExportSink" => rpc_call!(client, call, get_namespace_export_sink), - "GetNamespaceExportSinks" => rpc_call!(client, call, get_namespace_export_sinks), - "GetNamespaces" => rpc_call!(client, call, get_namespaces), + "DeleteServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, delete_service_account) + } + "DeleteUserGroup" => rpc_call_on_trait!(client, call, CloudService, delete_user_group), + "DeleteUser" => rpc_call_on_trait!(client, call, CloudService, delete_user), + "FailoverNamespaceRegion" => { + rpc_call_on_trait!(client, call, CloudService, failover_namespace_region) + } + "GetAccount" => rpc_call_on_trait!(client, call, CloudService, get_account), + "GetApiKey" => rpc_call_on_trait!(client, call, CloudService, get_api_key), + "GetApiKeys" => rpc_call_on_trait!(client, call, CloudService, get_api_keys), + "GetAsyncOperation" => rpc_call_on_trait!(client, call, CloudService, get_async_operation), + "GetNamespace" => rpc_call_on_trait!(client, call, CloudService, get_namespace), + "GetNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, get_namespace_export_sink) + } + "GetNamespaceExportSinks" => { + rpc_call_on_trait!(client, call, CloudService, get_namespace_export_sinks) + } + "GetNamespaces" => rpc_call_on_trait!(client, call, CloudService, get_namespaces), "GetNexusEndpoint" => rpc_call_on_trait!(client, call, CloudService, get_nexus_endpoint), - "GetNexusEndpoints" => rpc_call!(client, call, get_nexus_endpoints), - "GetRegion" => rpc_call!(client, call, get_region), - "GetRegions" => rpc_call!(client, call, get_regions), - "GetServiceAccount" => rpc_call!(client, call, get_service_account), - "GetServiceAccounts" => rpc_call!(client, call, get_service_accounts), - "GetUsage" => rpc_call!(client, call, get_usage), - "GetUserGroup" => rpc_call!(client, call, get_user_group), - "GetUserGroupMembers" => rpc_call!(client, call, get_user_group_members), - "GetUserGroups" => rpc_call!(client, call, get_user_groups), - "GetUser" => rpc_call!(client, call, get_user), - "GetUsers" => rpc_call!(client, call, get_users), - "RemoveUserGroupMember" => rpc_call!(client, call, remove_user_group_member), - "RenameCustomSearchAttribute" => rpc_call!(client, call, rename_custom_search_attribute), - "SetUserGroupNamespaceAccess" => rpc_call!(client, call, set_user_group_namespace_access), - "SetUserNamespaceAccess" => rpc_call!(client, call, set_user_namespace_access), - "UpdateAccount" => rpc_call!(client, call, update_account), - "UpdateApiKey" => rpc_call!(client, call, update_api_key), + "GetNexusEndpoints" => rpc_call_on_trait!(client, call, CloudService, get_nexus_endpoints), + "GetRegion" => rpc_call_on_trait!(client, call, CloudService, get_region), + "GetRegions" => rpc_call_on_trait!(client, call, CloudService, get_regions), + "GetServiceAccount" => rpc_call_on_trait!(client, call, CloudService, get_service_account), + "GetServiceAccounts" => { + rpc_call_on_trait!(client, call, CloudService, get_service_accounts) + } + "GetUsage" => rpc_call_on_trait!(client, call, CloudService, get_usage), + "GetUserGroup" => rpc_call_on_trait!(client, call, CloudService, get_user_group), + "GetUserGroupMembers" => { + rpc_call_on_trait!(client, call, CloudService, get_user_group_members) + } + "GetUserGroups" => rpc_call_on_trait!(client, call, CloudService, get_user_groups), + "GetUser" => rpc_call_on_trait!(client, call, CloudService, get_user), + "GetUsers" => rpc_call_on_trait!(client, call, CloudService, get_users), + "RemoveUserGroupMember" => { + rpc_call_on_trait!(client, call, CloudService, remove_user_group_member) + } + "RenameCustomSearchAttribute" => { + rpc_call_on_trait!(client, call, CloudService, rename_custom_search_attribute) + } + "SetUserGroupNamespaceAccess" => { + rpc_call_on_trait!(client, call, CloudService, set_user_group_namespace_access) + } + "SetUserNamespaceAccess" => { + rpc_call_on_trait!(client, call, CloudService, set_user_namespace_access) + } + "UpdateAccount" => rpc_call_on_trait!(client, call, CloudService, update_account), + "UpdateApiKey" => rpc_call_on_trait!(client, call, CloudService, update_api_key), "UpdateNamespace" => rpc_call_on_trait!(client, call, CloudService, update_namespace), - "UpdateNamespaceExportSink" => rpc_call!(client, call, update_namespace_export_sink), + "UpdateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, update_namespace_export_sink) + } "UpdateNexusEndpoint" => { rpc_call_on_trait!(client, call, CloudService, update_nexus_endpoint) } - "UpdateServiceAccount" => rpc_call!(client, call, update_service_account), - "UpdateUserGroup" => rpc_call!(client, call, update_user_group), - "UpdateUser" => rpc_call!(client, call, update_user), - "ValidateNamespaceExportSink" => rpc_call!(client, call, validate_namespace_export_sink), - "UpdateNamespaceTags" => rpc_call!(client, call, update_namespace_tags), - "CreateConnectivityRule" => rpc_call!(client, call, create_connectivity_rule), - "GetConnectivityRule" => rpc_call!(client, call, get_connectivity_rule), - "GetConnectivityRules" => rpc_call!(client, call, get_connectivity_rules), - "DeleteConnectivityRule" => rpc_call!(client, call, delete_connectivity_rule), + "UpdateServiceAccount" => { + rpc_call_on_trait!(client, call, CloudService, update_service_account) + } + "UpdateUserGroup" => rpc_call_on_trait!(client, call, CloudService, update_user_group), + "UpdateUser" => rpc_call_on_trait!(client, call, CloudService, update_user), + "ValidateNamespaceExportSink" => { + rpc_call_on_trait!(client, call, CloudService, validate_namespace_export_sink) + } + "UpdateNamespaceTags" => { + rpc_call_on_trait!(client, call, CloudService, update_namespace_tags) + } + "CreateConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, create_connectivity_rule) + } + "GetConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, get_connectivity_rule) + } + "GetConnectivityRules" => { + rpc_call_on_trait!(client, call, CloudService, get_connectivity_rules) + } + "DeleteConnectivityRule" => { + rpc_call_on_trait!(client, call, CloudService, delete_connectivity_rule) + } rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -797,12 +1064,14 @@ async fn call_test_service(client: &CoreClient, call: &RpcCallOptions) -> anyhow let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "GetCurrentTime" => rpc_call!(client, call, get_current_time), - "LockTimeSkipping" => rpc_call!(client, call, lock_time_skipping), - "SleepUntil" => rpc_call!(client, call, sleep_until), - "Sleep" => rpc_call!(client, call, sleep), - "UnlockTimeSkippingWithSleep" => rpc_call!(client, call, unlock_time_skipping_with_sleep), - "UnlockTimeSkipping" => rpc_call!(client, call, unlock_time_skipping), + "GetCurrentTime" => rpc_call_on_trait!(client, call, TestService, get_current_time), + "LockTimeSkipping" => rpc_call_on_trait!(client, call, TestService, lock_time_skipping), + "SleepUntil" => rpc_call_on_trait!(client, call, TestService, sleep_until), + "Sleep" => rpc_call_on_trait!(client, call, TestService, sleep), + "UnlockTimeSkippingWithSleep" => { + rpc_call_on_trait!(client, call, TestService, unlock_time_skipping_with_sleep) + } + "UnlockTimeSkipping" => rpc_call_on_trait!(client, call, TestService, unlock_time_skipping), rpc => Err(anyhow::anyhow!("Unknown RPC call {rpc}")), } } @@ -814,7 +1083,7 @@ async fn call_health_service( let rpc = call.rpc.to_str(); let mut client = client.clone(); match rpc { - "Check" => rpc_call!(client, call, check), + "Check" => rpc_call_on_trait!(client, call, HealthService, check), "Watch" => Err(anyhow::anyhow!( "Health service Watch method is not implemented in C bridge" )), diff --git a/core-c-bridge/src/envconfig.rs b/core-c-bridge/src/envconfig.rs new file mode 100644 index 000000000..0e1f2d0de --- /dev/null +++ b/core-c-bridge/src/envconfig.rs @@ -0,0 +1,314 @@ +use crate::{ByteArray, ByteArrayRef}; +use serde::Serialize; +use std::collections::HashMap; +use temporal_sdk_core_api::envconfig::{ + self, ClientConfig as CoreClientConfig, ClientConfigCodec as CoreClientConfigCodec, + ClientConfigProfile as CoreClientConfigProfile, ClientConfigTLS as CoreClientConfigTLS, + DataSource as CoreDataSource, LoadClientConfigOptions, LoadClientConfigProfileOptions, +}; + +/// OrFail result for client config loading operations. +/// Either success or fail will be null, but never both. +/// If success is not null, it contains JSON-serialized client configuration data. +/// If fail is not null, it contains UTF-8 encoded error message. +/// The returned ByteArrays must be freed by the caller. +#[repr(C)] +pub struct ClientEnvConfigOrFail { + pub success: *const ByteArray, + pub fail: *const ByteArray, +} + +/// OrFail result for client config profile loading operations. +/// Either success or fail will be null, but never both. +/// If success is not null, it contains JSON-serialized client configuration profile data. +/// If fail is not null, it contains UTF-8 encoded error message. +/// The returned ByteArrays must be freed by the caller. +#[repr(C)] +pub struct ClientEnvConfigProfileOrFail { + pub success: *const ByteArray, + pub fail: *const ByteArray, +} + +/// Options for loading client configuration. +#[repr(C)] +pub struct ClientEnvConfigLoadOptions { + pub path: ByteArrayRef, + pub data: ByteArrayRef, + pub config_file_strict: bool, + pub env_vars: ByteArrayRef, +} + +/// Options for loading a specific client configuration profile. +#[repr(C)] +pub struct ClientEnvConfigProfileLoadOptions { + pub profile: ByteArrayRef, + pub path: ByteArrayRef, + pub data: ByteArrayRef, + pub disable_file: bool, + pub disable_env: bool, + pub config_file_strict: bool, + pub env_vars: ByteArrayRef, +} + +// Wrapper types for JSON serialization +#[derive(Serialize)] +struct ClientEnvConfig { + profiles: HashMap, +} + +impl From for ClientEnvConfig { + fn from(c: CoreClientConfig) -> Self { + Self { + profiles: c.profiles.into_iter().map(|(k, v)| (k, v.into())).collect(), + } + } +} + +#[derive(Serialize)] +struct ClientEnvConfigProfile { + #[serde(skip_serializing_if = "Option::is_none")] + address: Option, + #[serde(skip_serializing_if = "Option::is_none")] + namespace: Option, + #[serde(skip_serializing_if = "Option::is_none")] + api_key: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + codec: Option, + #[serde(skip_serializing_if = "HashMap::is_empty")] + grpc_meta: HashMap, +} + +impl From for ClientEnvConfigProfile { + fn from(c: CoreClientConfigProfile) -> Self { + Self { + address: c.address, + namespace: c.namespace, + api_key: c.api_key, + tls: c.tls.map(Into::into), + codec: c.codec.map(Into::into), + grpc_meta: c.grpc_meta, + } + } +} + +#[derive(Serialize)] +struct ClientEnvConfigTLS { + #[serde(skip_serializing_if = "Option::is_none")] + disabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] + server_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + server_ca_cert: Option, + #[serde(skip_serializing_if = "Option::is_none")] + client_cert: Option, + #[serde(skip_serializing_if = "Option::is_none")] + client_key: Option, +} + +impl From for ClientEnvConfigTLS { + fn from(c: CoreClientConfigTLS) -> Self { + Self { + disabled: c.disabled, + server_name: c.server_name, + server_ca_cert: c.server_ca_cert.map(Into::into), + client_cert: c.client_cert.map(Into::into), + client_key: c.client_key.map(Into::into), + } + } +} + +#[derive(Serialize)] +struct ClientEnvConfigCodec { + #[serde(skip_serializing_if = "Option::is_none")] + endpoint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + auth: Option, +} + +impl From for ClientEnvConfigCodec { + fn from(c: CoreClientConfigCodec) -> Self { + Self { + endpoint: c.endpoint, + auth: c.auth, + } + } +} + +#[derive(Serialize)] +struct DataSource { + #[serde(skip_serializing_if = "Option::is_none")] + path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option>, +} + +impl From for DataSource { + fn from(c: CoreDataSource) -> Self { + match c { + CoreDataSource::Path(p) => Self { + path: Some(p), + data: None, + }, + CoreDataSource::Data(d) => Self { + path: None, + data: Some(d), + }, + } + } +} + +// Helper functions +fn parse_config_source( + path: &ByteArrayRef, + data: &ByteArrayRef, +) -> Result, String> { + if !path.data.is_null() && path.size > 0 { + Ok(Some(CoreDataSource::Path(path.to_string()))) + } else if !data.data.is_null() && data.size > 0 { + Ok(Some(CoreDataSource::Data(data.to_vec()))) + } else { + Ok(None) + } +} + +fn parse_env_vars(env_vars: &ByteArrayRef) -> Result>, String> { + if env_vars.data.is_null() || env_vars.size == 0 { + return Ok(None); + } + + let env_json = std::str::from_utf8(env_vars.to_slice()) + .map_err(|e| format!("Invalid env vars UTF-8: {e}"))?; + + serde_json::from_str(env_json) + .map(Some) + .map_err(|e| format!("Invalid env vars JSON: {e}")) +} + +// Simple helper to handle serialization errors consistently +fn serialize_or_error(data: T) -> Result<*const ByteArray, *const ByteArray> { + match serde_json::to_vec(&data) { + Ok(json_bytes) => { + let result = ByteArray::from_vec(json_bytes); + Ok(result.into_raw()) + } + Err(e) => { + let err = ByteArray::from_utf8(format!("Failed to serialize: {e}")); + Err(err.into_raw()) + } + } +} + +/// Load all client profiles from given sources. +/// Returns ClientConfigOrFail with either success JSON or error message. +/// The returned ByteArrays must be freed by the caller. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_env_config_load( + options: *const ClientEnvConfigLoadOptions, +) -> ClientEnvConfigOrFail { + if options.is_null() { + let err = ByteArray::from_utf8("Options cannot be null".to_string()); + return ClientEnvConfigOrFail { + success: std::ptr::null(), + fail: err.into_raw(), + }; + } + + let result = || -> Result { + let opts = unsafe { &*options }; + let env_vars_map = parse_env_vars(&opts.env_vars)?; + + let load_options = LoadClientConfigOptions { + config_source: parse_config_source(&opts.path, &opts.data)?, + config_file_strict: opts.config_file_strict, + }; + + let core_config = envconfig::load_client_config(load_options, env_vars_map.as_ref()) + .map_err(|e| e.to_string())?; + + Ok(core_config.into()) + }; + + match result() { + Ok(data) => match serialize_or_error(data) { + Ok(success) => ClientEnvConfigOrFail { + success, + fail: std::ptr::null(), + }, + Err(fail) => ClientEnvConfigOrFail { + success: std::ptr::null(), + fail, + }, + }, + Err(e) => { + let err = ByteArray::from_utf8(e); + ClientEnvConfigOrFail { + success: std::ptr::null(), + fail: err.into_raw(), + } + } + } +} + +/// Load a single client profile from given sources with env overrides. +/// Returns ClientConfigProfileOrFail with either success JSON or error message. +/// The returned ByteArrays must be freed by the caller. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_env_config_profile_load( + options: *const ClientEnvConfigProfileLoadOptions, +) -> ClientEnvConfigProfileOrFail { + if options.is_null() { + let err = ByteArray::from_utf8("Options cannot be null".to_string()); + return ClientEnvConfigProfileOrFail { + success: std::ptr::null(), + fail: err.into_raw(), + }; + } + + let result = || -> Result { + let opts = unsafe { &*options }; + + let profile_name = if !opts.profile.data.is_null() && opts.profile.size > 0 { + Some(opts.profile.to_string()) + } else { + None + }; + + let config_source = parse_config_source(&opts.path, &opts.data)?; + let env_vars_map = parse_env_vars(&opts.env_vars)?; + + let load_options = LoadClientConfigProfileOptions { + config_source, + config_file_profile: profile_name, + config_file_strict: opts.config_file_strict, + disable_file: opts.disable_file, + disable_env: opts.disable_env, + }; + + let profile = envconfig::load_client_config_profile(load_options, env_vars_map.as_ref()) + .map_err(|e| e.to_string())?; + + Ok(profile.into()) + }; + + match result() { + Ok(data) => match serialize_or_error(data) { + Ok(success) => ClientEnvConfigProfileOrFail { + success, + fail: std::ptr::null(), + }, + Err(fail) => ClientEnvConfigProfileOrFail { + success: std::ptr::null(), + fail, + }, + }, + Err(e) => { + let err = ByteArray::from_utf8(e); + ClientEnvConfigProfileOrFail { + success: std::ptr::null(), + fail: err.into_raw(), + } + } + } +} diff --git a/core-c-bridge/src/lib.rs b/core-c-bridge/src/lib.rs index a4427f3b0..643310b37 100644 --- a/core-c-bridge/src/lib.rs +++ b/core-c-bridge/src/lib.rs @@ -9,6 +9,7 @@ )] pub mod client; +pub mod envconfig; pub mod metric; pub mod random; pub mod runtime; diff --git a/core-c-bridge/src/tests/mod.rs b/core-c-bridge/src/tests/mod.rs index 697f2cafc..169847d43 100644 --- a/core-c-bridge/src/tests/mod.rs +++ b/core-c-bridge/src/tests/mod.rs @@ -317,7 +317,7 @@ fn test_simple_callback_override() { })) .unwrap(); let start_resp = StartWorkflowExecutionResponse::decode(&*start_resp_raw).unwrap(); - assert!(start_resp.run_id == "run-id for my-workflow-id"); + assert_eq!(start_resp.run_id, "run-id for my-workflow-id"); // Try a query where a query failure will actually be delivered as failure details. // However, we don't currently have temporal_sdk_core_protos::google::rpc::Status in @@ -336,23 +336,23 @@ fn test_simple_callback_override() { .unwrap_err() .downcast::() .unwrap(); - assert!(query_err.status_code == tonic::Code::Internal as u32); - assert!(query_err.message == "query-fail"); - assert!( + assert_eq!(query_err.status_code, tonic::Code::Internal as u32); + assert_eq!(query_err.message, "query-fail"); + assert_eq!( Failure::decode(query_err.details.as_ref().unwrap().as_slice()) .unwrap() - .message - == "intentional failure" + .message, + "intentional failure" ); // Confirm we got the expected calls - assert!( - *CALLBACK_OVERRIDE_CALLS.lock().unwrap() - == vec![ - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: GetSystemInfo", - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: StartWorkflowExecution", - "service: temporal.api.workflowservice.v1.WorkflowService, rpc: QueryWorkflow" - ] + assert_eq!( + *CALLBACK_OVERRIDE_CALLS.lock().unwrap(), + vec![ + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: GetSystemInfo", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: StartWorkflowExecution", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: QueryWorkflow" + ] ); }); } diff --git a/core-c-bridge/src/worker.rs b/core-c-bridge/src/worker.rs index c752fdae0..a68f5696a 100644 --- a/core-c-bridge/src/worker.rs +++ b/core-c-bridge/src/worker.rs @@ -2,9 +2,11 @@ use crate::{ ByteArray, ByteArrayRef, ByteArrayRefArray, UserDataHandle, client::Client, runtime::Runtime, }; use anyhow::{Context, bail}; +use crossbeam_utils::atomic::AtomicCell; use prost::Message; use std::{ collections::{HashMap, HashSet}, + num::NonZero, sync::Arc, time::Duration, }; @@ -28,8 +30,8 @@ use temporal_sdk_core_protos::{ temporal::api::history::v1::History, }; use tokio::sync::{ + Notify, mpsc::{Sender, channel}, - oneshot, }; use tokio_stream::wrappers::ReceiverStream; @@ -165,14 +167,24 @@ struct CustomSlotSupplier { unsafe impl Send for CustomSlotSupplier {} unsafe impl Sync for CustomSlotSupplier {} -pub type CustomReserveSlotCallback = - unsafe extern "C" fn(ctx: *const SlotReserveCtx, sender: *mut libc::c_void); -pub type CustomCancelReserveCallback = unsafe extern "C" fn(token_source: *mut libc::c_void); -/// Must return C#-tracked id for the permit. A zero value means no permit was reserved. -pub type CustomTryReserveSlotCallback = unsafe extern "C" fn(ctx: *const SlotReserveCtx) -> usize; -pub type CustomMarkSlotUsedCallback = unsafe extern "C" fn(ctx: *const SlotMarkUsedCtx); -pub type CustomReleaseSlotCallback = unsafe extern "C" fn(ctx: *const SlotReleaseCtx); -pub type CustomSlotImplFreeCallback = +pub type CustomSlotSupplierReserveCallback = unsafe extern "C" fn( + ctx: *const SlotReserveCtx, + completion_ctx: *const SlotReserveCompletionCtx, + user_data: *mut libc::c_void, +); +pub type CustomSlotSupplierCancelReserveCallback = unsafe extern "C" fn( + completion_ctx: *const SlotReserveCompletionCtx, + user_data: *mut libc::c_void, +); +pub type CustomSlotSupplierTryReserveCallback = + unsafe extern "C" fn(ctx: *const SlotReserveCtx, user_data: *mut libc::c_void) -> usize; +pub type CustomSlotSupplierMarkUsedCallback = + unsafe extern "C" fn(ctx: *const SlotMarkUsedCtx, user_data: *mut libc::c_void); +pub type CustomSlotSupplierReleaseCallback = + unsafe extern "C" fn(ctx: *const SlotReleaseCtx, user_data: *mut libc::c_void); +pub type CustomSlotSupplierAvailableSlotsCallback = + Option bool>; +pub type CustomSlotSupplierFreeCallback = unsafe extern "C" fn(userimpl: *const CustomSlotSupplierCallbacks); #[repr(C)] @@ -181,12 +193,46 @@ pub struct CustomSlotSupplierCallbacksImpl(pub *const CustomSlotSupplierCallback #[repr(C)] pub struct CustomSlotSupplierCallbacks { - pub reserve: CustomReserveSlotCallback, - pub cancel_reserve: CustomCancelReserveCallback, - pub try_reserve: CustomTryReserveSlotCallback, - pub mark_used: CustomMarkSlotUsedCallback, - pub release: CustomReleaseSlotCallback, - pub free: CustomSlotImplFreeCallback, + /// Called to initiate asynchronous slot reservation. `ctx` contains information about + /// reservation request. The pointer is only valid for the duration of the function call; the + /// implementation should copy the data out of it for later use, and return as soon as possible. + /// + /// When slot is reserved, the implementation should call [`temporal_core_complete_async_reserve`] + /// with the same `completion_ctx` as passed to this function. Reservation cannot be cancelled + /// by Lang, but it can be cancelled by Core through [`cancel_reserve`](Self::cancel_reserve) + /// callback. If reservation was cancelled, [`temporal_core_complete_async_cancel_reserve`] + /// should be called instead. + /// + /// Slot reservation cannot error. The implementation should recover from errors and keep trying + /// to reserve a slot until it eventually succeeds, or until reservation is cancelled by Core. + pub reserve: CustomSlotSupplierReserveCallback, + /// Called to cancel slot reservation. `completion_ctx` specifies which reservation is being + /// cancelled; the matching [`reserve`](Self::reserve) call was made with the same `completion_ctx`. + /// After cancellation, the implementation should call [`temporal_core_complete_async_cancel_reserve`] + /// with the same `completion_ctx`. Calling [`temporal_core_complete_async_reserve`] is not + /// needed after cancellation. + pub cancel_reserve: CustomSlotSupplierCancelReserveCallback, + /// Called to try an immediate slot reservation. The callback should return 0 if immediate + /// reservation is not currently possible, or permit ID if reservation was successful. Permit ID + /// is arbitrary, but must be unique among live reservations as it's later used for [`mark_used`](Self::mark_used) + /// and [`release`](Self::release) callbacks. + pub try_reserve: CustomSlotSupplierTryReserveCallback, + /// Called after successful reservation to mark slot as used. See [`SlotSupplier`](temporal_sdk_core_api::worker::SlotSupplier) + /// trait for details. + pub mark_used: CustomSlotSupplierMarkUsedCallback, + /// Called to free a previously reserved slot. + pub release: CustomSlotSupplierReleaseCallback, + /// Called to retrieve the number of available slots if known. If the implementation knows how + /// many slots are available at the moment, it should set the value behind the `available_slots` + /// pointer and return true. If that number is unknown, it should return false. + /// + /// This function pointer can be set to null. It will be treated as if the number of available + /// slots is never known. + pub available_slots: CustomSlotSupplierAvailableSlotsCallback, + /// Called when the slot supplier is being dropped. All resources should be freed. + pub free: CustomSlotSupplierFreeCallback, + /// Passed as an extra argument to the callbacks. + pub user_data: *mut libc::c_void, } impl CustomSlotSupplierCallbacksImpl { @@ -223,8 +269,6 @@ pub struct SlotReserveCtx { pub worker_identity: ByteArrayRef, pub worker_build_id: ByteArrayRef, pub is_sticky: bool, - // The C# side will store a pointer here to the cancellation token source - pub token_src: *mut libc::c_void, } unsafe impl Send for SlotReserveCtx {} @@ -249,31 +293,67 @@ pub enum SlotInfo { #[repr(C)] pub struct SlotMarkUsedCtx { pub slot_info: SlotInfo, - /// C# id for the slot permit. + /// Lang-issued permit ID. pub slot_permit: usize, } #[repr(C)] pub struct SlotReleaseCtx { pub slot_info: *const SlotInfo, - /// C# id for the slot permit. + /// Lang-issued permit ID. pub slot_permit: usize, } -struct CancelReserveGuard { - token_src: *mut libc::c_void, - callback: CustomCancelReserveCallback, +pub struct SlotReserveCompletionCtx { + state: AtomicCell, + notify: Notify, } -impl Drop for CancelReserveGuard { + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SlotReserveOperationState { + Pending, + Cancelled, + Completed(NonZero), +} + +struct CancelReserveGuard<'a, SK: SlotKind + Send + Sync> { + slot_supplier: &'a CustomSlotSupplier, + completion_ctx: Arc, + completed: bool, +} + +impl<'a, SK: SlotKind + Send + Sync> Drop for CancelReserveGuard<'a, SK> { fn drop(&mut self) { - if !self.token_src.is_null() { + // do not cancel if already completed + if !self.completed { + let state = self + .completion_ctx + .state + .swap(SlotReserveOperationState::Cancelled); unsafe { - (self.callback)(self.token_src); + let inner = &*self.slot_supplier.inner.0; + match state { + SlotReserveOperationState::Cancelled => { + // This situation should never happen, but on the other hand, it doesn't + // result in any unsafety, deadlock or leak. It's safe to ignore it, but in + // debug builds we'd like to know it happened. + debug_assert!(false, "slot reservation cancelled twice") + } + SlotReserveOperationState::Pending => { + (inner.cancel_reserve)(Arc::as_ptr(&self.completion_ctx), inner.user_data) + } + SlotReserveOperationState::Completed(slot_permit) => (inner.release)( + &SlotReleaseCtx { + slot_info: std::ptr::null(), + slot_permit: slot_permit.into(), + }, + inner.user_data, + ), + } } } } } -unsafe impl Send for CancelReserveGuard {} #[async_trait::async_trait] impl temporal_sdk_core_api::worker::SlotSupplier @@ -282,36 +362,53 @@ impl temporal_sdk_core_api::worker::SlotSupplier type SlotKind = SK; async fn reserve_slot(&self, ctx: &dyn SlotReservationContext) -> SlotSupplierPermit { - let (tx, rx) = oneshot::channel(); let ctx = Self::convert_reserve_ctx(ctx); - let tx = Box::into_raw(Box::new(tx)) as *mut libc::c_void; + let completion_ctx = Arc::new(SlotReserveCompletionCtx { + state: AtomicCell::new(SlotReserveOperationState::Pending), + notify: Notify::new(), + }); unsafe { - let _drop_guard = CancelReserveGuard { - token_src: ctx.token_src, - callback: (*self.inner.0).cancel_reserve, - }; - ((*self.inner.0).reserve)(&ctx, tx); - rx.await.expect("reserve channel is not closed") + let inner = &*self.inner.0; + (inner.reserve)(&ctx, Arc::into_raw(completion_ctx.clone()), inner.user_data); + } + let mut guard = CancelReserveGuard { + slot_supplier: self, + completion_ctx, + completed: false, + }; + // if the future is dropped before this await resolves, the guard is dropped which triggers cancellation + guard.completion_ctx.notify.notified().await; + guard.completed = true; + match guard.completion_ctx.state.load() { + SlotReserveOperationState::Completed(permit_id) => { + SlotSupplierPermit::with_user_data::(permit_id.get()) + } + other => panic!("Unexpected slot reservation state: expected Completed, got {other:?}"), } } fn try_reserve_slot(&self, ctx: &dyn SlotReservationContext) -> Option { let ctx = Self::convert_reserve_ctx(ctx); - let permit_id = unsafe { ((*self.inner.0).try_reserve)(&ctx) }; + let permit_id = unsafe { ((*self.inner.0).try_reserve)(&ctx, (*self.inner.0).user_data) }; if permit_id == 0 { None } else { - Some(SlotSupplierPermit::with_user_data(permit_id)) + Some(SlotSupplierPermit::with_user_data::(permit_id)) } } fn mark_slot_used(&self, ctx: &dyn SlotMarkUsedContext) { let ctx = SlotMarkUsedCtx { slot_info: Self::convert_slot_info(ctx.info().downcast()), - slot_permit: ctx.permit().user_data::().copied().unwrap_or(0), + slot_permit: ctx + .permit() + .user_data::() + .copied() + .expect("permit user data should be usize"), }; unsafe { - ((*self.inner.0).mark_used)(&ctx); + let inner = &*self.inner.0; + (inner.mark_used)(&ctx, inner.user_data); } } @@ -323,15 +420,26 @@ impl temporal_sdk_core_api::worker::SlotSupplier } let ctx = SlotReleaseCtx { slot_info: info_ptr, - slot_permit: ctx.permit().user_data::().copied().unwrap_or(0), + slot_permit: ctx + .permit() + .user_data::() + .copied() + .expect("permit user data should be usize"), }; unsafe { - ((*self.inner.0).release)(&ctx); + let inner = &*self.inner.0; + (inner.release)(&ctx, inner.user_data); } } fn available_slots(&self) -> Option { - None + unsafe { + let inner = &*self.inner.0; + inner.available_slots.and_then(|f| { + let mut available_slots = 0; + f(&mut available_slots, inner.user_data).then_some(available_slots) + }) + } } } @@ -360,7 +468,6 @@ impl CustomSlotSupplier { ByteArrayRef::empty() }, is_sticky: ctx.is_sticky(), - token_src: std::ptr::null_mut(), } } @@ -948,29 +1055,88 @@ pub extern "C" fn temporal_core_worker_replay_push( } } +/// Completes asynchronous slot reservation started by a call to [`CustomSlotSupplierCallbacks::reserve`]. +/// +/// `completion_ctx` must be the same as the one passed to the matching [`reserve`](CustomSlotSupplierCallbacks::reserve) +/// call. `permit_id` is arbitrary, but must be unique among live reservations as it's later used +/// for [`mark_used`](CustomSlotSupplierCallbacks::mark_used) and [`release`](CustomSlotSupplierCallbacks::release) +/// callbacks. +/// +/// This function returns true if the reservation was completed successfully, or false if the +/// reservation was cancelled before completion. If this function returns false, the implementation +/// should call [`temporal_core_complete_async_cancel_reserve`] with the same `completion_ctx`. +/// +/// **Caution:** if this function returns true, `completion_ctx` gets freed. Afterwards, calling +/// either [`temporal_core_complete_async_reserve`] or [`temporal_core_complete_async_cancel_reserve`] +/// with the same `completion_ctx` will cause **memory corruption!** #[unsafe(no_mangle)] pub extern "C" fn temporal_core_complete_async_reserve( - sender: *mut libc::c_void, + completion_ctx: *const SlotReserveCompletionCtx, permit_id: usize, -) { - if !sender.is_null() { - unsafe { - let sender = Box::from_raw(sender as *mut oneshot::Sender); - let permit = SlotSupplierPermit::with_user_data(permit_id); - let _ = sender.send(permit); +) -> bool { + if completion_ctx.is_null() { + panic!("completion_ctx is null"); + } + let permit_id = + NonZero::new(permit_id).expect("permit_id cannot be 0 on successful reservation"); + let prev_state = unsafe { + // Not turning completion_ctx into Arc yet as we only want to deallocate it on success + (*completion_ctx).state.compare_exchange( + SlotReserveOperationState::Pending, + SlotReserveOperationState::Completed(permit_id), + ) + }; + match prev_state { + Ok(_) => { + let completion_ctx = unsafe { Arc::from_raw(completion_ctx) }; + completion_ctx.notify.notify_one(); + true } - } else { - panic!("ReserveSlot sender must not be null!"); + Err(SlotReserveOperationState::Cancelled) => false, + Err(SlotReserveOperationState::Completed(prev_permit_id)) => { + panic!( + "temporal_core_complete_async_reserve called twice for the same reservation - first permit ID {prev_permit_id}, second permit ID {permit_id}" + ) + } + Err(SlotReserveOperationState::Pending) => unreachable!(), } } +/// Completes cancellation of asynchronous slot reservation. +/// +/// Cancellation can only be initiated by Core. It's done by calling [`CustomSlotSupplierCallbacks::cancel_reserve`] +/// after an earlier call to [`CustomSlotSupplierCallbacks::reserve`]. +/// +/// `completion_ctx` must be the same as the one passed to the matching [`cancel_reserve`](CustomSlotSupplierCallbacks::cancel_reserve) +/// call. +/// +/// This function returns true on successful cancellation, or false if cancellation was not +/// requested for the given `completion_ctx`. A false value indicates there's likely a logic bug in +/// the implementation where it doesn't correctly wait for [`cancel_reserve`](CustomSlotSupplierCallbacks::cancel_reserve) +/// callback to be called. +/// +/// **Caution:** if this function returns true, `completion_ctx` gets freed. Afterwards, calling +/// either [`temporal_core_complete_async_reserve`] or [`temporal_core_complete_async_cancel_reserve`] +/// with the same `completion_ctx` will cause **memory corruption!** #[unsafe(no_mangle)] -pub extern "C" fn temporal_core_set_reserve_cancel_target( - ctx: *mut SlotReserveCtx, - token_ptr: *mut libc::c_void, -) { - if let Some(ctx) = unsafe { ctx.as_mut() } { - ctx.token_src = token_ptr; +pub extern "C" fn temporal_core_complete_async_cancel_reserve( + completion_ctx: *const SlotReserveCompletionCtx, +) -> bool { + if completion_ctx.is_null() { + panic!("completion_ctx is null"); + } + let state = unsafe { (*completion_ctx).state.load() }; + match state { + SlotReserveOperationState::Cancelled => { + drop(unsafe { Arc::from_raw(completion_ctx) }); + true + } + SlotReserveOperationState::Pending => false, + SlotReserveOperationState::Completed(permit_id) => { + panic!( + "temporal_core_complete_async_cancel_reserve called on completed reservation - permit ID {permit_id}" + ) + } } } diff --git a/core/Cargo.toml b/core/Cargo.toml index 1a0afea75..c9ceccda7 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -64,12 +64,12 @@ itertools = "0.14" lru = "0.16" mockall = "0.13" opentelemetry = { workspace = true, features = ["metrics"], optional = true } -opentelemetry_sdk = { version = "0.30", features = [ +opentelemetry_sdk = { version = "0.31", features = [ "rt-tokio", "metrics", "spec_unstable_metrics_views", ], optional = true } -opentelemetry-otlp = { version = "0.30", features = [ +opentelemetry-otlp = { version = "0.31", features = [ "tokio", "metrics", "tls", @@ -81,7 +81,7 @@ pid = "4.0" pin-project = "1.1" prometheus = { version = "0.14", optional = true } prost = { workspace = true } -prost-types = { version = "0.6", package = "prost-wkt-types" } +prost-types = { workspace = true } rand = "0.9" reqwest = { version = "0.12", features = [ "json", @@ -116,7 +116,8 @@ tracing-subscriber = { version = "0.3", default-features = false, features = [ ] } url = "2.5" uuid = { version = "1.18", features = ["v4"] } -zip = { version = "4.6", optional = true } +# Only need specific features to decompress zip files for ephemeral server download +zip = { version = "4.6", optional = true, default-features = false, features = ["deflate", "bzip2", "zstd"] } # 1st party local deps [dependencies.temporal-sdk-core-api] diff --git a/core/src/core_tests/updates.rs b/core/src/core_tests/updates.rs index 9beedc46f..9f6bdb954 100644 --- a/core/src/core_tests/updates.rs +++ b/core/src/core_tests/updates.rs @@ -114,10 +114,10 @@ async fn initial_request_sent_back(#[values(false, true)] reject: bool) { .returning(move |mut resp| { let msg = resp.messages.pop().unwrap(); let orig_req = if reject { - let acceptance = msg.body.unwrap().unpack_as(Rejection::default()).unwrap(); + let acceptance = msg.body.unwrap().to_msg::().unwrap(); acceptance.rejected_request.unwrap() } else { - let acceptance = msg.body.unwrap().unpack_as(Acceptance::default()).unwrap(); + let acceptance = msg.body.unwrap().to_msg::().unwrap(); acceptance.accepted_request.unwrap() }; assert_eq!(orig_req, upd_req_body); diff --git a/core/src/internal_flags.rs b/core/src/internal_flags.rs index 4a676da7a..df2fa5afb 100644 --- a/core/src/internal_flags.rs +++ b/core/src/internal_flags.rs @@ -18,6 +18,7 @@ use temporal_sdk_core_protos::temporal::api::{ /// may be removed from the enum. *Importantly*, all variants must be given explicit values, such /// that removing older variants does not create any change in existing values. Removed flag /// variants must be reserved forever (a-la protobuf), and should be called out in a comment. +#[allow(unreachable_pub)] // re-exported in test_help::integ_helpers #[repr(u32)] #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone, Debug, enum_iterator::Sequence)] pub enum CoreInternalFlags { diff --git a/core/src/lib.rs b/core/src/lib.rs index a73eb43ad..e306d5f15 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -62,7 +62,7 @@ use anyhow::bail; use futures_util::Stream; use std::sync::Arc; use std::time::Duration; -use temporal_client::{ConfiguredClient, NamespacedClient, TemporalServiceClientWithMetrics}; +use temporal_client::{ConfiguredClient, NamespacedClient, SharedReplaceableClient}; use temporal_sdk_core_api::{ Worker as WorkerTrait, errors::{CompleteActivityError, PollError}, @@ -89,20 +89,20 @@ pub fn init_worker( where CT: Into, { - let client_inner = *client.into().into_inner(); - let client = init_worker_client( - worker_config.namespace.clone(), - worker_config.client_identity_override.clone(), - client_inner, - ); let namespace = worker_config.namespace.clone(); - if client.namespace() != namespace { - bail!("Passed in client is not bound to the same namespace as the worker"); - } - if client.namespace() == "" { - bail!("Client namespace cannot be empty"); + if namespace.is_empty() { + bail!("Worker namespace cannot be empty"); } - let client_ident = client.get_identity().to_owned(); + + let client = RetryClient::new( + SharedReplaceableClient::new(init_worker_client( + worker_config.namespace.clone(), + worker_config.client_identity_override.clone(), + client, + )), + RetryConfig::default(), + ); + let client_ident = client.identity(); let sticky_q = sticky_q_name_for_worker(&client_ident, worker_config.max_cached_workflows); if client_ident.is_empty() { @@ -142,16 +142,19 @@ where rwi.into_core_worker() } -pub(crate) fn init_worker_client( +pub(crate) fn init_worker_client( namespace: String, client_identity_override: Option, - client: ConfiguredClient, -) -> RetryClient { - let mut client = Client::new(client, namespace); + client: CT, +) -> Client +where + CT: Into, +{ + let mut client = Client::new(*client.into().into_inner(), namespace.clone()); if let Some(ref id_override) = client_identity_override { client.options_mut().identity.clone_from(id_override); } - RetryClient::new(client, RetryConfig::default()) + client } /// Creates a unique sticky queue name for a worker, iff the config allows for 1 or more cached @@ -173,44 +176,57 @@ pub(crate) fn sticky_q_name_for_worker( mod sealed { use super::*; + use temporal_client::{SharedReplaceableClient, TemporalServiceClient}; /// Allows passing different kinds of clients into things that want to be flexible. Motivating /// use-case was worker initialization. /// /// Needs to exist in this crate to avoid blanket impl conflicts. pub struct AnyClient { - pub(crate) inner: Box>, + pub(crate) inner: Box>, } impl AnyClient { - pub(crate) fn into_inner(self) -> Box> { + pub(crate) fn into_inner(self) -> Box> { self.inner } } - impl From>> for AnyClient { - fn from(c: RetryClient>) -> Self { - Self { - inner: Box::new(c.into_inner()), - } + impl From> for AnyClient { + fn from(c: ConfiguredClient) -> Self { + Self { inner: Box::new(c) } + } + } + + impl From for AnyClient { + fn from(c: Client) -> Self { + c.into_inner().into() } } - impl From> for AnyClient { - fn from(c: RetryClient) -> Self { - Self { - inner: Box::new(c.into_inner().into_inner()), - } + + impl From> for AnyClient + where + T: Into, + { + fn from(c: RetryClient) -> Self { + c.into_inner().into() } } - impl From>> for AnyClient { - fn from(c: Arc>) -> Self { - Self { - inner: Box::new(c.get_client().inner().clone()), - } + + impl From> for AnyClient + where + T: Into + Clone + Send + Sync, + { + fn from(c: SharedReplaceableClient) -> Self { + c.inner_clone().into() } } - impl From> for AnyClient { - fn from(c: ConfiguredClient) -> Self { - Self { inner: Box::new(c) } + + impl From> for AnyClient + where + T: Into + Clone, + { + fn from(c: Arc) -> Self { + Arc::unwrap_or_clone(c).into() } } } diff --git a/core/src/protosext/mod.rs b/core/src/protosext/mod.rs index 1a1119033..29250a6da 100644 --- a/core/src/protosext/mod.rs +++ b/core/src/protosext/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod protocol_messages; use crate::{ CompleteActivityError, TaskToken, protosext::protocol_messages::IncomingProtocolMessage, + retry_logic::ValidatedRetryPolicy, worker::{LEGACY_QUERY_ID, LocalActivityExecutionResult}, }; use anyhow::anyhow; @@ -32,7 +33,7 @@ use temporal_sdk_core_protos::{ workflow_completion, }, temporal::api::{ - common::v1::{Payload, RetryPolicy, WorkflowExecution}, + common::v1::{Payload, WorkflowExecution}, enums::v1::EventType, failure::v1::Failure, history::v1::{History, HistoryEvent, MarkerRecordedEventAttributes, history_event}, @@ -135,7 +136,7 @@ impl TryFrom for ValidPollWFTQResponse { _cant_construct_me: (), }) } - _ => Err(anyhow!("Unable to interpret poll response: {value:?}",)), + _ => Err(anyhow!("Unable to interpret poll response: {value:?}")), } } } @@ -318,7 +319,7 @@ pub(crate) struct ValidScheduleLA { pub(crate) arguments: Vec, pub(crate) schedule_to_start_timeout: Option, pub(crate) close_timeouts: LACloseTimeouts, - pub(crate) retry_policy: RetryPolicy, + pub(crate) retry_policy: ValidatedRetryPolicy, pub(crate) local_retry_threshold: Duration, pub(crate) cancellation_type: ActivityCancellationType, pub(crate) user_metadata: Option, @@ -408,7 +409,8 @@ impl ValidScheduleLA { )); } }; - let retry_policy = v.retry_policy.unwrap_or_default(); + let retry_policy = + ValidatedRetryPolicy::from_proto_with_defaults(v.retry_policy.unwrap_or_default()); let local_retry_threshold = v .local_retry_threshold .try_into_or_none() diff --git a/core/src/protosext/protocol_messages.rs b/core/src/protosext/protocol_messages.rs index af2aab47d..e267193db 100644 --- a/core/src/protosext/protocol_messages.rs +++ b/core/src/protosext/protocol_messages.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, bail}; +use anyhow::anyhow; use std::collections::HashMap; use temporal_sdk_core_protos::temporal::api::{ common::v1::Payload, @@ -108,16 +108,9 @@ impl TryFrom> for IncomingProtocolMessageBody { fn try_from(v: Option) -> Result { let v = v.ok_or_else(|| anyhow!("Protocol message body must be populated"))?; - // Undo explicit type url checks when https://github.com/fdeantoni/prost-wkt/issues/48 is - // fixed - Ok(match v.type_url.as_str() { - "type.googleapis.com/temporal.api.update.v1.Request" => { - IncomingProtocolMessageBody::UpdateRequest( - v.unpack_as(update::v1::Request::default())?.try_into()?, - ) - } - o => bail!("Could not understand protocol message type {o}"), - }) + Ok(IncomingProtocolMessageBody::UpdateRequest( + v.to_msg::()?.try_into()?, + )) } } diff --git a/core/src/retry_logic.rs b/core/src/retry_logic.rs index c8f025c40..3a4140cc6 100644 --- a/core/src/retry_logic.rs +++ b/core/src/retry_logic.rs @@ -1,45 +1,87 @@ -use std::time::Duration; -use temporal_sdk_core_protos::{ - temporal::api::{common::v1::RetryPolicy, failure::v1::ApplicationFailureInfo}, - utilities::TryIntoOrNone, +use std::{num::NonZero, time::Duration}; +use temporal_sdk_core_protos::temporal::api::{ + common::v1::RetryPolicy, failure::v1::ApplicationFailureInfo, }; -pub(crate) trait RetryPolicyExt { +/// Represents a retry policy where all fields have valid values. Durations are stored in std type. +/// Upholds the following invariants: +/// - `maximum_interval` >= `initial_interval` +/// - `backoff_coefficient` >= 1 +/// - `maximum_attempts` >= 0 +#[derive(Debug, Clone)] +pub(crate) struct ValidatedRetryPolicy { + initial_interval: Duration, + backoff_coefficient: f64, + maximum_interval: Duration, + maximum_attempts: u32, + non_retryable_error_types: Vec, +} + +impl ValidatedRetryPolicy { + /// Validates and converts retry policy. If some field is invalid, it's replaced with a default value: + /// - `initial_interval`: 1 second + /// - `backoff_coefficient`: 2.0 + /// - `maximum_interval`: 100 * `initial_interval` if missing or inconvertible, 1 * `initial_interval` if too small + /// - `maximum_attempts`: 0 (unlimited) + pub(crate) fn from_proto_with_defaults(retry_policy: RetryPolicy) -> Self { + let initial_interval = retry_policy + .initial_interval + .and_then(|i| i.try_into().ok()) + .unwrap_or_else(|| Duration::from_secs(1)); + + let backoff_coefficient = if retry_policy.backoff_coefficient >= 1.0 { + retry_policy.backoff_coefficient + } else { + 2.0 + }; + + let maximum_interval = if let Some(maximum_interval) = retry_policy + .maximum_interval + .and_then(|i| Duration::try_from(i).ok()) + { + maximum_interval.max(initial_interval) + } else { + let maximum_interval = initial_interval.saturating_mul(100); + // Verifying that serialization to proto will work. It may fail for extremely large + // durations, so in that case we fall back to maximum_interval = initial_interval. + if prost_types::Duration::try_from(maximum_interval).is_ok() { + maximum_interval + } else { + initial_interval + } + }; + + Self { + initial_interval, + backoff_coefficient, + maximum_interval, + maximum_attempts: retry_policy.maximum_attempts.try_into().unwrap_or(0), + non_retryable_error_types: retry_policy.non_retryable_error_types, + } + } + /// Ask this retry policy if a retry should be performed. Caller provides the current attempt /// number - the first attempt should start at 1. /// - /// Returns `None` if it should not, otherwise a duration indicating how long to wait before - /// performing the retry. - /// - /// Applies defaults to missing fields: - /// `initial_interval` - 1 second - /// `maximum_interval` - 100 x initial_interval - /// `backoff_coefficient` - 2.0 - fn should_retry( - &self, - attempt_number: usize, - application_failure: Option<&ApplicationFailureInfo>, - ) -> Option; -} - -impl RetryPolicyExt for RetryPolicy { - fn should_retry( + /// Returns `None` if it should not retry, otherwise returns a duration indicating how long to + /// wait before performing the retry. + pub(crate) fn should_retry( &self, - attempt_number: usize, + attempt_number: NonZero, application_failure: Option<&ApplicationFailureInfo>, ) -> Option { + if self.maximum_attempts > 0 && attempt_number.get() >= self.maximum_attempts { + return None; + } + let non_retryable = application_failure .map(|f| f.non_retryable) .unwrap_or_default(); if non_retryable { return None; } - let err_type_str = application_failure.map_or("", |f| &f.r#type); - let realmax = self.maximum_attempts.max(0); - if realmax > 0 && attempt_number >= realmax as usize { - return None; - } + let err_type_str = application_failure.map_or("", |f| &f.r#type); for pat in &self.non_retryable_error_types { if err_type_str.to_lowercase() == pat.to_lowercase() { return None; @@ -47,46 +89,48 @@ impl RetryPolicyExt for RetryPolicy { } if let Some(explicit_delay) = application_failure.and_then(|af| af.next_retry_delay) { - return explicit_delay.try_into().ok(); + match explicit_delay.try_into() { + Ok(delay) => return Some(delay), + Err(e) => error!( + "Failed to convert retry delay of application failure. Normal delay calculation will be used. Conversion error: `{}`. Application failure: {:?}", + e, application_failure + ), + } } - let converted_interval = self - .initial_interval - .try_into_or_none() - .or(Some(Duration::from_secs(1))); - if attempt_number == 1 { - return converted_interval; + if attempt_number.get() == 1 { + return Some(self.initial_interval); } - let coeff = if self.backoff_coefficient != 0. { - self.backoff_coefficient - } else { - 2.0 - }; - if let Some(interval) = converted_interval { - let max_iv = self - .maximum_interval - .try_into_or_none() - .unwrap_or_else(|| interval.saturating_mul(100)); - let mul_factor = coeff.powi(attempt_number as i32 - 1); - let tried_mul = try_from_secs_f64(mul_factor * interval.as_secs_f64()); - Some(tried_mul.unwrap_or(max_iv).min(max_iv)) - } else { - // No retries if initial interval is not specified - None - } + let delay = i32::try_from(attempt_number.get()) + .ok() + .and_then(|attempt| { + let factor = self.backoff_coefficient.powi(attempt - 1); + Duration::try_from_secs_f64(factor * self.initial_interval.as_secs_f64()).ok() + }) + .map(|interval| interval.min(self.maximum_interval)) + .unwrap_or(self.maximum_interval); + + Some(delay) } } -const NANOS_PER_SEC: u32 = 1_000_000_000; -/// modified from rust stdlib since this feature is currently nightly only -fn try_from_secs_f64(secs: f64) -> Option { - const MAX_NANOS_F64: f64 = ((u64::MAX as u128 + 1) * (NANOS_PER_SEC as u128)) as f64; - let nanos = secs * (NANOS_PER_SEC as f64); - if !nanos.is_finite() || !(0.0..MAX_NANOS_F64).contains(&nanos) { - None - } else { - Some(Duration::from_secs_f64(secs)) +impl Default for ValidatedRetryPolicy { + fn default() -> Self { + Self::from_proto_with_defaults(RetryPolicy::default()) + } +} + +impl From for RetryPolicy { + fn from(value: ValidatedRetryPolicy) -> Self { + // All fields were tested on struct initialization to convert successfully. Unwraps are safe. + Self { + initial_interval: Some(value.initial_interval.try_into().unwrap()), + backoff_coefficient: value.backoff_coefficient, + maximum_interval: Some(value.maximum_interval.try_into().unwrap()), + maximum_attempts: value.maximum_attempts.try_into().unwrap(), + non_retryable_error_types: value.non_retryable_error_types, + } } } @@ -94,84 +138,186 @@ fn try_from_secs_f64(secs: f64) -> Option { mod tests { use super::*; use crate::prost_dur; + use std::{num::NonZero, time::Duration}; + + macro_rules! nz { + ($x:expr) => { + NonZero::new($x).unwrap() + }; + } + + #[test] + fn applies_defaults_to_default_retry_policy() { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy::default()); + assert_eq!(rp.initial_interval, Duration::from_secs(1)); + assert_eq!(rp.backoff_coefficient, 2.0); + assert_eq!(rp.maximum_interval, Duration::from_secs(100)); + assert_eq!(rp.maximum_attempts, 0); + assert!(rp.non_retryable_error_types.is_empty()); + + let rp = ValidatedRetryPolicy::default(); + assert_eq!(rp.initial_interval, Duration::from_secs(1)); + assert_eq!(rp.backoff_coefficient, 2.0); + assert_eq!(rp.maximum_interval, Duration::from_secs(100)); + assert_eq!(rp.maximum_attempts, 0); + assert!(rp.non_retryable_error_types.is_empty()); + } + + #[test] + fn applies_defaults_to_invalid_fields_only() { + let base_rp = RetryPolicy { + initial_interval: Some(prost_dur!(from_secs(2))), + backoff_coefficient: 1.5, + maximum_interval: Some(prost_dur!(from_secs(4))), + maximum_attempts: 2, + non_retryable_error_types: vec!["error".into()], + }; + let base_values = ValidatedRetryPolicy::from_proto_with_defaults(base_rp.clone()); + assert_eq!(base_values.initial_interval, Duration::from_secs(2)); + assert_eq!(base_values.backoff_coefficient, 1.5); + assert_eq!(base_values.maximum_interval, Duration::from_secs(4)); + assert_eq!(base_values.maximum_attempts, 2); + assert_eq!( + base_values.non_retryable_error_types, + vec!["error".to_owned()] + ); + + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { + initial_interval: Some(prost_types::Duration { + seconds: -5, + nanos: 0, + }), + ..base_rp.clone() + }); + assert_eq!(rp.initial_interval, Duration::from_secs(1)); + assert_eq!(rp.backoff_coefficient, base_values.backoff_coefficient); + assert_eq!(rp.maximum_interval, base_values.maximum_interval); + assert_eq!(rp.maximum_attempts, base_values.maximum_attempts); + assert_eq!( + rp.non_retryable_error_types, + base_values.non_retryable_error_types + ); + + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { + backoff_coefficient: 0.5, + ..base_rp.clone() + }); + assert_eq!(rp.initial_interval, base_values.initial_interval); + assert_eq!(rp.backoff_coefficient, 2.0); + assert_eq!(rp.maximum_interval, base_values.maximum_interval); + assert_eq!(rp.maximum_attempts, base_values.maximum_attempts); + assert_eq!( + rp.non_retryable_error_types, + base_values.non_retryable_error_types + ); + + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { + maximum_interval: Some(prost_types::Duration { + seconds: -5, + nanos: 0, + }), + ..base_rp.clone() + }); + assert_eq!(rp.initial_interval, base_values.initial_interval); + assert_eq!(rp.backoff_coefficient, base_values.backoff_coefficient); + assert_eq!(rp.maximum_interval, 100 * base_values.initial_interval); + assert_eq!(rp.maximum_attempts, base_values.maximum_attempts); + assert_eq!( + rp.non_retryable_error_types, + base_values.non_retryable_error_types + ); + + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { + maximum_interval: Some(prost_dur!(from_secs(1))), // valid but less than initial interval + ..base_rp.clone() + }); + assert_eq!(rp.initial_interval, base_values.initial_interval); + assert_eq!(rp.backoff_coefficient, base_values.backoff_coefficient); + assert_eq!(rp.maximum_interval, base_values.initial_interval); + assert_eq!(rp.maximum_attempts, base_values.maximum_attempts); + assert_eq!( + rp.non_retryable_error_types, + base_values.non_retryable_error_types + ); + + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { + maximum_attempts: -5, + ..base_rp.clone() + }); + assert_eq!(rp.initial_interval, base_values.initial_interval); + assert_eq!(rp.backoff_coefficient, base_values.backoff_coefficient); + assert_eq!(rp.maximum_interval, base_values.maximum_interval); + assert_eq!(rp.maximum_attempts, 0); + assert_eq!( + rp.non_retryable_error_types, + base_values.non_retryable_error_types + ); + + // non_retryable_error_types is always valid + } #[test] fn calcs_backoffs_properly() { - let rp = RetryPolicy { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 2.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec![], - }; - let res = rp.should_retry(1, None).unwrap(); - assert_eq!(res.as_millis(), 1_000); - let res = rp.should_retry(2, None).unwrap(); - assert_eq!(res.as_millis(), 2_000); - let res = rp.should_retry(3, None).unwrap(); - assert_eq!(res.as_millis(), 4_000); - let res = rp.should_retry(4, None).unwrap(); - assert_eq!(res.as_millis(), 8_000); - let res = rp.should_retry(5, None).unwrap(); - assert_eq!(res.as_millis(), 10_000); - let res = rp.should_retry(6, None).unwrap(); - assert_eq!(res.as_millis(), 10_000); + }); + assert_eq!(rp.should_retry(nz!(1), None), Some(Duration::from_secs(1))); + assert_eq!(rp.should_retry(nz!(2), None), Some(Duration::from_secs(2))); + assert_eq!(rp.should_retry(nz!(3), None), Some(Duration::from_secs(4))); + assert_eq!(rp.should_retry(nz!(4), None), Some(Duration::from_secs(8))); + assert_eq!(rp.should_retry(nz!(5), None), Some(Duration::from_secs(10))); + assert_eq!(rp.should_retry(nz!(6), None), Some(Duration::from_secs(10))); // Max attempts - no retry - assert!(rp.should_retry(10, None).is_none()); - } - - #[test] - fn no_interval_no_backoff() { - let rp = RetryPolicy { - initial_interval: None, - backoff_coefficient: 0., - maximum_interval: None, - maximum_attempts: 10, - non_retryable_error_types: vec![], - }; - assert!(rp.should_retry(1, None).is_some()); + assert!(rp.should_retry(nz!(10), None).is_none()); } #[test] fn max_attempts_zero_retry_forever() { - let rp = RetryPolicy { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 1.2, maximum_interval: None, maximum_attempts: 0, non_retryable_error_types: vec![], - }; - for i in 0..50 { - assert!(rp.should_retry(i, None).is_some()); + }); + for i in 1..50 { + assert!(rp.should_retry(nz!(i), None).is_some()); } } #[test] - fn no_overflows() { - let rp = RetryPolicy { + fn delay_calculation_does_not_overflow() { + let maximum_interval = Duration::from_secs(1000 * 365 * 24 * 60 * 60); + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 10., - maximum_interval: None, + maximum_interval: Some(maximum_interval.try_into().unwrap()), maximum_attempts: 0, non_retryable_error_types: vec![], - }; - for i in 0..50 { - assert!(rp.should_retry(i, None).is_some()); + }); + for i in 1..50 { + assert!(rp.should_retry(nz!(i), None).unwrap() <= maximum_interval); } + assert_eq!(rp.should_retry(nz!(50), None), Some(maximum_interval)); + assert_eq!(rp.should_retry(nz!(u32::MAX), None), Some(maximum_interval)); } #[test] fn no_retry_err_str_match() { - let rp = RetryPolicy { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 2.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec!["no retry".to_string()], - }; + }); assert!( rp.should_retry( - 1, + nz!(1), Some(&ApplicationFailureInfo { r#type: "no retry".to_string(), non_retryable: false, @@ -184,16 +330,16 @@ mod tests { #[test] fn no_non_retryable_application_failure() { - let rp = RetryPolicy { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 2.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec![], - }; + }); assert!( rp.should_retry( - 1, + nz!(1), Some(&ApplicationFailureInfo { r#type: "".to_string(), non_retryable: true, @@ -206,19 +352,21 @@ mod tests { #[test] fn explicit_delay_is_used() { - let rp = RetryPolicy { + let rp = ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 2.0, maximum_attempts: 2, ..Default::default() - }; + }); let afi = &ApplicationFailureInfo { r#type: "".to_string(), next_retry_delay: Some(prost_dur!(from_secs(50))), ..Default::default() }; - let res = rp.should_retry(1, Some(afi)).unwrap(); - assert_eq!(res.as_millis(), 50_000); - assert!(rp.should_retry(2, Some(afi)).is_none()); + assert_eq!( + rp.should_retry(nz!(1), Some(afi)), + Some(Duration::from_secs(50)) + ); + assert!(rp.should_retry(nz!(2), Some(afi)).is_none()); } } diff --git a/core/src/test_help/integ_helpers.rs b/core/src/test_help/integ_helpers.rs index 2f3cded78..085f44e07 100644 --- a/core/src/test_help/integ_helpers.rs +++ b/core/src/test_help/integ_helpers.rs @@ -660,7 +660,6 @@ pub fn build_mock_pollers(mut cfg: MockPollCfg) -> MocksHolder { ) .is_err() { - dbg!("Exiting mock WFT task because rcv half of stream was dropped"); break; } } diff --git a/core/src/worker/activities.rs b/core/src/worker/activities.rs index a4edb4f2c..071c8d3ec 100644 --- a/core/src/worker/activities.rs +++ b/core/src/worker/activities.rs @@ -333,7 +333,14 @@ impl WorkerActivityTasks { if let Some(jh) = act_info.local_timeouts_task { jh.abort() }; - self.heartbeat_manager.evict(task_token.clone()).await; + let should_flush = !known_not_found + && !matches!( + &status, + aer::Status::Completed(_) | aer::Status::WillCompleteAsync(_) + ); + self.heartbeat_manager + .evict(task_token.clone(), should_flush) + .await; // No need to report activities which we already know the server doesn't care about if !known_not_found { @@ -400,8 +407,6 @@ impl WorkerActivityTasks { } }; - self.complete_notify.notify_waiters(); - if let Some(e) = maybe_net_err { if e.code() == tonic::Code::NotFound { warn!(task_token=?task_token, details=?e, "Activity not found on \ @@ -418,6 +423,8 @@ impl WorkerActivityTasks { &task_token ); } + + self.complete_notify.notify_waiters(); } /// Attempt to record an activity heartbeat diff --git a/core/src/worker/activities/activity_heartbeat_manager.rs b/core/src/worker/activities/activity_heartbeat_manager.rs index 44e426ba9..583584fd8 100644 --- a/core/src/worker/activities/activity_heartbeat_manager.rs +++ b/core/src/worker/activities/activity_heartbeat_manager.rs @@ -42,6 +42,7 @@ enum HeartbeatAction { Evict { token: TaskToken, on_complete: Arc, + should_flush: bool, }, CompleteReport(TaskToken), CompleteThrottle(TaskToken), @@ -118,7 +119,7 @@ impl ActivityHeartbeatManager { HeartbeatAction::SendHeartbeat(hb) => hb_states.record(hb), HeartbeatAction::CompleteReport(tt) => hb_states.handle_report_completed(tt), HeartbeatAction::CompleteThrottle(tt) => hb_states.handle_throttle_completed(tt), - HeartbeatAction::Evict{ token, on_complete } => hb_states.evict(token, on_complete), + HeartbeatAction::Evict{ token, on_complete, should_flush } => hb_states.evict(token, on_complete, should_flush), }, hb_states, )) @@ -230,13 +231,14 @@ impl ActivityHeartbeatManager { } /// Tell the heartbeat manager we are done forever with a certain task, so it may be forgotten. - /// This will also force-flush the most recently provided details. + /// If should_flush is true, will also force-flush the most recently provided details. /// Record *should* not be called with the same TaskToken after calling this. - pub(super) async fn evict(&self, task_token: TaskToken) { + pub(super) async fn evict(&self, task_token: TaskToken, should_flush: bool) { let completed = Arc::new(Notify::new()); let _ = self.heartbeat_tx.send(HeartbeatAction::Evict { token: task_token, on_complete: completed.clone(), + should_flush, }); completed.notified().await; } @@ -397,12 +399,15 @@ impl HeartbeatStreamState { &mut self, tt: TaskToken, on_complete: Arc, + should_flush: bool, ) -> Option { if let Some(state) = self.tt_to_state.remove(&tt) { if let Some(cancel_tok) = state.throttled_cancellation_token { cancel_tok.cancel(); } - if let Some(last_deets) = state.last_recorded_details { + if let Some(last_deets) = state.last_recorded_details + && should_flush + { self.tt_needs_flush.insert(tt.clone(), on_complete); return Some(HeartbeatExecutorAction::Report { task_token: tt, @@ -524,7 +529,7 @@ mod test { record_heartbeat(&hm, fake_task_token.clone(), 0, Duration::from_millis(100)); // Let it propagate sleep(Duration::from_millis(10)).await; - hm.evict(fake_task_token.clone().into()).await; + hm.evict(fake_task_token.clone().into(), true).await; record_heartbeat(&hm, fake_task_token, 0, Duration::from_millis(100)); // Let it propagate sleep(Duration::from_millis(10)).await; @@ -543,7 +548,38 @@ mod test { let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); let fake_task_token = vec![1, 2, 3]; record_heartbeat(&hm, fake_task_token.clone(), 0, Duration::from_millis(100)); - hm.evict(fake_task_token.clone().into()).await; + hm.evict(fake_task_token.clone().into(), true).await; + hm.shutdown().await; + } + + #[tokio::test] + async fn no_flush_on_successful_completion() { + let mut mock_client = mock_worker_client(); + // Should only expect 1 heartbeat call, not 2 (the second would be from evict flushing) + mock_client + .expect_record_activity_heartbeat() + .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .times(1); + let (cancel_tx, _cancel_rx) = unbounded_channel(); + let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); + let fake_task_token = vec![1, 2, 3]; + + // Record initial heartbeat - this should be sent immediately + record_heartbeat(&hm, fake_task_token.clone(), 0, Duration::from_millis(100)); + + // Wait a bit for initial heartbeat to process and enter throttling phase + sleep(Duration::from_millis(50)).await; + + // Record another heartbeat while throttled - this should be stored in last_recorded_details + record_heartbeat(&hm, fake_task_token.clone(), 1, Duration::from_millis(100)); + + // Wait a bit to ensure the second heartbeat is recorded but not sent + sleep(Duration::from_millis(10)).await; + + // Evict the activity with should_flush false + // This should NOT send the stored heartbeat details since the activity completed successfully + hm.evict(fake_task_token.into(), false).await; + hm.shutdown().await; } diff --git a/core/src/worker/activities/local_activities.rs b/core/src/worker/activities/local_activities.rs index bdfc1b5e4..106b7d49a 100644 --- a/core/src/worker/activities/local_activities.rs +++ b/core/src/worker/activities/local_activities.rs @@ -2,7 +2,6 @@ use crate::{ MetricsContext, TaskToken, abstractions::{MeteredPermitDealer, OwnedMeteredSemPermit, UsedMeteredSemPermit, dbg_panic}, protosext::ValidScheduleLA, - retry_logic::RetryPolicyExt, telemetry::metrics::{activity_type, should_record_failure_metric, workflow_type}, worker::workflow::HeartbeatTimeoutMsg, }; @@ -13,6 +12,7 @@ use parking_lot::{Mutex, MutexGuard}; use std::{ collections::{HashMap, hash_map::Entry}, fmt::{Debug, Formatter}, + num::NonZero, pin::Pin, task::{Context, Poll}, time::{Duration, Instant, SystemTime}, @@ -490,7 +490,7 @@ impl LocalActivityManager { dispatch_time: Instant::now(), attempt, _permit: permit.into_used(LocalActivitySlotInfo { - activity_type: new_la.workflow_type.clone(), + activity_type: sa.activity_type.clone(), }), }, ); @@ -525,7 +525,7 @@ impl LocalActivityManager { .or(schedule_to_close) .and_then(|t| t.try_into().ok()), heartbeat_timeout: None, - retry_policy: Some(sa.retry_policy), + retry_policy: Some(sa.retry_policy.into()), priority: Some(Default::default()), is_local: true, })), @@ -570,7 +570,7 @@ impl LocalActivityManager { macro_rules! calc_backoff { ($fail: ident) => { info.la_info.schedule_cmd.retry_policy.should_retry( - info.attempt as usize, + info.attempt.try_into().unwrap_or(NonZero::::MIN), $fail .failure .as_ref() @@ -976,7 +976,7 @@ impl Drop for TimeoutBag { #[cfg(test)] mod tests { use super::*; - use crate::{prost_dur, protosext::LACloseTimeouts}; + use crate::{prost_dur, protosext::LACloseTimeouts, retry_logic::ValidatedRetryPolicy}; use futures_util::FutureExt; use temporal_sdk_core_protos::temporal::api::{ common::v1::RetryPolicy, @@ -1111,13 +1111,13 @@ mod tests { seq: 1, activity_id: 1.to_string(), attempt: 5, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 10.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec![], - }, + }), local_retry_threshold: Duration::from_secs(5), ..Default::default() }, @@ -1146,13 +1146,13 @@ mod tests { seq: 1, activity_id: "1".to_string(), attempt: 1, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(1))), backoff_coefficient: 10.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec!["TestError".to_string()], - }, + }), local_retry_threshold: Duration::from_secs(5), ..Default::default() }, @@ -1190,13 +1190,13 @@ mod tests { seq: 1, activity_id: 1.to_string(), attempt: 5, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_secs(10))), backoff_coefficient: 1.0, maximum_interval: Some(prost_dur!(from_secs(10))), maximum_attempts: 10, non_retryable_error_types: vec![], - }, + }), local_retry_threshold: Duration::from_secs(500), ..Default::default() }, @@ -1239,11 +1239,11 @@ mod tests { seq: 1, activity_id: 1.to_string(), attempt: 5, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_millis(10))), backoff_coefficient: 1.0, ..Default::default() - }, + }), local_retry_threshold: Duration::from_secs(500), ..Default::default() }, @@ -1276,11 +1276,11 @@ mod tests { seq: 1, activity_id: 1.to_string(), attempt: 5, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_millis(10))), backoff_coefficient: 1.0, ..Default::default() - }, + }), local_retry_threshold: Duration::from_secs(500), schedule_to_start_timeout: Some(timeout), ..Default::default() @@ -1319,12 +1319,12 @@ mod tests { seq: 1, activity_id: 1.to_string(), attempt: 5, - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_millis(10))), backoff_coefficient: 1.0, maximum_attempts: 1, ..Default::default() - }, + }), local_retry_threshold: Duration::from_secs(500), close_timeouts, ..Default::default() @@ -1400,11 +1400,11 @@ mod tests { schedule_cmd: ValidScheduleLA { seq: 1, activity_id: 1.to_string(), - retry_policy: RetryPolicy { + retry_policy: ValidatedRetryPolicy::from_proto_with_defaults(RetryPolicy { initial_interval: Some(prost_dur!(from_millis(1))), backoff_coefficient: 1.0, ..Default::default() - }, + }), local_retry_threshold: Duration::from_secs(500), ..Default::default() }, diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 578976d0e..2fae0309a 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -2,14 +2,14 @@ pub(crate) mod mocks; use crate::protosext::legacy_query_failure; -use parking_lot::{Mutex, RwLock}; +use parking_lot::Mutex; use prost_types::Duration as PbDuration; use std::collections::HashMap; use std::time::SystemTime; use std::{sync::Arc, time::Duration}; use temporal_client::{ Client, ClientWorkerSet, IsWorkerTaskLongPoll, Namespace, NamespacedClient, NoRetryOnMatching, - RetryClient, WorkflowService, + RetryClient, SharedReplaceableClient, WorkflowService, }; use temporal_sdk_core_api::worker::WorkerVersioningStrategy; use temporal_sdk_core_protos::temporal::api::enums::v1::WorkerStatus; @@ -49,7 +49,7 @@ pub enum LegacyQueryResult { /// Contains everything a worker needs to interact with the server pub(crate) struct WorkerClientBag { - replaceable_client: RwLock>, + client: RetryClient>, namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, @@ -58,13 +58,13 @@ pub(crate) struct WorkerClientBag { impl WorkerClientBag { pub(crate) fn new( - client: RetryClient, + client: RetryClient>, namespace: String, identity: String, worker_versioning_strategy: WorkerVersioningStrategy, ) -> Self { Self { - replaceable_client: RwLock::new(client), + client, namespace, identity, worker_versioning_strategy, @@ -72,10 +72,6 @@ impl WorkerClientBag { } } - fn cloned_client(&self) -> RetryClient { - self.replaceable_client.read().clone() - } - fn default_capabilities(&self) -> Capabilities { self.capabilities().unwrap_or_default() } @@ -231,7 +227,7 @@ pub trait WorkerClient: Sync + Send { ) -> Result; /// Replace the underlying client - fn replace_client(&self, new_client: RetryClient); + fn replace_client(&self, new_client: Client); /// Return server capabilities fn capabilities(&self) -> Option; /// Return workers using this client @@ -312,7 +308,8 @@ impl WorkerClient for WorkerClientBag { } Ok(self - .cloned_client() + .client + .clone() .poll_workflow_task_queue(request) .await? .into_inner()) @@ -349,7 +346,8 @@ impl WorkerClient for WorkerClientBag { } Ok(self - .cloned_client() + .client + .clone() .poll_activity_task_queue(request) .await? .into_inner()) @@ -383,7 +381,8 @@ impl WorkerClient for WorkerClientBag { } Ok(self - .cloned_client() + .client + .clone() .poll_nexus_task_queue(request) .await? .into_inner()) @@ -433,8 +432,9 @@ impl WorkerClient for WorkerClientBag { deployment_options: self.deployment_options(), }; Ok(self - .cloned_client() - .respond_workflow_task_completed(request) + .client + .clone() + .respond_workflow_task_completed(request.into_request()) .await? .into_inner()) } @@ -445,7 +445,8 @@ impl WorkerClient for WorkerClientBag { result: Option, ) -> Result { Ok(self - .cloned_client() + .client + .clone() .respond_activity_task_completed( #[allow(deprecated)] // want to list all fields explicitly RespondActivityTaskCompletedRequest { @@ -457,7 +458,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -469,13 +471,17 @@ impl WorkerClient for WorkerClientBag { response: nexus::v1::Response, ) -> Result { Ok(self - .cloned_client() - .respond_nexus_task_completed(RespondNexusTaskCompletedRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - task_token: task_token.0, - response: Some(response), - }) + .client + .clone() + .respond_nexus_task_completed( + RespondNexusTaskCompletedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + response: Some(response), + } + .into_request(), + ) .await? .into_inner()) } @@ -486,13 +492,17 @@ impl WorkerClient for WorkerClientBag { details: Option, ) -> Result { Ok(self - .cloned_client() - .record_activity_task_heartbeat(RecordActivityTaskHeartbeatRequest { - task_token: task_token.0, - details, - identity: self.identity.clone(), - namespace: self.namespace.clone(), - }) + .client + .clone() + .record_activity_task_heartbeat( + RecordActivityTaskHeartbeatRequest { + task_token: task_token.0, + details, + identity: self.identity.clone(), + namespace: self.namespace.clone(), + } + .into_request(), + ) .await? .into_inner()) } @@ -503,7 +513,8 @@ impl WorkerClient for WorkerClientBag { details: Option, ) -> Result { Ok(self - .cloned_client() + .client + .clone() .respond_activity_task_canceled( #[allow(deprecated)] // want to list all fields explicitly RespondActivityTaskCanceledRequest { @@ -515,7 +526,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -527,7 +539,8 @@ impl WorkerClient for WorkerClientBag { failure: Option, ) -> Result { Ok(self - .cloned_client() + .client + .clone() .respond_activity_task_failed( #[allow(deprecated)] // want to list all fields explicitly RespondActivityTaskFailedRequest { @@ -541,7 +554,8 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - }, + } + .into_request(), ) .await? .into_inner()) @@ -568,8 +582,9 @@ impl WorkerClient for WorkerClientBag { deployment_options: self.deployment_options(), }; Ok(self - .cloned_client() - .respond_workflow_task_failed(request) + .client + .clone() + .respond_workflow_task_failed(request.into_request()) .await? .into_inner()) } @@ -580,13 +595,17 @@ impl WorkerClient for WorkerClientBag { error: nexus::v1::HandlerError, ) -> Result { Ok(self - .cloned_client() - .respond_nexus_task_failed(RespondNexusTaskFailedRequest { - namespace: self.namespace.clone(), - identity: self.identity.clone(), - task_token: task_token.0, - error: Some(error), - }) + .client + .clone() + .respond_nexus_task_failed( + RespondNexusTaskFailedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + error: Some(error), + } + .into_request(), + ) .await? .into_inner()) } @@ -598,16 +617,20 @@ impl WorkerClient for WorkerClientBag { page_token: Vec, ) -> Result { Ok(self - .cloned_client() - .get_workflow_execution_history(GetWorkflowExecutionHistoryRequest { - namespace: self.namespace.clone(), - execution: Some(WorkflowExecution { - workflow_id, - run_id: run_id.unwrap_or_default(), - }), - next_page_token: page_token, - ..Default::default() - }) + .client + .clone() + .get_workflow_execution_history( + GetWorkflowExecutionHistoryRequest { + namespace: self.namespace.clone(), + execution: Some(WorkflowExecution { + workflow_id, + run_id: run_id.unwrap_or_default(), + }), + next_page_token: page_token, + ..Default::default() + } + .into_request(), + ) .await? .into_inner()) } @@ -630,25 +653,32 @@ impl WorkerClient for WorkerClientBag { let (_, completed_type, query_result, error_message) = query_result.into_components(); Ok(self - .cloned_client() - .respond_query_task_completed(RespondQueryTaskCompletedRequest { - task_token: task_token.into(), - completed_type: completed_type as i32, - query_result, - error_message, - namespace: self.namespace.clone(), - failure, - cause: cause.into(), - }) + .client + .clone() + .respond_query_task_completed( + RespondQueryTaskCompletedRequest { + task_token: task_token.into(), + completed_type: completed_type as i32, + query_result, + error_message, + namespace: self.namespace.clone(), + failure, + cause: cause.into(), + } + .into_request(), + ) .await? .into_inner()) } async fn describe_namespace(&self) -> Result { Ok(self - .cloned_client() + .client + .clone() .describe_namespace( - Namespace::Name(self.namespace.clone()).into_describe_namespace_request(), + Namespace::Name(self.namespace.clone()) + .into_describe_namespace_request() + .into_request(), ) .await? .into_inner()) @@ -673,7 +703,7 @@ impl WorkerClient for WorkerClientBag { }; Ok( - WorkflowService::shutdown_worker(&mut self.cloned_client(), request) + WorkflowService::shutdown_worker(&mut self.client.clone(), request.into_request()) .await? .into_inner(), ) @@ -690,25 +720,28 @@ impl WorkerClient for WorkerClientBag { worker_heartbeat, }; Ok(self - .cloned_client() - .record_worker_heartbeat(request) + .client + .clone() + .record_worker_heartbeat(request.into_request()) .await? .into_inner()) } - fn replace_client(&self, new_client: RetryClient) { - let mut replaceable_client = self.replaceable_client.write(); - *replaceable_client = new_client; + fn replace_client(&self, new_client: Client) { + self.client.get_client().replace_client(new_client); } fn capabilities(&self) -> Option { - let client = self.replaceable_client.read(); - client.get_client().inner().capabilities().cloned() + self.client + .get_client() + .inner_cow() + .inner() + .capabilities() + .cloned() } fn workers(&self) -> Arc { - let client = self.replaceable_client.read(); - client.get_client().inner().workers() + self.client.get_client().inner_cow().inner().workers() } fn is_mock(&self) -> bool { @@ -716,8 +749,8 @@ impl WorkerClient for WorkerClientBag { } fn sdk_name_and_version(&self) -> (String, String) { - let lock = self.replaceable_client.read(); - let opts = lock.get_client().inner().options(); + let inner = self.client.get_client().inner_cow(); + let opts = inner.options(); (opts.client_name.clone(), opts.client_version.clone()) } @@ -726,17 +759,14 @@ impl WorkerClient for WorkerClientBag { } fn worker_grouping_key(&self) -> Uuid { - self.replaceable_client - .read() - .get_client() - .worker_grouping_key() + self.client.get_client().inner_cow().worker_grouping_key() } fn set_heartbeat_client_fields(&self, heartbeat: &mut WorkerHeartbeat) { if let Some(host_info) = heartbeat.host_info.as_mut() { host_info.process_key = self.worker_grouping_key().to_string(); } - heartbeat.worker_identity = self.identity(); + heartbeat.worker_identity = WorkerClient::identity(self); let sdk_name_and_ver = self.sdk_name_and_version(); heartbeat.sdk_name = sdk_name_and_ver.0; heartbeat.sdk_version = sdk_name_and_ver.1; @@ -778,12 +808,12 @@ impl WorkerClient for WorkerClientBag { } impl NamespacedClient for WorkerClientBag { - fn namespace(&self) -> &str { - &self.namespace + fn namespace(&self) -> String { + self.namespace.clone() } - fn get_identity(&self) -> &str { - &self.identity + fn identity(&self) -> String { + self.identity.clone() } } diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index b86334e4b..26a1b18da 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -163,7 +163,7 @@ mockall::mock! { heartbeat: Vec ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; - fn replace_client(&self, new_client: RetryClient); + fn replace_client(&self, new_client: Client); fn capabilities(&self) -> Option; fn workers(&self) -> Arc; fn is_mock(&self) -> bool; diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index f3dbf276a..eed8602f1 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -19,6 +19,8 @@ pub(crate) use activities::{ NewLocalAct, }; pub(crate) use wft_poller::WFTPollerShared; + +#[allow(unreachable_pub)] // re-exported in test_help::integ_helpers pub use workflow::LEGACY_QUERY_ID; use crate::telemetry::WorkerHeartbeatMetrics; @@ -29,6 +31,7 @@ use crate::{ errors::CompleteWfError, pollers::{BoxedActPoller, BoxedNexusPoller}, protosext::validate_activity_completion, + sealed::AnyClient, telemetry::{ TelemetryInstance, metrics::{ @@ -67,10 +70,8 @@ use std::{ }, time::Duration, }; +use temporal_client::SharedNamespaceWorkerTrait; use temporal_client::{ClientWorker, HeartbeatCallback, Slot as SlotTrait}; -use temporal_client::{ - ConfiguredClient, SharedNamespaceWorkerTrait, TemporalServiceClientWithMetrics, -}; use temporal_sdk_core_api::telemetry::metrics::TemporalMeter; use temporal_sdk_core_api::worker::{ ActivitySlotKind, LocalActivitySlotKind, NexusSlotKind, SlotKind, WorkflowSlotKind, @@ -336,7 +337,7 @@ impl Worker { ) } - /// Replace client and return a new client. + /// Replace client. /// /// For eager workflow purposes, this new client will now apply to future eager start requests /// and the older client will not. Note, if this registration fails, the worker heartbeat will @@ -345,10 +346,10 @@ impl Worker { /// For worker heartbeat, this will remove an existing shared worker if it is the last worker of /// the old client and create a new nexus worker if it's the first client of the namespace on /// the new client. - pub fn replace_client( - &self, - new_client: ConfiguredClient, - ) -> Result<(), anyhow::Error> { + pub fn replace_client(&self, new_client: CT) -> Result<(), anyhow::Error> + where + CT: Into, + { // Unregister worker from current client, register in new client at the end let client_worker = self .client diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index 0639d425e..46b3f5344 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -32,7 +32,11 @@ use temporal_sdk_core_protos::{ CancelNexusTask, NexusTask, NexusTaskCancelReason, nexus_task, nexus_task_completion, }, }, - temporal::api::nexus::v1::{request::Variant, response, start_operation_response}, + temporal::api::nexus::{ + self, + v1::{request::Variant, response, start_operation_response}, + }, + utilities::normalize_http_headers, }; use tokio::{ join, @@ -42,7 +46,7 @@ use tokio::{ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -static REQUEST_TIMEOUT_HEADER: &str = "Request-Timeout"; +static REQUEST_TIMEOUT_HEADER: &str = "request-timeout"; /// Centralizes all state related to received nexus tasks pub(super) struct NexusManager { @@ -245,11 +249,18 @@ where .filter_map(move |t| { let res = match t { TaskStreamInput::Poll(t) => match *t { - Ok(t) => { + Ok(mut t) => { if let Some(dur) = t.resp.sched_to_start() { self.metrics.nexus_task_sched_to_start_latency(dur); }; + if let Some(ref mut req) = t.resp.request { + req.header = normalize_http_headers(std::mem::take(&mut req.header)); + if let Some(nexus::v1::request::Variant::StartOperation(ref mut sor)) = req.variant { + sor.callback_header = normalize_http_headers(std::mem::take(&mut sor.callback_header)); + } + } + let tt = TaskToken(t.resp.task_token.clone()); let mut timeout_task = None; if let Some(timeout_str) = t diff --git a/core/src/worker/workflow/machines/patch_state_machine.rs b/core/src/worker/workflow/machines/patch_state_machine.rs index 8548a0c5d..8c4312674 100644 --- a/core/src/worker/workflow/machines/patch_state_machine.rs +++ b/core/src/worker/workflow/machines/patch_state_machine.rs @@ -46,7 +46,7 @@ use temporal_sdk_core_protos::{ RecordMarkerCommandAttributes, UpsertWorkflowSearchAttributesCommandAttributes, }, common::v1::SearchAttributes, - enums::v1::CommandType, + enums::v1::{CommandType, IndexedValueType}, }, }; @@ -127,7 +127,7 @@ pub(super) fn has_change<'a>( // Produce an upsert SA command for this patch. let mut all_ids = BTreeSet::from_iter(existing_patch_ids); all_ids.insert(machine.shared_state.patch_id.as_str()); - let serialized = all_ids + let mut serialized = all_ids .as_json_payload() .context("Could not serialize search attribute value for patch machine") .map_err(|e| WFMachinesError::Fatal(e.to_string()))?; @@ -141,6 +141,13 @@ pub(super) fn has_change<'a>( ); vec![] } else { + serialized.metadata.insert( + "type".to_string(), + IndexedValueType::KeywordList + .as_str_name() + .as_bytes() + .to_vec(), + ); let indexed_fields = { let mut m = HashMap::new(); m.insert(VERSION_SEARCH_ATTR_KEY.to_string(), serialized); diff --git a/core/src/worker/workflow/mod.rs b/core/src/worker/workflow/mod.rs index 24aa59f11..c8103ba11 100644 --- a/core/src/worker/workflow/mod.rs +++ b/core/src/worker/workflow/mod.rs @@ -99,6 +99,7 @@ use tracing::{Span, Subscriber}; /// Id used by server for "legacy" queries. IE: Queries that come in the `query` rather than /// `queries` field of a WFT, and are responded to on the separate `respond_query_task_completed` /// rpc. +#[allow(unreachable_pub)] // re-exported in supermodule pub const LEGACY_QUERY_ID: &str = "legacy_query"; /// What percentage of a WFT timeout we are willing to wait before sending a WFT heartbeat when /// necessary. diff --git a/core/src/worker/workflow/workflow_stream.rs b/core/src/worker/workflow/workflow_stream.rs index b85e08953..0cce1ee90 100644 --- a/core/src/worker/workflow/workflow_stream.rs +++ b/core/src/worker/workflow/workflow_stream.rs @@ -578,7 +578,7 @@ impl WFStream { #[derive(derive_more::From, Debug)] enum WFStreamInput { NewWft(Box), - Local(LocalInput), + Local(Box), /// The stream given to us which represents the poller (or a mock) terminated. PollerDead, /// The stream given to us which represents the poller (or a mock) encountered a non-retryable @@ -590,6 +590,11 @@ enum WFStreamInput { auto_reply_fail_tt: Option, }, } +impl From for WFStreamInput { + fn from(input: LocalInput) -> Self { + WFStreamInput::Local(Box::new(input)) + } +} /// A non-poller-received input to the [WFStream] #[derive(derive_more::Debug)] @@ -672,10 +677,10 @@ impl From for WFStreamInput { paginator, update, span, - } => WFStreamInput::Local(LocalInput { + } => WFStreamInput::Local(Box::new(LocalInput { input: LocalInputs::FetchedPageCompletion { paginator, update }, span, - }), + })), } } } diff --git a/fsm/rustfsm_procmacro/tests/trybuild/dupe_transitions_fail.stderr b/fsm/rustfsm_procmacro/tests/trybuild/dupe_transitions_fail.stderr index 3e0ee1f40..f16c35878 100644 --- a/fsm/rustfsm_procmacro/tests/trybuild/dupe_transitions_fail.stderr +++ b/fsm/rustfsm_procmacro/tests/trybuild/dupe_transitions_fail.stderr @@ -1,11 +1,11 @@ error: Duplicate transitions are not allowed! --> $DIR/dupe_transitions_fail.rs:5:1 | -5 | / fsm! { -6 | | name SimpleMachine; command SimpleMachineCommand; error Infallible; -7 | | -8 | | One --(A)--> Two; -9 | | One --(A)--> Two; + 5 | / fsm! { + 6 | | name SimpleMachine; command SimpleMachineCommand; error Infallible; + 7 | | + 8 | | One --(A)--> Two; + 9 | | One --(A)--> Two; 10 | | } | |_^ | diff --git a/sdk-core-protos/Cargo.toml b/sdk-core-protos/Cargo.toml index d9a77052d..4ce0bbffa 100644 --- a/sdk-core-protos/Cargo.toml +++ b/sdk-core-protos/Cargo.toml @@ -20,19 +20,18 @@ anyhow = "1.0" base64 = "0.22" derive_more = { workspace = true } prost = { workspace = true } -prost-wkt = "0.6" -prost-wkt-types = "0.6" +prost-wkt = "0.7" +prost-types = { workspace = true } rand = { version = "0.9", optional = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = { workspace = true } tonic = { workspace = true } +tonic-prost = { workspace = true } uuid = { version = "1.18", features = ["v4"], optional = true } [build-dependencies] -tonic-build = { workspace = true } -prost-build = "0.13" -prost-wkt-build = "0.6" +tonic-prost-build = { workspace = true } [lints] workspace = true diff --git a/sdk-core-protos/build.rs b/sdk-core-protos/build.rs index 68b964fd1..3f2fe8a3c 100644 --- a/sdk-core-protos/build.rs +++ b/sdk-core-protos/build.rs @@ -1,5 +1,7 @@ use std::{env, path::PathBuf}; +use tonic_prost_build::Config; + static ALWAYS_SERDE: &str = "#[cfg_attr(not(feature = \"serde_serialize\"), \ derive(::serde::Serialize, ::serde::Deserialize))]"; @@ -7,7 +9,7 @@ fn main() -> Result<(), Box> { println!("cargo:rerun-if-changed=./protos"); let out = PathBuf::from(env::var("OUT_DIR").unwrap()); let descriptor_file = out.join("descriptors.bin"); - tonic_build::configure() + tonic_prost_build::configure() // We don't actually want to build the grpc definitions - we don't need them (for now). // Just build the message structs. .build_server(false) @@ -96,29 +98,14 @@ fn main() -> Result<(), Box> { "coresdk.external_data.LocalActivityMarkerData.backoff", "#[serde(with = \"opt_duration\")]", ) - .extern_path( - ".google.protobuf.Any", - "::prost_wkt_types::Any" - ) - .extern_path( - ".google.protobuf.Timestamp", - "::prost_wkt_types::Timestamp" - ) - .extern_path( - ".google.protobuf.Duration", - "::prost_wkt_types::Duration" - ) - .extern_path( - ".google.protobuf.Value", - "::prost_wkt_types::Value" - ) - .extern_path( - ".google.protobuf.FieldMask", - "::prost_wkt_types::FieldMask" - ) .file_descriptor_set_path(descriptor_file) - .skip_debug("temporal.api.common.v1.Payload") - .compile_protos( + .skip_debug(["temporal.api.common.v1.Payload"]) + .compile_with_config( + { + let mut c = Config::new(); + c.enable_type_names(); + c + }, &[ "./protos/local/temporal/sdk/core/core_interface.proto", "./protos/api_upstream/temporal/api/workflowservice/v1/service.proto", diff --git a/sdk-core-protos/src/history_builder.rs b/sdk-core-protos/src/history_builder.rs index 9aef52737..c197dd684 100644 --- a/sdk-core-protos/src/history_builder.rs +++ b/sdk-core-protos/src/history_builder.rs @@ -23,7 +23,7 @@ use crate::{ }, }; use anyhow::bail; -use prost_wkt_types::Timestamp; +use prost_types::Timestamp; use std::{ collections::HashMap, time::{Duration, SystemTime}, diff --git a/sdk-core-protos/src/lib.rs b/sdk-core-protos/src/lib.rs index 4b03515c3..8b9242c3e 100644 --- a/sdk-core-protos/src/lib.rs +++ b/sdk-core-protos/src/lib.rs @@ -45,7 +45,7 @@ pub mod coresdk { use crate::{ ENCODING_PAYLOAD_KEY, JSON_ENCODING_VAL, temporal::api::{ - common::v1::{Payload, Payloads, WorkflowExecution}, + common::v1::{Payload, Payloads, RetryPolicy, WorkflowExecution}, enums::v1::{ ApplicationErrorCategory, TimeoutType, VersioningBehavior, WorkflowTaskFailedCause, }, @@ -396,7 +396,7 @@ pub mod coresdk { } pub mod external_data { - use prost_wkt_types::{Duration, Timestamp}; + use prost_types::{Duration, Timestamp}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; tonic::include_proto!("coresdk.external_data"); @@ -481,6 +481,7 @@ pub mod coresdk { FromPayloadsExt, activity_result::{ActivityResolution, activity_resolution}, common::NamespacedWorkflowExecution, + fix_retry_policy, workflow_activation::remove_from_cache::EvictionReason, }, temporal::api::{ @@ -493,7 +494,7 @@ pub mod coresdk { query::v1::WorkflowQuery, }, }; - use prost_wkt_types::Timestamp; + use prost_types::Timestamp; use std::fmt::{Display, Formatter}; tonic::include_proto!("coresdk.workflow_activation"); @@ -745,7 +746,7 @@ pub mod coresdk { continued_failure: attrs.continued_failure, last_completion_result: attrs.last_completion_result, first_execution_run_id: attrs.first_execution_run_id, - retry_policy: attrs.retry_policy, + retry_policy: attrs.retry_policy.map(fix_retry_policy), attempt: attrs.attempt, cron_schedule: attrs.cron_schedule, workflow_execution_expiration_time: attrs.workflow_execution_expiration_time, @@ -1298,7 +1299,7 @@ pub mod coresdk { schedule_to_close_timeout: r.schedule_to_close_timeout, start_to_close_timeout: r.start_to_close_timeout, heartbeat_timeout: r.heartbeat_timeout, - retry_policy: r.retry_policy, + retry_policy: r.retry_policy.map(fix_retry_policy), priority: r.priority, is_local: false, }, @@ -1578,6 +1579,15 @@ pub mod coresdk { } } } + + /// If initial_interval is missing, fills it with zero value to prevent crashes + /// (lang assumes that RetryPolicy always has initial_interval set). + fn fix_retry_policy(mut retry_policy: RetryPolicy) -> RetryPolicy { + if retry_policy.initial_interval.is_none() { + retry_policy.initial_interval = Default::default(); + } + retry_policy + } } // No need to lint these @@ -2141,8 +2151,7 @@ pub mod temporal { enums::v1::EventType, history::v1::history_event::Attributes, }; use anyhow::bail; - use prost::alloc::fmt::Formatter; - use std::fmt::Display; + use std::fmt::{Display, Formatter}; tonic::include_proto!("temporal.api.history.v1"); @@ -2683,8 +2692,8 @@ pub mod temporal { } fn elapsed_between_prost_times( - from: prost_wkt_types::Timestamp, - to: prost_wkt_types::Timestamp, + from: prost_types::Timestamp, + to: prost_types::Timestamp, ) -> Option> { let from: Result = from.try_into(); let to: Result = to.try_into(); diff --git a/sdk-core-protos/src/utilities.rs b/sdk-core-protos/src/utilities.rs index 2f269299e..9a8decd9d 100644 --- a/sdk-core-protos/src/utilities.rs +++ b/sdk-core-protos/src/utilities.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use prost::{EncodeError, Message}; pub trait TryIntoOrNone { @@ -18,11 +20,18 @@ where /// Use to encode an message into a proto `Any`. /// /// Delete this once `prost_wkt_types` supports `prost` `0.12.x` which has built-in any packing. -pub fn pack_any( - type_url: String, - msg: &T, -) -> Result { +pub fn pack_any(type_url: String, msg: &T) -> Result { let mut value = Vec::new(); Message::encode(msg, &mut value)?; - Ok(prost_wkt_types::Any { type_url, value }) + Ok(prost_types::Any { type_url, value }) +} + +/// Given a header map, lowercase all the keys and return it as a new map. +/// Any keys that are duplicated after lowercasing will clobber each other in undefined ordering. +pub fn normalize_http_headers(headers: HashMap) -> HashMap { + let mut new_headers = HashMap::new(); + for (header_key, val) in headers.into_iter() { + new_headers.insert(header_key.to_lowercase(), val); + } + new_headers } diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 97039c2cb..f6f70de5a 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -16,9 +16,15 @@ anyhow = "1.0" derive_more = { workspace = true } futures-util = { version = "0.3", default-features = false } parking_lot = { version = "0.12", features = ["send_guard"] } -prost-types = { version = "0.6", package = "prost-wkt-types" } +prost-types = { workspace = true } serde = "1.0" -tokio = { version = "1.47", features = ["rt", "rt-multi-thread", "parking_lot", "time", "fs"] } +tokio = { version = "1.47", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "time", + "fs", +] } tokio-util = { version = "0.7" } tokio-stream = "0.1" tracing = "0.1" diff --git a/sdk/src/activity_context.rs b/sdk/src/activity_context.rs index 804c5b87e..c9d25f4c5 100644 --- a/sdk/src/activity_context.rs +++ b/sdk/src/activity_context.rs @@ -56,7 +56,7 @@ pub struct ActivityInfo { impl ActContext { /// Construct new Activity Context, returning the context and the first argument to the activity /// (which may be a default [Payload]). - pub(crate) fn new( + pub fn new( worker: Arc, app_data: Arc, cancellation_token: CancellationToken, diff --git a/sdk/src/app_data.rs b/sdk/src/app_data.rs index bffeb8199..977d18879 100644 --- a/sdk/src/app_data.rs +++ b/sdk/src/app_data.rs @@ -6,7 +6,7 @@ use std::{ /// A Wrapper Type for workflow and activity app data #[derive(Default)] -pub(crate) struct AppData { +pub struct AppData { map: HashMap>, } diff --git a/tests/c_bridge_smoke_test.c b/tests/c_bridge_smoke_test.c new file mode 100644 index 000000000..c38e3f01e --- /dev/null +++ b/tests/c_bridge_smoke_test.c @@ -0,0 +1,10 @@ +#include "temporal-sdk-core-c-bridge.h" +#include + +int main(void) { + // Just do something simple to confirm the bridge works + struct TemporalCoreCancellationToken *tok = temporal_core_cancellation_token_new(); + temporal_core_cancellation_token_free(tok); + printf("C bridge smoke test passed!\n"); + return 0; +} \ No newline at end of file diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 94785b9e6..4933372ce 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -66,6 +66,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::{sync::OnceCell, task::AbortHandle}; +use tonic::IntoRequest; use tracing::{debug, warn}; use url::Url; use uuid::Uuid; @@ -82,6 +83,11 @@ pub(crate) const OTEL_URL_ENV_VAR: &str = "TEMPORAL_INTEG_OTEL_URL"; pub(crate) const PROM_ENABLE_ENV_VAR: &str = "TEMPORAL_INTEG_PROM_PORT"; /// This should match the prometheus port exposed in docker-compose-ci.yaml pub(crate) const PROMETHEUS_QUERY_API: &str = "http://localhost:9090/api/v1/query"; +/// If set, integ tests will use this specific version of CLI for starting dev server +pub(crate) const CLI_VERSION_OVERRIDE_ENV_VAR: &str = "CLI_VERSION_OVERRIDE"; +pub(crate) const INTEG_CLIENT_IDENTITY: &str = "integ_tester"; +pub(crate) const INTEG_CLIENT_NAME: &str = "temporal-core"; +pub(crate) const INTEG_CLIENT_VERSION: &str = "0.1.0"; /// Create a worker instance which will use the provided test name to base the task queue and wf id /// upon. Returns the instance. @@ -256,8 +262,7 @@ impl CoreWfStarter { .get_client() .inner() .workflow_svc() - .clone() - .get_cluster_info(GetClusterInfoRequest::default()) + .get_cluster_info(GetClusterInfoRequest::default().into_request()) .await; let srv_ver = semver::Version::parse( &clustinfo @@ -599,8 +604,7 @@ impl TestWorker { .client .as_ref() .map(|c| c.namespace()) - .unwrap_or(NAMESPACE) - .to_owned(), + .unwrap_or(NAMESPACE.to_owned()), workflow_id: wf_id.into(), run_id, }); @@ -752,10 +756,10 @@ pub(crate) fn get_integ_server_options() -> ClientOptions { .unwrap_or_else(|_| "http://localhost:7233".to_owned()); let url = Url::try_from(&*temporal_server_address).unwrap(); let mut cb = ClientOptionsBuilder::default(); - cb.identity("integ_tester".to_string()) + cb.identity(INTEG_CLIENT_IDENTITY.to_string()) .target_url(url) - .client_name("temporal-core".to_string()) - .client_version("0.1.0".to_string()); + .client_name(INTEG_CLIENT_NAME.to_string()) + .client_version(INTEG_CLIENT_VERSION.to_string()); if let Ok(key_file) = env::var(INTEG_API_KEY) { let content = std::fs::read_to_string(key_file).unwrap(); cb.api_key(Some(content)); @@ -991,3 +995,53 @@ impl Drop for ActivationAssertionsInterceptor { } } } + +#[cfg(feature = "ephemeral-server")] +use temporal_sdk_core::ephemeral_server::{ + EphemeralExe, EphemeralExeVersion, TemporalDevServerConfigBuilder, default_cached_download, +}; + +#[cfg(feature = "ephemeral-server")] +pub(crate) fn integ_dev_server_config( + mut extra_args: Vec, +) -> TemporalDevServerConfigBuilder { + let cli_version = if let Ok(ver_override) = env::var(CLI_VERSION_OVERRIDE_ENV_VAR) { + EphemeralExe::CachedDownload { + version: EphemeralExeVersion::Fixed(ver_override.to_owned()), + dest_dir: None, + ttl: None, + } + } else { + default_cached_download() + }; + extra_args.extend( + [ + // TODO: Delete when temporalCLI enables it by default. + "--dynamic-config-value".to_string(), + "system.enableEagerWorkflowStart=true".to_string(), + "--dynamic-config-value".to_string(), + "system.enableNexus=true".to_string(), + "--dynamic-config-value".to_owned(), + "frontend.workerVersioningWorkflowAPIs=true".to_owned(), + "--dynamic-config-value".to_owned(), + "frontend.workerVersioningDataAPIs=true".to_owned(), + "--dynamic-config-value".to_owned(), + "system.enableDeploymentVersions=true".to_owned(), + "--dynamic-config-value".to_owned(), + "component.nexusoperations.recordCancelRequestCompletionEvents=true".to_owned(), + "--dynamic-config-value".to_owned(), + "frontend.WorkerHeartbeatsEnabled=true".to_owned(), + "--dynamic-config-value".to_owned(), + "frontend.ListWorkersEnabled=true".to_owned(), + "--search-attribute".to_string(), + format!("{SEARCH_ATTR_TXT}=Text"), + "--search-attribute".to_string(), + format!("{SEARCH_ATTR_INT}=Int"), + ] + .map(Into::into), + ); + + let mut config = TemporalDevServerConfigBuilder::default(); + config.exe(cli_version).extra_args(extra_args); + config +} diff --git a/tests/integ_tests/client_tests.rs b/tests/integ_tests/client_tests.rs index 9a4473b35..4dfa6a4ec 100644 --- a/tests/integ_tests/client_tests.rs +++ b/tests/integ_tests/client_tests.rs @@ -32,7 +32,7 @@ use tokio::{ sync::{mpsc::UnboundedSender, oneshot}, }; use tonic::{ - Code, Request, Status, + Code, IntoRequest, Request, Status, body::Body, codegen::{Service, http::Response}, server::NamedService, @@ -56,10 +56,13 @@ async fn can_use_retry_raw_client() { let opts = get_integ_server_options(); let mut client = opts.connect_no_namespace(None).await.unwrap(); client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); } @@ -79,10 +82,13 @@ async fn per_call_timeout_respected_whole_client() { hm.insert("grpc-timeout".to_string(), "0S".to_string()); raw_client.get_client().set_headers(hm).unwrap(); let err = raw_client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap_err(); assert_matches!(err.code(), Code::DeadlineExceeded | Code::Cancelled); @@ -409,12 +415,15 @@ async fn cloud_ops_test() { hm.insert("temporal-cloud-api-version".to_string(), api_version); hm }); - let mut client = opts.connect_no_namespace(None).await.unwrap().into_inner(); - let cloud_client = client.cloud_svc_mut(); + let client = opts.connect_no_namespace(None).await.unwrap().into_inner(); + let mut cloud_client = client.cloud_svc(); let res = cloud_client - .get_namespace(GetNamespaceRequest { - namespace: namespace.clone(), - }) + .get_namespace( + GetNamespaceRequest { + namespace: namespace.clone(), + } + .into_request(), + ) .await .unwrap(); assert_eq!(res.into_inner().namespace.unwrap().namespace, namespace); diff --git a/tests/integ_tests/ephemeral_server_tests.rs b/tests/integ_tests/ephemeral_server_tests.rs index 27abacf40..1b9b612a3 100644 --- a/tests/integ_tests/ephemeral_server_tests.rs +++ b/tests/integ_tests/ephemeral_server_tests.rs @@ -1,4 +1,4 @@ -use crate::common::NAMESPACE; +use crate::common::{INTEG_CLIENT_IDENTITY, INTEG_CLIENT_NAME, INTEG_CLIENT_VERSION, NAMESPACE}; use futures_util::{TryStreamExt, stream}; use std::time::{SystemTime, UNIX_EPOCH}; use temporal_client::{ClientOptionsBuilder, TestService, WorkflowService}; @@ -7,6 +7,7 @@ use temporal_sdk_core::ephemeral_server::{ default_cached_download, }; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::DescribeNamespaceRequest; +use tonic::IntoRequest; use url::Url; #[tokio::test] @@ -137,27 +138,30 @@ fn fixed_cached_download(version: &str) -> EphemeralExe { async fn assert_ephemeral_server(server: &EphemeralServer) { // Connect and describe namespace let mut client = ClientOptionsBuilder::default() - .identity("integ_tester".to_string()) + .identity(INTEG_CLIENT_IDENTITY.to_string()) .target_url(Url::try_from(&*format!("http://{}", server.target)).unwrap()) - .client_name("temporal-core".to_string()) - .client_version("0.1.0".to_string()) + .client_name(INTEG_CLIENT_NAME.to_string()) + .client_version(INTEG_CLIENT_VERSION.to_string()) .build() .unwrap() .connect_no_namespace(None) .await .unwrap(); let resp = client - .describe_namespace(DescribeNamespaceRequest { - namespace: NAMESPACE.to_string(), - ..Default::default() - }) + .describe_namespace( + DescribeNamespaceRequest { + namespace: NAMESPACE.to_string(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); assert!(resp.into_inner().namespace_info.unwrap().name == "default"); // If it has test service, make sure we can use it too if server.has_test_service { - let resp = client.get_current_time(()).await.unwrap(); + let resp = client.get_current_time(().into_request()).await.unwrap(); // Make sure it's within 5 mins of now let resp_seconds = resp.get_ref().time.as_ref().unwrap().seconds as u64; let curr_seconds = SystemTime::now() diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index 901e7c178..97bd191ff 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -71,6 +71,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::{join, sync::Barrier}; +use tonic::IntoRequest; use url::Url; pub(crate) async fn get_text(endpoint: String) -> String { @@ -107,7 +108,7 @@ async fn prometheus_metrics_exported( assert!(raw_client.get_client().capabilities().is_some()); let _ = raw_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await .unwrap(); @@ -542,7 +543,7 @@ fn runtime_new() { .unwrap(); assert!(raw_client.get_client().capabilities().is_some()); let _ = raw_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await .unwrap(); let body = get_text(format!("http://{addr}/metrics")).await; @@ -636,9 +637,12 @@ async fn request_fail_codes() { .unwrap(); // Describe namespace w/ invalid argument (unset namespace field) - WorkflowService::describe_namespace(&mut client, DescribeNamespaceRequest::default()) - .await - .unwrap_err(); + WorkflowService::describe_namespace( + &mut client, + DescribeNamespaceRequest::default().into_request(), + ) + .await + .unwrap_err(); let body = get_text(format!("http://{addr}/metrics")).await; let matching_line = body @@ -681,9 +685,12 @@ async fn request_fail_codes_otel() { for _ in 0..10 { // Describe namespace w/ invalid argument (unset namespace field) - WorkflowService::describe_namespace(&mut client, DescribeNamespaceRequest::default()) - .await - .unwrap_err(); + WorkflowService::describe_namespace( + &mut client, + DescribeNamespaceRequest::default().into_request(), + ) + .await + .unwrap_err(); tokio::time::sleep(Duration::from_secs(1)).await; } diff --git a/tests/integ_tests/polling_tests.rs b/tests/integ_tests/polling_tests.rs index aaa69ec16..e502cc9b4 100644 --- a/tests/integ_tests/polling_tests.rs +++ b/tests/integ_tests/polling_tests.rs @@ -1,18 +1,34 @@ use crate::{ - common::{CoreWfStarter, init_core_and_create_wf, init_integ_telem, integ_worker_config}, + common::{ + CoreWfStarter, INTEG_CLIENT_NAME, INTEG_CLIENT_VERSION, get_integ_server_options, + init_core_and_create_wf, init_integ_telem, integ_dev_server_config, integ_worker_config, + }, integ_tests::activity_functions::echo, }; use assert_matches::assert_matches; -use std::{sync::Arc, time::Duration}; +use futures_util::{FutureExt, StreamExt, future::join_all}; +use std::{ + process::Stdio, + sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, +}; use temporal_client::{WfClientExt, WorkflowClientTrait, WorkflowOptions}; use temporal_sdk::{ActivityOptions, WfContext}; use temporal_sdk_core::{ - ClientOptionsBuilder, + ClientOptionsBuilder, CoreRuntime, RuntimeOptionsBuilder, ephemeral_server::{TemporalDevServerConfigBuilder, default_cached_download}, init_worker, - test_help::{WorkerTestHelpers, drain_pollers_and_shutdown}, + telemetry::CoreLogStreamConsumer, + test_help::{NAMESPACE, WorkerTestHelpers, drain_pollers_and_shutdown}, +}; +use temporal_sdk_core_api::{ + Worker, + telemetry::{Logger, TelemetryOptionsBuilder}, + worker::PollerBehavior, }; -use temporal_sdk_core_api::{Worker, worker::PollerBehavior}; use temporal_sdk_core_protos::{ coresdk::{ AsJsonPayloadExt, IntoCompletion, @@ -25,7 +41,7 @@ use temporal_sdk_core_protos::{ temporal::api::enums::v1::EventType, test_utils::schedule_activity_cmd, }; -use tokio::time::timeout; +use tokio::{sync::Notify, time::timeout}; use tracing::info; use url::Url; @@ -119,103 +135,119 @@ async fn switching_worker_client_changes_poll() { ]) .build() .unwrap(); - let mut server1 = server_config.start_server().await.unwrap(); - let mut server2 = server_config.start_server().await.unwrap(); - - // Connect clients to both servers - info!("Connecting clients"); - let mut client_common_config = ClientOptionsBuilder::default(); - client_common_config - .identity("integ_tester".to_owned()) - .client_name("temporal-core".to_owned()) - .client_version("0.1.0".to_owned()); - let client1 = client_common_config - .clone() - .target_url(Url::parse(&format!("http://{}", server1.target)).unwrap()) - .build() - .unwrap() - .connect("default", None) + let mut server1 = server_config + .start_server_with_output(Stdio::null(), Stdio::null()) .await .unwrap(); - let client2 = client_common_config - .clone() - .target_url(Url::parse(&format!("http://{}", server2.target)).unwrap()) - .build() - .unwrap() - .connect("default", None) + let mut server2 = server_config + .start_server_with_output(Stdio::null(), Stdio::null()) .await .unwrap(); - // Start a workflow on both servers - info!("Starting workflows"); - let wf1 = client1 - .start_workflow( - vec![], - "my-task-queue".to_owned(), - "my-workflow-1".to_owned(), - "my-workflow-type".to_owned(), - None, - WorkflowOptions::default(), - ) - .await - .unwrap(); - let wf2 = client2 - .start_workflow( - vec![], - "my-task-queue".to_owned(), - "my-workflow-2".to_owned(), - "my-workflow-type".to_owned(), - None, - WorkflowOptions::default(), + let result = std::panic::AssertUnwindSafe(async { + // Connect clients to both servers + info!("Connecting clients"); + let mut client_common_config = ClientOptionsBuilder::default(); + client_common_config + .identity("integ_tester".to_owned()) + .client_name("temporal-core".to_owned()) + .client_version("0.1.0".to_owned()); + let client1 = client_common_config + .clone() + .target_url(Url::parse(&format!("http://{}", server1.target)).unwrap()) + .build() + .unwrap() + .connect("default", None) + .await + .unwrap(); + let client2 = client_common_config + .clone() + .target_url(Url::parse(&format!("http://{}", server2.target)).unwrap()) + .build() + .unwrap() + .connect("default", None) + .await + .unwrap(); + + // Start a workflow on both servers + info!("Starting workflows"); + let wf1 = client1 + .start_workflow( + vec![], + "my-task-queue".to_owned(), + "my-workflow-1".to_owned(), + "my-workflow-type".to_owned(), + None, + WorkflowOptions::default(), + ) + .await + .unwrap(); + let wf2 = client2 + .start_workflow( + vec![], + "my-task-queue".to_owned(), + "my-workflow-2".to_owned(), + "my-workflow-type".to_owned(), + None, + WorkflowOptions::default(), + ) + .await + .unwrap(); + + // Create a worker only on the first server + let worker = init_worker( + init_integ_telem().unwrap(), + integ_worker_config("my-task-queue") + // We want a cache so we don't get extra remove-job activations + .max_cached_workflows(100_usize) + .build() + .unwrap(), + client1.clone(), ) - .await .unwrap(); - // Create a worker only on the first server - let worker = init_worker( - init_integ_telem().unwrap(), - integ_worker_config("my-task-queue") - // We want a cache so we don't get extra remove-job activations - .max_cached_workflows(100_usize) - .build() - .unwrap(), - client1.clone(), - ) - .unwrap(); + // Poll for first task, confirm it's first wf, complete, and wait for complete + info!("Doing initial poll"); + let act1 = worker.poll_workflow_activation().await.unwrap(); + assert_eq!(wf1.run_id, act1.run_id); + worker.complete_execution(&act1.run_id).await; + worker.handle_eviction().await; + info!("Waiting on first workflow complete"); + client1 + .get_untyped_workflow_handle("my-workflow-1", wf1.run_id) + .get_workflow_result(Default::default()) + .await + .unwrap(); - // Poll for first task, confirm it's first wf, complete, and wait for complete - info!("Doing initial poll"); - let act1 = worker.poll_workflow_activation().await.unwrap(); - assert_eq!(wf1.run_id, act1.run_id); - worker.complete_execution(&act1.run_id).await; - worker.handle_eviction().await; - info!("Waiting on first workflow complete"); - client1 - .get_untyped_workflow_handle("my-workflow-1", wf1.run_id) - .get_workflow_result(Default::default()) - .await - .unwrap(); + // Swap client, poll for next task, confirm it's second wf, and respond w/ empty + info!("Replacing client and polling again"); + worker + .replace_client(client2.get_client().inner().clone()) + .unwrap(); + let act2 = worker.poll_workflow_activation().await.unwrap(); + assert_eq!(wf2.run_id, act2.run_id); + worker.complete_execution(&act2.run_id).await; + worker.handle_eviction().await; + info!("Waiting on second workflow complete"); + client2 + .get_untyped_workflow_handle("my-workflow-2", wf2.run_id) + .get_workflow_result(Default::default()) + .await + .unwrap(); - // Swap client, poll for next task, confirm it's second wf, and respond w/ empty - info!("Replacing client and polling again"); - worker - .replace_client(client2.get_client().inner().clone()) - .unwrap(); - let act2 = worker.poll_workflow_activation().await.unwrap(); - assert_eq!(wf2.run_id, act2.run_id); - worker.complete_execution(&act2.run_id).await; - worker.handle_eviction().await; - info!("Waiting on second workflow complete"); - client2 - .get_untyped_workflow_handle("my-workflow-2", wf2.run_id) - .get_workflow_result(Default::default()) - .await - .unwrap(); + // Shutdown workers and servers + drain_pollers_and_shutdown(&(Arc::new(worker) as Arc)).await; + }) + .catch_unwind() + .await; - // Shutdown workers and servers - drain_pollers_and_shutdown(&(Arc::new(worker) as Arc)).await; - server1.shutdown().await.unwrap(); - server2.shutdown().await.unwrap(); + let shutdown_results = join_all([server1.shutdown(), server2.shutdown()]).await; + if let Err(e) = result { + std::panic::resume_unwind(e); + } + for r in shutdown_results { + r.unwrap(); + } } #[rstest::rstest] @@ -298,3 +330,189 @@ async fn small_workflow_slots_and_pollers(#[values(false, true)] use_autoscaling .any(|e| e.event_type() == EventType::WorkflowTaskTimedOut); assert!(!any_task_timeouts); } + +#[tokio::test] +async fn replace_client_works_after_polling_failure() { + let (log_consumer, mut log_rx) = CoreLogStreamConsumer::new(100); + let telem_opts = TelemetryOptionsBuilder::default() + .logging(Logger::Push { + filter: "OFF,temporal_client=DEBUG".into(), + consumer: Arc::new(log_consumer), + }) + .build() + .unwrap(); + let runtime_opts = RuntimeOptionsBuilder::default() + .telemetry_options(telem_opts) + .build() + .unwrap(); + let rt = Arc::new(CoreRuntime::new_assume_tokio(runtime_opts).unwrap()); + + // Spawning background task to read logs and notify the test when polling failure occurs. + let look_for_poll_failure_log = Arc::new(AtomicBool::new(false)); + let poll_retry_log_found = Arc::new(Notify::new()); + let log_reader_join_handle = tokio::spawn({ + let look_for_poll_retry_log = look_for_poll_failure_log.clone(); + let poll_retry_log_found = poll_retry_log_found.clone(); + async move { + let mut enabled = false; + loop { + let Some(log) = log_rx.next().await else { + break; + }; + if !enabled { + enabled = look_for_poll_retry_log.load(Ordering::Acquire); + } + if enabled + && (log + .message + .starts_with("gRPC call poll_workflow_task_queue failed") + || log + .message + .starts_with("gRPC call poll_workflow_task_queue retried")) + { + poll_retry_log_found.notify_one(); + break; + } + } + } + }); + let abort_handles = Arc::new(Mutex::new(vec![log_reader_join_handle.abort_handle()])); + + // Starting a second dev server for the worker to connect to initially. Later this server will be shut down + // and the worker client replaced with a client connected to the main integration test server. + let initial_server_config = integ_dev_server_config(vec![]).build().unwrap(); + let initial_server = Arc::new(Mutex::new(Some( + initial_server_config + .start_server_with_output(Stdio::null(), Stdio::null()) + .await + .unwrap(), + ))); + + let result = { + let initial_server = initial_server.clone(); + let abort_handles = abort_handles.clone(); + std::panic::AssertUnwindSafe(async move { + let initial_server_target = format!( + "http://{}", + initial_server.lock().unwrap().as_ref().unwrap().target + ); + let client_for_initial_server = ClientOptionsBuilder::default() + .identity("client_for_initial_server".to_string()) + .target_url(Url::parse(&initial_server_target).unwrap()) + .client_name(INTEG_CLIENT_NAME.to_string()) + .client_version(INTEG_CLIENT_VERSION.to_string()) + .build() + .unwrap() + .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) + .await + .unwrap(); + + let wf_name = "replace_client_works_after_polling_failure"; + let task_queue = format!("{wf_name}_tq"); + + let worker = Arc::new( + init_worker( + &rt, + integ_worker_config(&task_queue) + .max_cached_workflows(100_usize) + .build() + .unwrap(), + client_for_initial_server.clone(), + ) + .unwrap(), + ); + + // Polling the initial server the first time is successful. + let wf_1 = client_for_initial_server + .start_workflow( + vec![], + task_queue.clone(), + wf_name.into(), + wf_name.into(), + None, + WorkflowOptions::default(), + ) + .await + .unwrap(); + let act_1 = + tokio::time::timeout(Duration::from_secs(60), worker.poll_workflow_activation()) + .await + .unwrap() + .unwrap(); + assert_eq!(act_1.run_id, wf_1.run_id); + + // Initial server is shut down. + let mut server = initial_server.lock().unwrap().take().unwrap(); + server.shutdown().await.unwrap(); + + // Start polling in a background task. + look_for_poll_failure_log.store(true, Ordering::Release); + let poll_join_handle = tokio::spawn({ + let worker = worker.clone(); + async move { worker.poll_workflow_activation().await } + }); + abort_handles + .try_lock() + .unwrap() + .push(poll_join_handle.abort_handle()); + + // Wait until polling failure is detected. + tokio::time::timeout(Duration::from_secs(60), poll_retry_log_found.notified()) + .await + .unwrap(); + + // Start a new WF on main integration server. + let client_for_integ_server = get_integ_server_options() + .connect(NAMESPACE, rt.telemetry().get_temporal_metric_meter()) + .await + .unwrap(); + let wf_2 = client_for_integ_server + .start_workflow( + vec![], + task_queue, + wf_name.into(), + wf_name.into(), + None, + WorkflowOptions { + execution_timeout: Some(Duration::from_secs(60)), + ..Default::default() + }, + ) + .await + .unwrap(); + + // Switch worker over to the main integration server. + // The polling started on the initial server should complete with a task from the new server. + worker.replace_client(client_for_integ_server).unwrap(); + let act_2 = tokio::time::timeout(Duration::from_secs(60), poll_join_handle) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(act_2.run_id, wf_2.run_id); + }) + } + .catch_unwind() + .await; + + // Cleaning up spawned background tasks if they're still running. + for handle in &*abort_handles.lock().unwrap() { + handle.abort(); + } + + // If the test panicked, we may or may not need to shut down the server here. + // If the test succeeded, the server should always be shut down by this point. + let server = initial_server.lock().unwrap().take(); + if let Some(mut server) = server { + let _ = server.shutdown().await; + assert_matches!( + result, + Err(_), + "Server should have been shut down during the test" + ); + } + + if let Err(e) = result { + std::panic::resume_unwind(e); + } +} diff --git a/tests/integ_tests/update_tests.rs b/tests/integ_tests/update_tests.rs index f3314c56c..748541f03 100644 --- a/tests/integ_tests/update_tests.rs +++ b/tests/integ_tests/update_tests.rs @@ -42,6 +42,7 @@ use temporal_sdk_core_protos::{ test_utils::start_timer_cmd, }; use tokio::{join, sync::Barrier}; +use tonic::IntoRequest; use uuid::Uuid; #[derive(Clone, Copy)] @@ -116,7 +117,7 @@ async fn reapplied_updates_due_to_reset() { Arc::make_mut(&mut client_mut), #[allow(deprecated)] ResetWorkflowExecutionRequest { - namespace: client.namespace().into(), + namespace: client.namespace(), workflow_execution: Some(WorkflowExecution { workflow_id: workflow_id.into(), run_id: pre_reset_run_id.clone(), @@ -125,7 +126,8 @@ async fn reapplied_updates_due_to_reset() { reset_reapply_type: ResetReapplyType::AllEligible as i32, request_id: Uuid::new_v4().to_string(), ..Default::default() - }, + } + .into_request(), ) .await .unwrap() diff --git a/tests/integ_tests/worker_heartbeat_tests.rs b/tests/integ_tests/worker_heartbeat_tests.rs index 1184cc505..c79ff03f3 100644 --- a/tests/integ_tests/worker_heartbeat_tests.rs +++ b/tests/integ_tests/worker_heartbeat_tests.rs @@ -30,6 +30,7 @@ use temporal_sdk_core_protos::temporal::api::workflowservice::v1::DescribeWorker use temporal_sdk_core_protos::temporal::api::workflowservice::v1::ListWorkersRequest; use tokio::sync::Notify; use tokio::time::sleep; +use tonic::IntoRequest; use url::Url; fn within_two_minutes_ts(ts: Timestamp) -> bool { @@ -70,7 +71,8 @@ async fn list_worker_heartbeats( page_size: 200, next_page_token: Vec::new(), query: query.into(), - }, + } + .into_request(), ) .await .unwrap() @@ -186,7 +188,8 @@ async fn docker_worker_heartbeat_basic(#[values("otel", "prom", "no_metrics")] b page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -225,7 +228,8 @@ async fn docker_worker_heartbeat_basic(#[values("otel", "prom", "no_metrics")] b page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -319,7 +323,8 @@ async fn docker_worker_heartbeat_tuner() { page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -711,7 +716,8 @@ async fn worker_heartbeat_multiple_workers() { DescribeWorkerRequest { namespace: client.namespace().to_owned(), worker_instance_key: worker_a_key.to_string(), - }, + } + .into_request(), ) .await .unwrap() @@ -731,7 +737,8 @@ async fn worker_heartbeat_multiple_workers() { DescribeWorkerRequest { namespace: client.namespace().to_owned(), worker_instance_key: worker_b_key.to_string(), - }, + } + .into_request(), ) .await .unwrap() @@ -817,7 +824,8 @@ async fn worker_heartbeat_failure_metrics() { page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -861,7 +869,8 @@ async fn worker_heartbeat_failure_metrics() { page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -961,7 +970,8 @@ async fn worker_heartbeat_no_runtime_heartbeat() { page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() @@ -1021,7 +1031,8 @@ async fn worker_heartbeat_skip_client_worker_set_check() { page_size: 100, next_page_token: Vec::new(), query: String::new(), - }, + } + .into_request(), ) .await .unwrap() diff --git a/tests/integ_tests/worker_tests.rs b/tests/integ_tests/worker_tests.rs index 920e2606d..2a0ed3c79 100644 --- a/tests/integ_tests/worker_tests.rs +++ b/tests/integ_tests/worker_tests.rs @@ -8,15 +8,17 @@ use futures_util::FutureExt; use std::{ cell::Cell, sync::{ - Arc, + Arc, Mutex, atomic::{AtomicBool, Ordering::Relaxed}, }, time::Duration, }; use temporal_client::WorkflowOptions; -use temporal_sdk::{ActivityOptions, WfContext, interceptors::WorkerInterceptor}; +use temporal_sdk::{ + ActivityOptions, LocalActivityOptions, WfContext, interceptors::WorkerInterceptor, +}; use temporal_sdk_core::{ - CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, init_worker, + CoreRuntime, ResourceBasedTuner, ResourceSlotOptions, TunerBuilder, init_worker, test_help::{ FakeWfResponses, MockPollCfg, ResponseType, build_mock_pollers, drain_pollers_and_shutdown, hist_to_poll_resp, mock_worker, mock_worker_client, @@ -25,7 +27,11 @@ use temporal_sdk_core::{ use temporal_sdk_core_api::{ Worker, errors::WorkerValidationError, - worker::{PollerBehavior, WorkerConfigBuilder, WorkerVersioningStrategy}, + worker::{ + ActivitySlotKind, LocalActivitySlotKind, PollerBehavior, SlotInfo, SlotInfoTrait, + SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext, SlotSupplier, + SlotSupplierPermit, WorkerConfigBuilder, WorkerVersioningStrategy, WorkflowSlotKind, + }, }; use temporal_sdk_core_protos::{ DEFAULT_WORKFLOW_TYPE, TestHistoryBuilder, canned_histories, @@ -574,3 +580,282 @@ async fn sets_build_id_from_wft_complete() { .unwrap(); worker.run_until_done().await.unwrap(); } + +#[derive(Debug, Clone)] +enum SlotEvent { + ReserveSlot { + slot_type: &'static str, + }, + TryReserveSlot { + slot_type: &'static str, + }, + MarkSlotUsed { + slot_type: &'static str, + is_sticky: bool, + workflow_type: Option, + activity_type: Option, + }, + ReleaseSlot { + slot_type: &'static str, + }, +} + +struct TrackingSlotSupplier { + events: Arc>>, + slot_type: &'static str, + _phantom: std::marker::PhantomData, +} + +impl TrackingSlotSupplier { + fn new(slot_type: &'static str) -> Self { + Self { + events: Arc::new(Mutex::new(Vec::new())), + slot_type, + _phantom: std::marker::PhantomData, + } + } + + fn get_events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + + fn add_event(&self, event: SlotEvent) { + self.events.lock().unwrap().push(event); + } + + fn extract_slot_info(info: &dyn SlotInfoTrait) -> (bool, Option, Option) { + match info.downcast() { + SlotInfo::Workflow(w) => (w.is_sticky, Some(w.workflow_type.clone()), None), + SlotInfo::Activity(a) => (false, None, Some(a.activity_type.clone())), + SlotInfo::LocalActivity(a) => (false, None, Some(a.activity_type.clone())), + SlotInfo::Nexus(_) => (false, None, None), + } + } +} + +#[async_trait::async_trait] +impl SlotSupplier for TrackingSlotSupplier +where + SK: temporal_sdk_core_api::worker::SlotKind + Send + Sync, + SK::Info: SlotInfoTrait, +{ + type SlotKind = SK; + + async fn reserve_slot(&self, _ctx: &dyn SlotReservationContext) -> SlotSupplierPermit { + self.add_event(SlotEvent::ReserveSlot { + slot_type: self.slot_type, + }); + SlotSupplierPermit::with_user_data(()) + } + + fn try_reserve_slot(&self, _ctx: &dyn SlotReservationContext) -> Option { + self.add_event(SlotEvent::TryReserveSlot { + slot_type: self.slot_type, + }); + Some(SlotSupplierPermit::with_user_data(())) + } + + fn mark_slot_used(&self, ctx: &dyn SlotMarkUsedContext) { + let (is_sticky, workflow_type, activity_type) = Self::extract_slot_info(ctx.info()); + self.add_event(SlotEvent::MarkSlotUsed { + slot_type: self.slot_type, + is_sticky, + workflow_type, + activity_type, + }); + } + + fn release_slot(&self, _ctx: &dyn SlotReleaseContext) { + self.add_event(SlotEvent::ReleaseSlot { + slot_type: self.slot_type, + }); + } +} + +#[tokio::test] +async fn test_custom_slot_supplier_simple() { + let wf_supplier = Arc::new(TrackingSlotSupplier::::new("workflow")); + let activity_supplier = Arc::new(TrackingSlotSupplier::::new("activity")); + let local_activity_supplier = Arc::new(TrackingSlotSupplier::::new( + "local_activity", + )); + + let mut starter = CoreWfStarter::new("test_custom_slot_supplier_simple"); + starter.worker_config.clear_max_outstanding_opts(); + + let mut tb = TunerBuilder::default(); + tb.workflow_slot_supplier(wf_supplier.clone()); + tb.activity_slot_supplier(activity_supplier.clone()); + tb.local_activity_slot_supplier(local_activity_supplier.clone()); + starter.worker_config.tuner(Arc::new(tb.build())); + + let mut worker = starter.worker().await; + + worker.register_activity( + "SlotSupplierActivity", + |_: temporal_sdk::ActContext, _: ()| async move { Ok(()) }, + ); + worker.register_wf( + "SlotSupplierWorkflow".to_owned(), + |ctx: WfContext| async move { + let _result = ctx + .activity(ActivityOptions { + activity_type: "SlotSupplierActivity".to_string(), + start_to_close_timeout: Some(Duration::from_secs(10)), + ..Default::default() + }) + .await; + let _result = ctx + .local_activity(LocalActivityOptions { + activity_type: "SlotSupplierActivity".to_string(), + start_to_close_timeout: Some(Duration::from_secs(10)), + ..Default::default() + }) + .await; + Ok(().into()) + }, + ); + + worker + .submit_wf( + "test-wf".to_owned(), + "SlotSupplierWorkflow".to_owned(), + vec![], + Default::default(), + ) + .await + .unwrap(); + + worker.run_until_done().await.unwrap(); + + // Collect all events + let wf_events = wf_supplier.get_events(); + let activity_events = activity_supplier.get_events(); + let local_activity_events = local_activity_supplier.get_events(); + + // Verify workflow slot events - should have reserve, mark used, and release events + assert!(wf_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "workflow") + )); + assert!(wf_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "workflow") + )); + assert!( + wf_events + .iter() + .any(|e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "workflow")) + ); + + // Verify activity slot events - should have reserve, try_reserve (for eager execution), mark + // used, and release + assert!(activity_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "activity") + )); + assert!( + activity_events.iter().any( + |e| matches!(e, SlotEvent::TryReserveSlot { slot_type } if *slot_type == "activity") + ) + ); + assert!(activity_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "activity") + )); + assert!( + activity_events + .iter() + .any(|e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "activity")) + ); + + // Verify local activity slot events + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::ReserveSlot { slot_type, .. } if *slot_type == "local_activity") + )); + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::MarkSlotUsed { slot_type, .. } if *slot_type == "local_activity") + )); + assert!(local_activity_events.iter().any( + |e| matches!(e, SlotEvent::ReleaseSlot { slot_type } if *slot_type == "local_activity") + )); + + assert!( + wf_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "workflow", + workflow_type: Some(wf_type), + .. + } if wf_type == "SlotSupplierWorkflow")) + ); + assert!( + activity_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "activity", + activity_type: Some(act_type), + .. + } if act_type == "SlotSupplierActivity")) + ); + assert!( + local_activity_events + .iter() + .any(|e| matches!(e, SlotEvent::MarkSlotUsed { + slot_type: "local_activity", + activity_type: Some(act_type), + .. + } if act_type == "SlotSupplierActivity")) + ); + assert!(wf_events.iter().any(|e| matches!( + e, + SlotEvent::MarkSlotUsed { + slot_type: "workflow", + is_sticky: false, + .. + } + ))); + + // Verify that the number of reserve/try_reserve events matches the number of release events + let total_reserves = wf_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count() + + activity_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count() + + local_activity_events + .iter() + .filter(|e| { + matches!( + e, + SlotEvent::ReserveSlot { .. } | SlotEvent::TryReserveSlot { .. } + ) + }) + .count(); + + let total_releases = wf_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count() + + activity_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count() + + local_activity_events + .iter() + .filter(|e| matches!(e, SlotEvent::ReleaseSlot { .. })) + .count(); + + assert_eq!( + total_reserves, total_releases, + "Number of reserves should equal number of releases" + ); +} diff --git a/tests/integ_tests/worker_versioning_tests.rs b/tests/integ_tests/worker_versioning_tests.rs index f5ca7f416..f532dafe4 100644 --- a/tests/integ_tests/worker_versioning_tests.rs +++ b/tests/integ_tests/worker_versioning_tests.rs @@ -23,6 +23,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::join; +use tonic::IntoRequest; #[rstest::rstest] #[tokio::test] @@ -76,10 +77,13 @@ async fn sets_deployment_info_on_task_responses(#[values(true, false)] use_defau client .get_client() .clone() - .describe_worker_deployment(DescribeWorkerDeploymentRequest { - namespace: client.namespace().to_string(), - deployment_name: deploy_name.clone(), - }) + .describe_worker_deployment( + DescribeWorkerDeploymentRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + } + .into_request(), + ) .await }, Duration::from_secs(5), @@ -92,13 +96,16 @@ async fn sets_deployment_info_on_task_responses(#[values(true, false)] use_defau client .get_client() .clone() - .set_worker_deployment_current_version(SetWorkerDeploymentCurrentVersionRequest { - namespace: client.namespace().to_owned(), - deployment_name: deploy_name.clone(), - version: format!("{deploy_name}.1.0"), - conflict_token: desc_resp.conflict_token, - ..Default::default() - }) + .set_worker_deployment_current_version( + SetWorkerDeploymentCurrentVersionRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + version: format!("{deploy_name}.1.0"), + conflict_token: desc_resp.conflict_token, + ..Default::default() + } + .into_request(), + ) .await .unwrap(); @@ -178,10 +185,13 @@ async fn activity_has_deployment_stamp() { client .get_client() .clone() - .describe_worker_deployment(DescribeWorkerDeploymentRequest { - namespace: client.namespace().to_string(), - deployment_name: deploy_name.clone(), - }) + .describe_worker_deployment( + DescribeWorkerDeploymentRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + } + .into_request(), + ) .await }, Duration::from_secs(50), @@ -194,13 +204,16 @@ async fn activity_has_deployment_stamp() { client .get_client() .clone() - .set_worker_deployment_current_version(SetWorkerDeploymentCurrentVersionRequest { - namespace: client.namespace().to_owned(), - deployment_name: deploy_name.clone(), - version: format!("{deploy_name}.1.0"), - conflict_token: desc_resp.conflict_token, - ..Default::default() - }) + .set_worker_deployment_current_version( + SetWorkerDeploymentCurrentVersionRequest { + namespace: client.namespace(), + deployment_name: deploy_name.clone(), + version: format!("{deploy_name}.1.0"), + conflict_token: desc_resp.conflict_token, + ..Default::default() + } + .into_request(), + ) .await .unwrap(); diff --git a/tests/integ_tests/workflow_tests/activities.rs b/tests/integ_tests/workflow_tests/activities.rs index 090a7b3f1..5274e8820 100644 --- a/tests/integ_tests/workflow_tests/activities.rs +++ b/tests/integ_tests/workflow_tests/activities.rs @@ -1,7 +1,7 @@ use crate::{ common::{ - ActivationAssertionsInterceptor, CoreWfStarter, build_fake_sdk, init_core_and_create_wf, - mock_sdk, mock_sdk_cfg, + ActivationAssertionsInterceptor, CoreWfStarter, INTEG_CLIENT_IDENTITY, build_fake_sdk, + eventually, init_core_and_create_wf, mock_sdk, mock_sdk_cfg, }, integ_tests::activity_functions::echo, }; @@ -25,7 +25,7 @@ use temporal_sdk_core_protos::{ DEFAULT_ACTIVITY_TYPE, DEFAULT_WORKFLOW_TYPE, TaskToken, TestHistoryBuilder, canned_histories, coresdk::{ ActivityHeartbeat, ActivityTaskCompletion, AsJsonPayloadExt, FromJsonPayloadExt, - IntoCompletion, + IntoCompletion, IntoPayloadsExt, activity_result::{ self, ActivityExecutionResult, ActivityResolution, activity_resolution as act_res, }, @@ -206,7 +206,7 @@ async fn activity_non_retryable_failure() { }), scheduled_event_id: 5, started_event_id: 6, - identity: "integ_tester".to_owned(), + identity: INTEG_CLIENT_IDENTITY.to_owned(), retry_state: RetryState::NonRetryableFailure as i32, })), ..Default::default() @@ -273,7 +273,7 @@ async fn activity_non_retryable_failure_with_error() { }), scheduled_event_id: 5, started_event_id: 6, - identity: "integ_tester".to_owned(), + identity: INTEG_CLIENT_IDENTITY.to_owned(), retry_state: RetryState::NonRetryableFailure as i32, })), ..Default::default() @@ -798,6 +798,86 @@ async fn activity_cancelled_after_heartbeat_times_out() { .unwrap(); } +#[ignore] // Currently skipped because of https://github.com/temporalio/temporal/issues/8376 +#[tokio::test] +async fn activity_heartbeat_not_flushed_on_success() { + let mut starter = init_core_and_create_wf("activity_heartbeat_not_flushed_on_success").await; + let core = starter.get_worker().await; + let task_q = starter.get_task_queue().to_string(); + let activity_id = "act-1"; + let task = core.poll_workflow_activation().await.unwrap(); + core.complete_workflow_activation(WorkflowActivationCompletion::from_cmd( + task.run_id, + ScheduleActivity { + seq: 0, + activity_id: activity_id.to_string(), + activity_type: "dontcare".to_string(), + task_queue: task_q.clone(), + schedule_to_close_timeout: Some(prost_dur!(from_secs(60))), + heartbeat_timeout: Some(prost_dur!(from_secs(10))), + retry_policy: Some(RetryPolicy { + maximum_attempts: 2, + initial_interval: Some(prost_dur!(from_secs(5))), + ..Default::default() + }), + ..Default::default() + } + .into(), + )) + .await + .unwrap(); + // Poll activity and verify that it's been scheduled + let task = core.poll_activity_task().await.unwrap(); + assert_matches!(task.variant, Some(act_task::Variant::Start(_))); + // heartbeat 1 (will send immediately) + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: task.task_token.clone(), + details: vec!["one".into()], + }); + // heartbeat 2 (would be throttled if not flushed) + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: task.task_token.clone(), + details: vec!["two".into()], + }); + // Complete activity with fail + let failure = Failure::application_failure("activity failed".to_string(), false); + core.complete_activity_task(ActivityTaskCompletion { + task_token: task.task_token, + result: Some(ActivityExecutionResult::fail(failure)), + }) + .await + .unwrap(); + // The activity is still in the pending state since it has retries left + let client = starter.get_client().await; + eventually( + || async { + // Verify pending details has the flushed heartbeat + let details = client + .describe_workflow_execution(starter.get_wf_id().to_string(), None) + .await + .unwrap(); + let last_deets = details + .pending_activities + .into_iter() + .find(|i| i.activity_id == activity_id) + .and_then(|i| i.heartbeat_details); + if last_deets == ["two".into()].into_payloads() { + Ok(()) + } else { + Err("details don't yet match") + } + }, + Duration::from_secs(5), + ) + .await + .unwrap(); + client + .terminate_workflow_execution(task_q, None) + .await + .unwrap(); + drain_pollers_and_shutdown(&core).await; +} + #[tokio::test] async fn one_activity_abandon_cancelled_before_started() { let wf_name = "one_activity_abandon_cancelled_before_started"; diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index f3200c3f2..91eb8fc08 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -258,6 +258,16 @@ async fn nexus_async( let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { let mut nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); + // Verify request header key for timeout exists and is lowercase + if outcome == Outcome::Timeout { + assert!( + nt.request + .as_ref() + .unwrap() + .header + .contains_key("request-timeout") + ); + } let start_req = assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::StartOperation(sr) => sr diff --git a/tests/integ_tests/workflow_tests/patches.rs b/tests/integ_tests/workflow_tests/patches.rs index c812802fb..a66193d31 100644 --- a/tests/integ_tests/workflow_tests/patches.rs +++ b/tests/integ_tests/workflow_tests/patches.rs @@ -24,7 +24,7 @@ use temporal_sdk_core_protos::{ UpsertWorkflowSearchAttributesCommandAttributes, command::Attributes, }, common::v1::ActivityType, - enums::v1::{CommandType, EventType}, + enums::v1::{CommandType, EventType, IndexedValueType}, history::v1::{ ActivityTaskCompletedEventAttributes, ActivityTaskScheduledEventAttributes, ActivityTaskStartedEventAttributes, TimerFiredEventAttributes, @@ -507,6 +507,14 @@ async fn v2_and_v3_changes( && decode_change_marker_details(details).unwrap().1 == dep_flag_expected ); if expected_num_cmds == 3 { + let mut as_payload = [MY_PATCH_ID].as_json_payload().unwrap(); + as_payload.metadata.insert( + "type".to_string(), + IndexedValueType::KeywordList + .as_str_name() + .as_bytes() + .to_vec(), + ); assert_matches!( commands.pop_front().unwrap().attributes.as_ref().unwrap(), Attributes::UpsertWorkflowSearchAttributesCommandAttributes( @@ -514,7 +522,7 @@ async fn v2_and_v3_changes( { search_attributes: Some(attrs) } ) if attrs.indexed_fields.get(VERSION_SEARCH_ATTR_KEY).unwrap() - == &[MY_PATCH_ID].as_json_payload().unwrap() + == &as_payload ); } // The only time the "old" timer should fire is in v2, replaying, without a marker. diff --git a/tests/integ_tests/workflow_tests/resets.rs b/tests/integ_tests/workflow_tests/resets.rs index 6b356da9d..cd6fec88a 100644 --- a/tests/integ_tests/workflow_tests/resets.rs +++ b/tests/integ_tests/workflow_tests/resets.rs @@ -19,6 +19,7 @@ use temporal_sdk_core_protos::{ }, }; use tokio::sync::Notify; +use tonic::IntoRequest; const POST_RESET_SIG: &str = "post-reset"; @@ -65,17 +66,20 @@ async fn reset_workflow() { notify.notified().await; // Do the reset client - .reset_workflow_execution(ResetWorkflowExecutionRequest { - namespace: NAMESPACE.to_owned(), - workflow_execution: Some(WorkflowExecution { - workflow_id: wf_name.to_owned(), - run_id: run_id.clone(), - }), - // End of first WFT - workflow_task_finish_event_id: 4, - request_id: "test-req-id".to_owned(), - ..Default::default() - }) + .reset_workflow_execution( + ResetWorkflowExecutionRequest { + namespace: NAMESPACE.to_owned(), + workflow_execution: Some(WorkflowExecution { + workflow_id: wf_name.to_owned(), + run_id: run_id.clone(), + }), + // End of first WFT + workflow_task_finish_event_id: 4, + request_id: "test-req-id".to_owned(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); @@ -191,16 +195,19 @@ async fn reset_randomseed() { notify.notified().await; // Reset the workflow to be after first timer has fired client - .reset_workflow_execution(ResetWorkflowExecutionRequest { - namespace: NAMESPACE.to_owned(), - workflow_execution: Some(WorkflowExecution { - workflow_id: wf_name.to_owned(), - run_id: run_id.clone(), - }), - workflow_task_finish_event_id: 14, - request_id: "test-req-id".to_owned(), - ..Default::default() - }) + .reset_workflow_execution( + ResetWorkflowExecutionRequest { + namespace: NAMESPACE.to_owned(), + workflow_execution: Some(WorkflowExecution { + workflow_id: wf_name.to_owned(), + run_id: run_id.clone(), + }), + workflow_task_finish_event_id: 14, + request_id: "test-req-id".to_owned(), + ..Default::default() + } + .into_request(), + ) .await .unwrap(); diff --git a/tests/main.rs b/tests/main.rs index d48a68bd5..06eb96972 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -40,6 +40,7 @@ mod integ_tests { operatorservice::v1::CreateNexusEndpointRequest, workflowservice::v1::ListNamespacesRequest, }; + use tonic::IntoRequest; // Create a worker like a bridge would (unwraps aside) #[tokio::test] @@ -68,27 +69,32 @@ mod integ_tests { // Do things with worker or client let _ = retrying_client - .list_namespaces(ListNamespacesRequest::default()) + .list_namespaces(ListNamespacesRequest::default().into_request()) .await; } pub(crate) async fn mk_nexus_endpoint(starter: &mut CoreWfStarter) -> String { let client = starter.get_client().await; let endpoint = format!("mycoolendpoint-{}", rand_6_chars()); - let mut op_client = client.get_client().inner().operator_svc().clone(); + let mut op_client = client.get_client().inner().operator_svc(); op_client - .create_nexus_endpoint(CreateNexusEndpointRequest { - spec: Some(EndpointSpec { - name: endpoint.to_owned(), - description: None, - target: Some(EndpointTarget { - variant: Some(endpoint_target::Variant::Worker(endpoint_target::Worker { - namespace: client.namespace().to_owned(), - task_queue: starter.get_task_queue().to_owned(), - })), + .create_nexus_endpoint( + CreateNexusEndpointRequest { + spec: Some(EndpointSpec { + name: endpoint.to_owned(), + description: None, + target: Some(EndpointTarget { + variant: Some(endpoint_target::Variant::Worker( + endpoint_target::Worker { + namespace: client.namespace(), + task_queue: starter.get_task_queue().to_owned(), + }, + )), + }), }), - }), - }) + } + .into_request(), + ) .await .unwrap(); // Endpoint creation can (as of server 1.25.2 at least) return before they are actually usable. diff --git a/tests/runner.rs b/tests/runner.rs index af763cb4a..5d46fa309 100644 --- a/tests/runner.rs +++ b/tests/runner.rs @@ -2,18 +2,16 @@ #[allow(dead_code)] mod common; +use crate::common::integ_dev_server_config; use anyhow::{anyhow, bail}; use clap::Parser; -use common::{INTEG_SERVER_TARGET_ENV_VAR, SEARCH_ATTR_INT, SEARCH_ATTR_TXT}; +use common::INTEG_SERVER_TARGET_ENV_VAR; use std::{ env, path::{Path, PathBuf}, process::Stdio, }; -use temporal_sdk_core::ephemeral_server::{ - EphemeralExe, EphemeralExeVersion, TemporalDevServerConfigBuilder, TestServerConfigBuilder, - default_cached_download, -}; +use temporal_sdk_core::ephemeral_server::{TestServerConfigBuilder, default_cached_download}; use tokio::{self, process::Command}; /// This env var is set (to any value) if temporal CLI dev server is in use @@ -96,44 +94,10 @@ async fn main() -> Result<(), anyhow::Error> { let (server, envs) = match server_kind { ServerKind::TemporalCLI => { - let cli_version = if let Some(ver_override) = option_env!("CLI_VERSION_OVERRIDE") { - EphemeralExe::CachedDownload { - version: EphemeralExeVersion::Fixed(ver_override.to_owned()), - dest_dir: None, - ttl: None, - } - } else { - default_cached_download() - }; - let config = TemporalDevServerConfigBuilder::default() - .exe(cli_version) - .extra_args(vec![ - // TODO: Delete when temporalCLI enables it by default. - "--dynamic-config-value".to_string(), - "system.enableEagerWorkflowStart=true".to_string(), - "--dynamic-config-value".to_string(), - "system.enableNexus=true".to_string(), - "--dynamic-config-value".to_owned(), - "frontend.workerVersioningWorkflowAPIs=true".to_owned(), - "--dynamic-config-value".to_owned(), - "frontend.workerVersioningDataAPIs=true".to_owned(), - "--dynamic-config-value".to_owned(), - "system.enableDeploymentVersions=true".to_owned(), - "--dynamic-config-value".to_owned(), - "component.nexusoperations.recordCancelRequestCompletionEvents=true".to_owned(), - "--dynamic-config-value".to_owned(), - "frontend.WorkerHeartbeatsEnabled=true".to_owned(), - "--dynamic-config-value".to_owned(), - "frontend.ListWorkersEnabled=true".to_owned(), - "--http-port".to_string(), - "7243".to_string(), - "--search-attribute".to_string(), - format!("{SEARCH_ATTR_TXT}=Text"), - "--search-attribute".to_string(), - format!("{SEARCH_ATTR_INT}=Int"), - ]) - .ui(true) - .build()?; + let config = + integ_dev_server_config(vec!["--http-port".to_string(), "7243".to_string()]) + .ui(true) + .build()?; println!("Using temporal CLI: {config:?}"); ( Some( From eee45b8331e0106acc91e225ee49d331bd4cffaf Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Fri, 17 Oct 2025 13:00:02 -0700 Subject: [PATCH 4/5] Forgot to commit format fixes --- core/src/lib.rs | 7 +++++-- core/src/worker/client.rs | 4 ---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 6c1dc74a8..e306d5f15 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -142,8 +142,11 @@ where rwi.into_core_worker() } -pub(crate) fn init_worker_client(namespace: String, - client_identity_override: Option,, client: CT) -> Client +pub(crate) fn init_worker_client( + namespace: String, + client_identity_override: Option, + client: CT, +) -> Client where CT: Into, { diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index fe67d486a..2fae0309a 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -709,10 +709,6 @@ impl WorkerClient for WorkerClientBag { ) } - fn replace_client(&self, new_client: Client) { - self.client.get_client().replace_client(new_client); - } - async fn record_worker_heartbeat( &self, namespace: String, From c30e25e830d97bcb9e88056c8d0a038d80a29d6a Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Mon, 20 Oct 2025 10:52:28 -0700 Subject: [PATCH 5/5] Fix test --- client/src/raw.rs | 6 ++++++ client/src/worker_registry/mod.rs | 1 + 2 files changed, 7 insertions(+) diff --git a/client/src/raw.rs b/client/src/raw.rs index 0f2c09f6b..655ad3b56 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -1601,6 +1601,7 @@ mod tests { operatorservice::v1::DeleteNamespaceRequest, workflowservice::v1::ListNamespacesRequest, }; use tonic::IntoRequest; + use uuid::Uuid; // Just to help make sure some stuff compiles. Not run. #[allow(dead_code)] @@ -1854,6 +1855,11 @@ mod tests { mock_provider .expect_deployment_options() .return_const(Some(deployment_opts.clone())); + mock_provider.expect_heartbeat_enabled().return_const(false); + let uuid = Uuid::new_v4(); + mock_provider + .expect_worker_instance_key() + .return_const(uuid); let client_worker_set = Arc::new(ClientWorkerSet::new()); client_worker_set diff --git a/client/src/worker_registry/mod.rs b/client/src/worker_registry/mod.rs index 58dc83a11..67b034c22 100644 --- a/client/src/worker_registry/mod.rs +++ b/client/src/worker_registry/mod.rs @@ -461,6 +461,7 @@ mod tests { mock_provider .expect_worker_instance_key() .return_const(worker_instance_key); + mock_provider.expect_deployment_options().return_const(None); if heartbeat_enabled { mock_provider