diff --git a/supar/structs/fn.py b/supar/structs/fn.py index 4bc689e2..3c44d12a 100644 --- a/supar/structs/fn.py +++ b/supar/structs/fn.py @@ -297,7 +297,7 @@ class Logsumexp(Function): """ @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + @torch.amp.custom_fwd(cast_inputs=torch.float, device_type='cuda') def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: output = x.logsumexp(dim) ctx.dim = dim @@ -305,7 +305,7 @@ def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: return output.clone() @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: x, output, dim = *ctx.saved_tensors, ctx.dim g, output = g.unsqueeze(dim), output.unsqueeze(dim) @@ -317,14 +317,14 @@ def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: class Logaddexp(Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + @torch.amp.custom_fwd(cast_inputs=torch.float, device_type='cuda') def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = torch.logaddexp(x, y) ctx.save_for_backward(x, y, output) return output.clone() @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: x, y, output = ctx.saved_tensors mask = g.eq(0) @@ -337,14 +337,14 @@ def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: class SampledLogsumexp(Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + @torch.amp.custom_fwd(cast_inputs=torch.float, device_type='cuda') def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: ctx.dim = dim ctx.save_for_backward(x) return x.logsumexp(dim=dim) @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: from torch.distributions import OneHotCategorical (x, ), dim = ctx.saved_tensors, ctx.dim @@ -354,7 +354,7 @@ def backward(ctx, g: torch.Tensor) -> Union[torch.Tensor, None]: class Sparsemax(Function): @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float) + @torch.amp.custom_fwd(cast_inputs=torch.float, device_type='cuda') def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: ctx.dim = dim sorted_x, _ = x.sort(dim, True) @@ -367,7 +367,7 @@ def forward(ctx, x: torch.Tensor, dim: int = -1) -> torch.Tensor: return p @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, g: torch.Tensor) -> Tuple[torch.Tensor, None]: k, p, dim = *ctx.saved_tensors, ctx.dim grad = g.masked_fill(p.eq(0), 0)