Skip to content

Staircase effect observed during Shape VAE reconstruction #127

@EJShim

Description

@EJShim

Hello,

I am currently testing the Shape VAE reconstruction using the TRELLIS-2 models, but I am encountering a staircase (blocky) effect on the generated surfaces.

Specifically, I compared the two pipelines:

mesh -> ovoxel -> mesh: Direct reconstruction works perfectly and produces a smooth surface.
Image

mesh -> ovoxel -> VAE -> mesh: Passing the data through the VAE encoder and decoder results in severe staircase artifacts, as shown in the attached image.
Image

Here is the inference code I am currently using:

.
.
.
## Load Encoder
    # Download config and safetensors from hf
    encoder_cfg = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_enc_next_dc_f16c32_fp16.json")
    encoder_weights = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_enc_next_dc_f16c32_fp16.safetensors")    
    with open(encoder_cfg) as f:
        encoder_args = json.load(f)["args"]
    encoder = FlexiDualGridVaeEncoder(**encoder_args)
    state_dict = load_file(encoder_weights)
    encoder.load_state_dict(state_dict)
    encoder.eval()
    encoder.cuda()

    ## Load Decoder
    decoder_cfg = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_dec_next_dc_f16c32_fp16.json")
    decoder_weights = hf_hub_download(repo_id="microsoft/TRELLIS.2-4B", filename="ckpts/shape_dec_next_dc_f16c32_fp16.safetensors")    
    with open(decoder_cfg) as f:
        decoder_args = json.load(f)["args"]
    # decoder_args["use_fp16"] = False

        
    decoder = FlexiDualGridVaeDecoder(**decoder_args)
    state_dict = load_file(decoder_weights)
    decoder.load_state_dict(state_dict)
    decoder.eval()
    decoder.cuda()


    RES = 512
    decoder.set_resolution(RES)

.
.
.
      asset = trimesh.load(gt_path)
      mesh = asset.to_mesh()
      mesh = normalize_mesh(mesh)

      # Convert to o-voxel
      vertices = torch.from_numpy(mesh.vertices).float()
      faces = torch.from_numpy(mesh.faces).long()

      voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
          vertices, faces,
          grid_size=RES,                              # Resolution
          aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],      # Axis-aligned bounding box
          face_weight=1.0,                            # Face term weight in QEF
          boundary_weight=0.2,                        # Boundary term weight in QEF
          regularization_weight=1e-2,                 # Regularization term weight in QEF
          timing=True
      )

      coords = voxel_indices

      local_dual_vertices = dual_vertices

      #  Inference VAE
      batch_idx = torch.zeros((coords.shape[0], 1), dtype=torch.int32, device='cuda')
      sp_coords = torch.cat([batch_idx, coords.cuda()], dim=1)

      sp_vertices = sp.SparseTensor(
          feats=local_dual_vertices.cuda().float(),
          coords=sp_coords
      )
      sp_intersected = sp.SparseTensor(
          feats=intersected.cuda().float(), 
          coords=sp_coords
      )        
      with torch.no_grad():            
          latent = encoder(sp_vertices, sp_intersected, sample_posterior=False)        
          recon_output = decoder(latent)            

      output_mesh = recon_output[0]
      
      rec_verts = output_mesh.vertices
      rec_faces = output_mesh.faces        

      recon_mesh = trimesh.Trimesh(vertices=rec_verts.detach().cpu().numpy(), faces=rec_faces.detach().cpu().numpy())
      recon_mesh.export( output_path / f"{gt_path.stem}.obj")

Could you please let me know if I am missing something or if there are any specific modifications needed in my inference code (such as tensor precision, coordinate types, or scaling) to resolve this issue?

Thank you for your time and for sharing this amazing work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions