-
Notifications
You must be signed in to change notification settings - Fork 587
issues with inference the sparse-structure flow model after training #126
Description
Hi TRELLIS Team,
Thank you for sharing this impressive project! I was trying to train the sparse-structure flow model using a subset of 1000 samples from the ABO dataset. I have encountered two points where I would appreciate some guidance:
- After training, the process produces three .pt files:
denoiser_ema0.9999_step0050000.pt (4.9G)
denoiser_step0050000.pt (4.9G)
misc_step0050000.pt (9.7G)
However, the provided ss_flow_img_dit_1_3B_64_bf16.safetensors checkpoint is only 2.5G. Could you please provide instructions on how to properly convert these .pt files into the optimized .safetensors format used in the inference pipeline?
- I attempted to convert the denoiser_ema0.9999_step0050000.pt to a safetensor and re-run the inference procedure, but I encountered the following error during the "Sampling shape SLat" stage. RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. The error occurs at trellis2/modules/sparse/basic.py in __cal_shape, specifically when calling coords[:, 0].max(). It seems like the coordinates tensor might be empty at this stage?
I would appreciate any insights into why this might be happening or if there are specific steps I missed when preparing the trained weights for inference.
Thank you for your help!
Sampling sparse structure: 100%|███████████████████████████████████████████████████████████████████████████████| 12/12 [00:07<00:00, 1.57it/s]
Sampling shape SLat: 0%| | 0/12 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/TRELLIS.2/inference.py", line 28, in <module>
mesh = pipeline.run(image)[0]
File "/envs/trellis_v2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/trellis2_image_to_3d.py", line 574, in run
shape_slat, res = self.sample_shape_slat_cascade(
File "/TRELLIS.2/trellis2/pipelines/trellis2_image_to_3d.py", line 311, in sample_shape_slat_cascade
slat = self.shape_slat_sampler.sample(
File "envs/trellis_v2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/flow_euler.py", line 208, in sample
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs)
File "/envs/trellis_v2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/flow_euler.py", line 121, in sample
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
File "/envs/trellis_v2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/flow_euler.py", line 79, in sample_once
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/flow_euler.py", line 49, in _get_model_prediction
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/guidance_interval_mixin.py", line 11, in _inference_model
return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py", line 15, in _inference_model
pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs)
File "/TRELLIS.2/trellis2/pipelines/samplers/flow_euler.py", line 45, in _inference_model
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
File "/TRELLIS.2/trellis2/modules/sparse/basic.py", line 479, in shape
self._shape = self.__cal_shape(self.feats, self.coords)
File "/TRELLIS.2/trellis2/modules/sparse/basic.py", line 463, in __cal_shape
shape.append(coords[:, 0].max().item() + 1)
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.`