@@ -191,6 +191,7 @@ async def wait(
191191 task : Union [str , DashScopeAPIResponse ],
192192 api_key : str = None ,
193193 workspace : str = None ,
194+ wait_timeout : int = - 1 ,
194195 ** kwargs ,
195196 ) -> DashScopeAPIResponse :
196197 """Wait for async task completion and return task result.
@@ -199,6 +200,12 @@ async def wait(
199200 task (Union[str, DashScopeAPIResponse]): The task_id, or
200201 async_call response.
201202 api_key (str, optional): The api_key. Defaults to None.
203+ workspace (str, optional): The dashscope workspace id.
204+ wait_timeout (int, optional): The maximum seconds to wait
205+ for the task to complete. Default is -1, which means no
206+ timeout. When set to a value > 0, if the task does not
207+ complete within this time, a timeout error response will
208+ be returned.
202209
203210 Returns:
204211 DashScopeAPIResponse: The async task information.
@@ -208,6 +215,7 @@ async def wait(
208215 max_wait_seconds = 5
209216 increment_steps = 3
210217 step = 0
218+ start_time = time .time ()
211219 while True :
212220 step += 1
213221 # we start by querying once every second, and double
@@ -217,6 +225,21 @@ async def wait(
217225 # (server side return immediately when ready)
218226 if wait_seconds < max_wait_seconds and step % increment_steps == 0 :
219227 wait_seconds = min (wait_seconds * 2 , max_wait_seconds )
228+ if 0 < wait_timeout <= (time .time () - start_time ):
229+ logger .warning (
230+ "Wait task: %s timeout after %s seconds." ,
231+ task_id ,
232+ wait_timeout ,
233+ )
234+ return DashScopeAPIResponse (
235+ request_id = task_id ,
236+ status_code = HTTPStatus .REQUEST_TIMEOUT ,
237+ code = "WaitTaskTimeout" ,
238+ message = (
239+ f"Wait task: { task_id } timeout after "
240+ f"{ wait_timeout } seconds."
241+ ),
242+ )
220243 rsp = await cls ._get (
221244 task_id ,
222245 api_key ,
@@ -600,6 +623,10 @@ def call(
600623 ** kwargs ,
601624 ) -> DashScopeAPIResponse :
602625 """Call service and get result."""
626+ wait_timeout = - 1
627+ if "wait_timeout" in kwargs :
628+ wait_timeout = kwargs .pop ("wait_timeout" )
629+
603630 task_response = cls .async_call ( # type: ignore[misc]
604631 * args ,
605632 api_key = api_key ,
@@ -610,6 +637,7 @@ def call(
610637 task_response ,
611638 api_key = api_key ,
612639 workspace = workspace ,
640+ wait_timeout = wait_timeout ,
613641 )
614642 return response
615643
@@ -767,6 +795,7 @@ def wait(
767795 task : Union [str , DashScopeAPIResponse ],
768796 api_key : str = None ,
769797 workspace : str = None ,
798+ wait_timeout : int = - 1 ,
770799 ** kwargs ,
771800 ) -> DashScopeAPIResponse :
772801 """Wait for async task completion and return task result.
@@ -775,6 +804,12 @@ def wait(
775804 task (Union[str, DashScopeAPIResponse]): The task_id, or
776805 async_call response.
777806 api_key (str, optional): The api_key. Defaults to None.
807+ workspace (str, optional): The dashscope workspace id.
808+ wait_timeout (int, optional): The maximum seconds to wait
809+ for the task to complete. Default is -1, which means no
810+ timeout. When set to a value > 0, if the task does not
811+ complete within this time, a timeout error response will
812+ be returned.
778813
779814 Returns:
780815 DashScopeAPIResponse: The async task information.
@@ -784,6 +819,7 @@ def wait(
784819 max_wait_seconds = 5
785820 increment_steps = 3
786821 step = 0
822+ start_time = time .time ()
787823 while True :
788824 step += 1
789825 # we start by querying once every second, and double
@@ -794,6 +830,21 @@ def wait(
794830 # (server side return immediately when ready)
795831 if wait_seconds < max_wait_seconds and step % increment_steps == 0 :
796832 wait_seconds = min (wait_seconds * 2 , max_wait_seconds )
833+ if 0 < wait_timeout <= (time .time () - start_time ):
834+ logger .warning (
835+ "Wait task: %s timeout after %s seconds." ,
836+ task_id ,
837+ wait_timeout ,
838+ )
839+ return DashScopeAPIResponse (
840+ request_id = task_id ,
841+ status_code = HTTPStatus .REQUEST_TIMEOUT ,
842+ code = "WaitTaskTimeout" ,
843+ message = (
844+ f"Wait task: { task_id } timeout after "
845+ f"{ wait_timeout } seconds."
846+ ),
847+ )
797848 rsp = cls ._get (task_id , api_key , workspace = workspace , ** kwargs )
798849 if rsp .status_code == HTTPStatus .OK :
799850 if rsp .output is None :
0 commit comments