diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 7470df0ead..7281318619 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -34,11 +34,11 @@ def _get_onnx_opset_version(model: ir.Model) -> int | None: return model_version1 or model_version2 -def _set_onnx_opset_version(model: ir.Model, version: int) -> None: - """Set the ONNX opset version imported by the model.""" - if "ai.onnx" in model.opset_imports: - del model.opset_imports["ai.onnx"] - model.opset_imports[""] = version +def _set_onnx_opset_version(model_or_function: ir.Model | ir.Function, version: int) -> None: + """Set the ONNX opset version imported by the model or function.""" + if "ai.onnx" in model_or_function.opset_imports: + del model_or_function.opset_imports["ai.onnx"] + model_or_function.opset_imports[""] = version class VersionConverterError(RuntimeError): @@ -334,6 +334,7 @@ def visit_model(self, model: ir.Model) -> None: self.visit_graph_or_function(model.graph) for function in model.functions.values(): self.visit_graph_or_function(function) + _set_onnx_opset_version(function, self._target_version) _set_onnx_opset_version(model, self._target_version) diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 2b615a8f7f..bf481313f4 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -237,8 +237,11 @@ def test_version_convert_function_nodes(self): version_converter.convert_version(model, target_version=target_version) self.assertEqual(model.opset_imports[""], target_version) - # Verify that nodes inside the function were version-converted + # Verify that the function's opset_imports are updated func = model.functions[("pkg.custom", "dft_func", "")] + self.assertEqual(func.opset_imports[""], target_version) + + # Verify that nodes inside the function were version-converted self.assertEqual(func[0].op_type, "Constant") self.assertEqual(func[0].version, 20) self.assertEqual(func[1].op_type, "Reshape") @@ -293,8 +296,12 @@ def test_version_convert_function_with_control_flow_subgraph(self): version_converter.convert_version(model, target_version=target_version) self.assertEqual(model.opset_imports[""], target_version) - # Verify nodes inside the function's If node subgraphs were version-converted + # Verify that the function's opset_imports are updated func = model.functions[("pkg.custom", "conditional_dft", "")] + self.assertEqual(func.opset_imports[""], target_version) + + # Verify nodes inside the function's If node subgraphs were version-converted + # Verify nodes inside the function's If node subgraphs were version-converted if_node = func[0] self.assertEqual(if_node.op_type, "If") self.assertEqual(if_node.version, 20)