From 12531734b59742d6b881aa3e4acb7e042b3c4a8e Mon Sep 17 00:00:00 2001 From: cl <2020334843@qq.com> Date: Thu, 12 Mar 2026 10:23:42 +0800 Subject: [PATCH] Use ChangePrefixMatcher for torch.diagflat --- paconvert/api_mapping.json | 11 +---------- tests/test_diagflat.py | 25 +++++++++++++++++++++---- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 75f94adda..f41725b35 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -4248,16 +4248,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.diagflat": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.diagflat", - "min_input_args": 1, - "args_list": [ - "input", - "offset" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.diagonal": { "Matcher": "ChangePrefixMatcher" diff --git a/tests/test_diagflat.py b/tests/test_diagflat.py index 8678d599b..008cb9de9 100644 --- a/tests/test_diagflat.py +++ b/tests/test_diagflat.py @@ -60,13 +60,19 @@ def test_case_4(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([[-0.4264, 0.0255,-0.1064], - [ 0.8795,-0.2429, 0.1374], - [ 0.1029,-0.6482,-1.6300]]) + x = torch.tensor([1, 2, 3]) result = torch.diagflat(input=x, offset=-3) """ ) - obj.run(pytorch_code, ["result"]) + expect_paddle_code = textwrap.dedent( + """ + import paddle + + x = paddle.tensor([1, 2, 3]) + result = paddle.diagflat(input=x, offset=-3) + """ + ) + obj.run(pytorch_code, expect_paddle_code=expect_paddle_code) # generated by validate_unittest autofix, based on test_case_4 @@ -81,3 +87,14 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor(7) + result = torch.diagflat(input=x) + """ + ) + obj.run(pytorch_code, ["result"])