-
Notifications
You must be signed in to change notification settings - Fork 587
Staircase effect observed during Shape VAE reconstruction #127
Description
Hello,
I am currently testing the Shape VAE reconstruction using the TRELLIS-2 models, but I am encountering a staircase (blocky) effect on the generated surfaces.
Specifically, I compared the two pipelines:
mesh -> ovoxel -> mesh: Direct reconstruction works perfectly and produces a smooth surface.

mesh -> ovoxel -> VAE -> mesh: Passing the data through the VAE encoder and decoder results in severe staircase artifacts, as shown in the attached image.

Here is the inference code I am currently using:
.
.
.
## Load Encoder
# Download config and safetensors from hf
encoder_cfg = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_enc_next_dc_f16c32_fp16.json")
encoder_weights = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_enc_next_dc_f16c32_fp16.safetensors")
with open(encoder_cfg) as f:
encoder_args = json.load(f)["args"]
encoder = FlexiDualGridVaeEncoder(**encoder_args)
state_dict = load_file(encoder_weights)
encoder.load_state_dict(state_dict)
encoder.eval()
encoder.cuda()
## Load Decoder
decoder_cfg = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_dec_next_dc_f16c32_fp16.json")
decoder_weights = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_dec_next_dc_f16c32_fp16.safetensors")
with open(decoder_cfg) as f:
decoder_args = json.load(f)["args"]
# decoder_args["use_fp16"] = False
decoder = FlexiDualGridVaeDecoder(**decoder_args)
state_dict = load_file(decoder_weights)
decoder.load_state_dict(state_dict)
decoder.eval()
decoder.cuda()
RES = 512
decoder.set_resolution(RES)
.
.
.
asset = trimesh.load(gt_path)
mesh = asset.to_mesh()
mesh = normalize_mesh(mesh)
# Convert to o-voxel
vertices = torch.from_numpy(mesh.vertices).float()
faces = torch.from_numpy(mesh.faces).long()
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
vertices, faces,
grid_size=RES, # Resolution
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], # Axis-aligned bounding box
face_weight=1.0, # Face term weight in QEF
boundary_weight=0.2, # Boundary term weight in QEF
regularization_weight=1e-2, # Regularization term weight in QEF
timing=True
)
coords = voxel_indices
local_dual_vertices = dual_vertices
# Inference VAE
batch_idx = torch.zeros((coords.shape[0], 1), dtype=torch.int32, device='cuda')
sp_coords = torch.cat([batch_idx, coords.cuda()], dim=1)
sp_vertices = sp.SparseTensor(
feats=local_dual_vertices.cuda().float(),
coords=sp_coords
)
sp_intersected = sp.SparseTensor(
feats=intersected.cuda().float(),
coords=sp_coords
)
with torch.no_grad():
latent = encoder(sp_vertices, sp_intersected, sample_posterior=False)
recon_output = decoder(latent)
output_mesh = recon_output[0]
rec_verts = output_mesh.vertices
rec_faces = output_mesh.faces
recon_mesh = trimesh.Trimesh(vertices=rec_verts.detach().cpu().numpy(), faces=rec_faces.detach().cpu().numpy())
recon_mesh.export( output_path / f"{gt_path.stem}.obj")Could you please let me know if I am missing something or if there are any specific modifications needed in my inference code (such as tensor precision, coordinate types, or scaling) to resolve this issue?
Thank you for your time and for sharing this amazing work!