2121from ..types import HistoricalInputQuery , InferenceMode , InferenceResult , ModelOutput , SchedulerParams
2222from ._conversions import convert_array_to_model_output , convert_to_model_input , convert_to_model_output
2323from ._utils import get_abi , get_bin , run_with_retry
24- from .exceptions import OpenGradientError
2524
2625# How much time we wait for txn to be included in chain
2726INFERENCE_TX_TIMEOUT = 120
@@ -91,7 +90,7 @@ def infer(
9190 model_output (Dict[str, np.ndarray]): Output of the ONNX model
9291
9392 Raises:
94- OpenGradientError : If the inference fails.
93+ RuntimeError : If the inference fails.
9594 """
9695
9796 def execute_transaction ():
@@ -106,7 +105,7 @@ def execute_transaction():
106105 tx_hash , tx_receipt = self ._send_tx_with_revert_handling (run_function )
107106 parsed_logs = contract .events .InferenceResult ().process_receipt (tx_receipt , errors = DISCARD )
108107 if len (parsed_logs ) < 1 :
109- raise OpenGradientError ("InferenceResult event not found in transaction logs" )
108+ raise RuntimeError ("InferenceResult event not found in transaction logs" )
110109
111110 # TODO: This should return a ModelOutput class object
112111 model_output = convert_to_model_output (parsed_logs [0 ]["args" ])
@@ -184,7 +183,7 @@ def _get_inference_result_from_node(self, inference_id: str, inference_mode: Inf
184183 Dict: The inference result as returned by the node
185184
186185 Raises:
187- OpenGradientError : If the request fails or returns an error
186+ RuntimeError : If the request fails or returns an error
188187 """
189188 try :
190189 encoded_id = urllib .parse .quote (inference_id , safe = "" )
@@ -199,50 +198,50 @@ def _get_inference_result_from_node(self, inference_id: str, inference_mode: Inf
199198 decoded_string = decoded_bytes .decode ("utf-8" )
200199 output = json .loads (decoded_string ).get ("InferenceResult" , {})
201200 if output is None :
202- raise OpenGradientError ("Missing InferenceResult in inference output" )
201+ raise RuntimeError ("Missing InferenceResult in inference output" )
203202
204203 match inference_mode :
205204 case InferenceMode .VANILLA :
206205 if "VanillaResult" not in output :
207- raise OpenGradientError ("Missing VanillaResult in inference output" )
206+ raise RuntimeError ("Missing VanillaResult in inference output" )
208207 if "model_output" not in output ["VanillaResult" ]:
209- raise OpenGradientError ("Missing model_output in VanillaResult" )
208+ raise RuntimeError ("Missing model_output in VanillaResult" )
210209 return {"output" : output ["VanillaResult" ]["model_output" ]}
211210
212211 case InferenceMode .TEE :
213212 if "TeeNodeResult" not in output :
214- raise OpenGradientError ("Missing TeeNodeResult in inference output" )
213+ raise RuntimeError ("Missing TeeNodeResult in inference output" )
215214 if "Response" not in output ["TeeNodeResult" ]:
216- raise OpenGradientError ("Missing Response in TeeNodeResult" )
215+ raise RuntimeError ("Missing Response in TeeNodeResult" )
217216 if "VanillaResponse" in output ["TeeNodeResult" ]["Response" ]:
218217 if "model_output" not in output ["TeeNodeResult" ]["Response" ]["VanillaResponse" ]:
219- raise OpenGradientError ("Missing model_output in VanillaResponse" )
218+ raise RuntimeError ("Missing model_output in VanillaResponse" )
220219 return {"output" : output ["TeeNodeResult" ]["Response" ]["VanillaResponse" ]["model_output" ]}
221220
222221 else :
223- raise OpenGradientError ("Missing VanillaResponse in TeeNodeResult Response" )
222+ raise RuntimeError ("Missing VanillaResponse in TeeNodeResult Response" )
224223
225224 case InferenceMode .ZKML :
226225 if "ZkmlResult" not in output :
227- raise OpenGradientError ("Missing ZkmlResult in inference output" )
226+ raise RuntimeError ("Missing ZkmlResult in inference output" )
228227 if "model_output" not in output ["ZkmlResult" ]:
229- raise OpenGradientError ("Missing model_output in ZkmlResult" )
228+ raise RuntimeError ("Missing model_output in ZkmlResult" )
230229 return {"output" : output ["ZkmlResult" ]["model_output" ]}
231230
232231 case _:
233- raise OpenGradientError (f"Invalid inference mode: { inference_mode } " )
232+ raise ValueError (f"Invalid inference mode: { inference_mode } " )
234233 else :
235234 return None
236235
237236 else :
238- raise OpenGradientError (f"Failed to get inference result: HTTP { response .status_code } " )
237+ raise RuntimeError (f"Failed to get inference result: HTTP { response .status_code } " )
239238
240239 except requests .RequestException as e :
241- raise OpenGradientError (f"Failed to get inference result: { str (e )} " )
242- except OpenGradientError :
240+ raise RuntimeError (f"Failed to get inference result: { str (e )} " )
241+ except ( RuntimeError , ValueError ) :
243242 raise
244243 except Exception as e :
245- raise OpenGradientError (f"Failed to get inference result: { str (e )} " )
244+ raise RuntimeError (f"Failed to get inference result: { str (e )} " )
246245
247246 def new_workflow (
248247 self ,
0 commit comments