@@ -1428,16 +1428,18 @@ def kabsch_rotation(
14281428 Returns:
14291429 r (torch.Tensor): rotation matrix with shape `(3, 3)`
14301430 """
1431- assert x .shape == y .shape
1431+ assert x .shape == y .shape and x . shape [ - 1 ] == 3
14321432
14331433 if exists (mask ):
1434+ assert len (x .shape ) == 2
14341435 x , y = x [mask > 0 , :], y [mask > 0 , :]
14351436
14361437 with autocast (enabled = False ):
14371438 x , y = x .float (), y .float ()
14381439
14391440 # optimal rotation matrix via SVD of the convariance matrix {x.T * y}
1440- v , _ , w = torch .linalg .svd (x .T @ y )
1441+ # v, _, w = torch.linalg.svd(x.T @ y)
1442+ v , _ , w = torch .linalg .svd (torch .einsum ('... i c,... i d -> ... c d' , x , y ))
14411443
14421444 # determinant sign for direction correction
14431445 d = torch .sign (torch .det (v ) * torch .det (w ))
@@ -1460,14 +1462,14 @@ def kabsch_transform(
14601462 """
14611463 assert x .shape == y .shape
14621464
1463- R = kabsch_rotation (x , y , mask = mask ) # pylint: disable=invalid-name
1464-
14651465 if exists (mask ):
1466- x_center = masked_mean (value = x , mask = mask , dim = - 2 , keepdim = True )
1467- y_center = masked_mean (value = y , mask = mask , dim = - 2 , keepdim = True )
1466+ x_center = masked_mean (value = x , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1467+ y_center = masked_mean (value = y , mask = mask [..., None ] , dim = - 2 , keepdim = True )
14681468 else :
14691469 x_center = torch .mean (x , dim = - 2 , keepdim = True )
14701470 y_center = torch .mean (y , dim = - 2 , keepdim = True )
1471+
1472+ R = kabsch_rotation (x - x_center , y - y_center , mask = mask ) # pylint: disable=invalid-name
14711473 t = x_center - torch .einsum ('... h w, ... w -> ... h' , R , y_center )
14721474
14731475 return R , t
@@ -1478,8 +1480,8 @@ def kabsch_align(x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor]
14781480 """
14791481 # center x and y to the origin
14801482 if exists (mask ):
1481- x_ = x - masked_mean (value = x , mask = mask , dim = - 2 , keepdim = True )
1482- y_ = y - masked_mean (value = y , mask = mask , dim = - 2 , keepdim = True )
1483+ x_ = x - masked_mean (value = x , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1484+ y_ = y - masked_mean (value = y , mask = mask [..., None ] , dim = - 2 , keepdim = True )
14831485 else :
14841486 x_ = x - x .mean (dim = - 2 , keepdim = True )
14851487 y_ = y - y .mean (dim = - 2 , keepdim = True )
0 commit comments