diff --git a/src/sharp/models/encoders/spn_encoder.py b/src/sharp/models/encoders/spn_encoder.py index 324a3595..d0e6c078 100644 --- a/src/sharp/models/encoders/spn_encoder.py +++ b/src/sharp/models/encoders/spn_encoder.py @@ -255,7 +255,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: patch_intermediate_features[self.patch_intermediate_features_ids[0]] # type:ignore[index] ) x_latent0_features = merge( - x_latent0_encodings[: batch_size * x0_tile_size], + x_latent0_encodings[: x0_tile_size], batch_size=batch_size, padding=padding, ) @@ -264,7 +264,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: patch_intermediate_features[self.patch_intermediate_features_ids[1]] # type:ignore[index] ) x_latent1_features = merge( - x_latent1_encodings[: batch_size * x0_tile_size], + x_latent1_encodings[: x0_tile_size], batch_size=batch_size, padding=padding, )