diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 236aa3614..0abe4b06a 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -10833,16 +10833,7 @@ } }, "torch.poisson": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.poisson", - "min_input_args": 1, - "args_list": [ - "input", - "generator" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.poisson_nll_loss": {}, "torch.polar": { diff --git a/tests/test_poisson.py b/tests/test_poisson.py index cbf658441..49a274011 100644 --- a/tests/test_poisson.py +++ b/tests/test_poisson.py @@ -93,3 +93,27 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + rates = torch.linspace(0.5, 4.5, 5, dtype=torch.float64) + result = torch.poisson(rates) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + rates = torch.rand(2, 3) * 10 + gen = torch.Generator() + args = (rates, gen) + result = torch.poisson(*args) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False)