File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def _placement_axis_in_mesh(
6363 mesh : jax .sharding .Mesh | jax .sharding .AbstractMesh | None ,
6464 placement : str ,
6565) -> bool :
66- """Checks if a clients axis is present in the mesh."""
66+ """Checks if a placement axis is present in the mesh."""
6767 if mesh is None :
6868 return False
6969 placement_is_in_mesh = placement in mesh .axis_names
@@ -146,7 +146,7 @@ def broadcast_to_placement(
146146 if _placement_axis_in_mesh (mesh , placement ):
147147 pspec = P (placement , * ([P .UNCONSTRAINED ] * len (arg .shape )))
148148 else :
149- # Without a clients axis in the mesh, we simply explicitly tell the
149+ # Without a placement axis in the mesh, we simply explicitly tell the
150150 # compiler that there are no constraints on this tensor. This will leave
151151 # the choices in the hands of the compiler.
152152 pspec = P (* ([P .UNCONSTRAINED ] * (len (arg .shape ) + 1 )))
You can’t perform that action at this time.
0 commit comments