In flow_matching/tests/solver/test_ode_solver.py,
in 23th line,
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + self.a
for enhancing readability, i suggest
return x * 0.0 + self.a -> return torch.ones_like(x) * self.a
In flow_matching/tests/solver/test_ode_solver.py,
in 23th line,
for enhancing readability, i suggest