diff --git a/README.md b/README.md
index 59ae4ef8..ed5f1213 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,5 @@
+# This is a fork of the TRELLIS project. I do not own any of this codebase or copyright. A novel http API was implement in flaskserver.py
+
diff --git a/flaskclient.py b/flaskclient.py
new file mode 100644
index 00000000..0070a1d7
--- /dev/null
+++ b/flaskclient.py
@@ -0,0 +1,267 @@
+import threading
+import time
+
+import requests
+import os
+import uuid
+from tqdm import tqdm
+
+class Trellis3DClient:
+ def __init__(self, base_url='http://localhost:5000'):
+ self.base_url = base_url
+
+ def generate_from_single_image(self, image_path, params=None):
+ """
+ Generate 3D model from a single image
+
+ Args:
+ image_path (str): Path to the input image
+ params (dict): Optional parameters including:
+ - seed (int): Random seed
+ - ss_guidance_strength (float): Guidance strength for sparse structure generation
+ - ss_sampling_steps (int): Sampling steps for sparse structure generation
+ - slat_guidance_strength (float): Guidance strength for structured latent generation
+ - slat_sampling_steps (int): Sampling steps for structured latent generation
+
+ Returns:
+ dict: Response containing session_id and download URLs
+ """
+ if params is None:
+ params = {}
+
+ url = f"{self.base_url}/generate_from_single_image"
+ files = {'image': open(image_path, 'rb')}
+
+ response = requests.post(url, files=files, data=params)
+ return response.json()
+
+ def generate_from_multiple_images(self, image_paths, params=None):
+ """
+ Generate 3D model from multiple images
+
+ Args:
+ image_paths (str or list): Either:
+ - A list of paths to input images, or
+ - A directory path containing images (will load all image files from the directory)
+ params (dict): Optional parameters including:
+ - seed (int): Random seed
+ - ss_guidance_strength (float): Guidance strength for sparse structure generation
+ - ss_sampling_steps (int): Sampling steps for sparse structure generation
+ - slat_guidance_strength (float): Guidance strength for structured latent generation
+ - slat_sampling_steps (int): Sampling steps for structured latent generation
+ - multiimage_algo (str): Algorithm for multi-image generation ('stochastic' or 'multidiffusion')
+
+ Returns:
+ dict: Response containing session_id and download URLs
+ """
+ if params is None:
+ params = {}
+
+ # If image_paths is a directory, find all image files in it
+ if isinstance(image_paths, str) and os.path.isdir(image_paths):
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif')
+ image_paths = [
+ os.path.join(image_paths, f)
+ for f in os.listdir(image_paths)
+ if f.lower().endswith(image_extensions)
+ ]
+ if not image_paths:
+ raise ValueError(f"No image files found in directory: {image_paths}")
+
+ # Ensure image_paths is a list at this point
+ if not isinstance(image_paths, list):
+ raise ValueError("image_paths must be either a list of image paths or a directory path")
+
+ url = f"{self.base_url}/generate_from_multiple_images"
+ files = [('images', (os.path.basename(path), open(path, 'rb'))) for path in image_paths]
+
+ try:
+ response = requests.post(url, files=files, data=params)
+ return response.json()
+ finally:
+ # Ensure all files are closed after the request
+ for _, (_, file_obj) in files:
+ file_obj.close()
+
+ def extract_glb(self, session_id, params=None):
+ """
+ Extract GLB file from generated 3D model
+
+ Args:
+ session_id (str): Session ID from generation step
+ params (dict): Optional parameters including:
+ - mesh_simplify (float): Mesh simplification factor (0.9-0.98)
+ - texture_size (int): Texture resolution (512, 1024, 1536, or 2048)
+
+ Returns:
+ dict: Response containing GLB download URL
+ """
+ if params is None:
+ params = {}
+
+ url = f"{self.base_url}/extract_glb"
+ data = {'session_id': session_id, **params}
+
+ response = requests.post(url, data=data)
+ return response.json()
+
+ def download_file(self, url, save_path=None):
+ """
+ Download a file from the server
+
+ Args:
+ url (str): Full URL to download (from previous responses)
+ save_path (str): Optional path to save the file
+
+ Returns:
+ str: Path where file was saved
+ """
+ if save_path is None:
+ save_path = os.path.basename(url)
+
+ response = requests.get(url, stream=True)
+ with open(save_path, 'wb') as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ return save_path
+
+ def generate_and_download_from_single_image(self, image_path, target_dir=None, params=None):
+ """
+ Generate 3D model from a single image and download all artifacts
+
+ Args:
+ image_path (str): Path to the input image
+ target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid)
+ params (dict): Optional generation parameters
+
+ Returns:
+ dict: Paths to downloaded files
+ """
+ if target_dir is None:
+ target_dir = f"/tmp/{uuid.uuid4()}"
+ os.makedirs(target_dir, exist_ok=True)
+
+ # Generate the 3D model
+ et = start_spinner_thread("Generating 3D model...")
+ gen_result = self.generate_from_single_image(image_path, params)
+
+ # Download preview
+ preview_path = os.path.join(target_dir, 'preview.mp4')
+ self.download_file(
+ f"{self.base_url}{gen_result['preview_url']}",
+ preview_path
+ )
+ stop_spinner_thread(*et)
+
+ # Extract and download GLB
+ et = start_spinner_thread("Extracting 3D model...")
+ glb_result = self.extract_glb(gen_result['session_id'], params)
+ glb_path = os.path.join(target_dir, 'model.glb')
+ self.download_file(
+ f"{self.base_url}{glb_result['glb_url']}",
+ glb_path
+ )
+ stop_spinner_thread(*et)
+
+ return {
+ 'preview_path': preview_path,
+ 'glb_path': glb_path,
+ 'session_id': gen_result['session_id'],
+ 'target_dir': target_dir
+ }
+
+ def generate_and_download_from_multiple_images(self, image_paths, target_dir=None, params=None):
+ """
+ Generate 3D model from multiple images and download all artifacts
+
+ Args:
+ image_paths (list): List of paths to input images
+ target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid)
+ params (dict): Optional generation parameters
+
+ Returns:
+ dict: Paths to downloaded files
+ """
+ if target_dir is None:
+ target_dir = f"/tmp/{uuid.uuid4()}"
+ os.makedirs(target_dir, exist_ok=True)
+
+ # Generate the 3D model
+ et = start_spinner_thread("Generating 3D model...")
+ gen_result = self.generate_from_multiple_images(image_paths, params)
+
+ # Download preview
+ preview_path = os.path.join(target_dir, 'preview.mp4')
+ self.download_file(
+ f"{self.base_url}{gen_result['preview_url']}",
+ preview_path
+ )
+ stop_spinner_thread(*et)
+
+ # Extract and download GLB
+ et = start_spinner_thread("Extracting 3D model...")
+ glb_result = self.extract_glb(gen_result['session_id'], params)
+ glb_path = os.path.join(target_dir, 'model.glb')
+ self.download_file(
+ f"{self.base_url}{glb_result['glb_url']}",
+ glb_path
+ )
+ stop_spinner_thread(*et)
+
+ return {
+ 'preview_path': preview_path,
+ 'glb_path': glb_path,
+ 'session_id': gen_result['session_id'],
+ 'target_dir': target_dir
+ }
+
+def spinner(desc, stop_event):
+ from itertools import cycle
+
+ spinner = cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'])
+ with tqdm(total=None, desc=desc, bar_format='{desc}') as pbar:
+ while not stop_event.is_set():
+ pbar.set_description(f"{desc} - {next(spinner)}")
+ time.sleep(0.1)
+ #pbar.update()
+
+def start_spinner_thread(desc) -> (threading.Event, threading.Thread):
+ stop_event = threading.Event()
+ spinner_thread = threading.Thread(
+ target=spinner,
+ args=(desc, stop_event)
+ )
+ spinner_thread.start()
+
+ return stop_event, spinner_thread
+
+def stop_spinner_thread(stop_event, spinner_thread):
+ stop_event.set()
+ spinner_thread.join(timeout=5)
+
+
+# Example usage
+if __name__ == '__main__':
+ client = Trellis3DClient()
+
+ multi_result = client.generate_and_download_from_multiple_images(
+ '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb',
+ target_dir="./blo",
+ params={
+ 'multiimage_algo': 'stochastic',
+ 'seed': 123
+ }
+ )
+ exit()
+
+
+ multi_result = client.generate_and_download_from_multiple_images(
+ '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb',
+ target_dir="./blo",
+ params={
+ 'multiimage_algo': 'stochastic',
+ 'seed': 123
+ }
+ )
+ exit()
diff --git a/flaskserver.py b/flaskserver.py
new file mode 100644
index 00000000..5d67d450
--- /dev/null
+++ b/flaskserver.py
@@ -0,0 +1,294 @@
+import os
+import uuid
+import shutil
+from flask import Flask, request, jsonify, send_file
+from werkzeug.utils import secure_filename
+from PIL import Image
+import numpy as np
+import torch
+from easydict import EasyDict as edict
+import imageio
+from trellis.pipelines import TrellisImageTo3DPipeline
+from trellis.representations import Gaussian, MeshExtractResult
+from trellis.utils import render_utils, postprocessing_utils
+
+app = Flask(__name__)
+app.config['UPLOAD_FOLDER'] = 'uploads'
+app.config['OUTPUT_FOLDER'] = 'outputs'
+os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
+os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)
+
+# Initialize pipeline
+pipeline = TrellisImageTo3DPipeline.from_pretrained("gqk/TRELLIS-image-large-fork")
+pipeline.cuda()
+
+MAX_SEED = np.iinfo(np.int32).max
+
+def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
+ return {
+ 'gaussian': {
+ **gs.init_params,
+ '_xyz': gs._xyz.cpu().numpy(),
+ '_features_dc': gs._features_dc.cpu().numpy(),
+ '_scaling': gs._scaling.cpu().numpy(),
+ '_rotation': gs._rotation.cpu().numpy(),
+ '_opacity': gs._opacity.cpu().numpy(),
+ },
+ 'mesh': {
+ 'vertices': mesh.vertices.cpu().numpy(),
+ 'faces': mesh.faces.cpu().numpy(),
+ },
+ }
+
+from typing import Tuple
+
+def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
+ gs = Gaussian(
+ aabb=state['gaussian']['aabb'],
+ sh_degree=state['gaussian']['sh_degree'],
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
+ scaling_bias=state['gaussian']['scaling_bias'],
+ opacity_bias=state['gaussian']['opacity_bias'],
+ scaling_activation=state['gaussian']['scaling_activation'],
+ )
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
+
+ mesh = edict(
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
+ )
+
+ return gs, mesh
+
+def preprocess_image(image_path: str) -> Image.Image:
+ image = Image.open(image_path)
+ processed_image = pipeline.preprocess_image(image)
+ return processed_image
+
+@app.route('/generate_from_single_image', methods=['POST'])
+def generate_from_single_image():
+ try:
+ # Get parameters
+ seed = int(request.form.get('seed', 0))
+ ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5))
+ ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12))
+ slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0))
+ slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12))
+ preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true'
+
+ # Handle file upload
+ if 'image' not in request.files:
+ return jsonify({'error': 'No image provided'}), 400
+
+ file = request.files['image']
+ if file.filename == '':
+ return jsonify({'error': 'No selected file'}), 400
+
+ # Create session directory
+ session_id = str(uuid.uuid4())
+ session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id)
+ os.makedirs(session_dir, exist_ok=True)
+
+ # Save uploaded file
+ filename = secure_filename(file.filename)
+ image_path = os.path.join(session_dir, filename)
+ file.save(image_path)
+
+ # Preprocess image
+
+ # Generate 3D model
+ outputs = pipeline.run(
+ Image.open(image_path),
+ seed=seed,
+ formats=["gaussian", "mesh"],
+ preprocess_image=preprocess_image,
+ sparse_structure_sampler_params={
+ "steps": ss_sampling_steps,
+ "cfg_strength": ss_guidance_strength,
+ },
+ slat_sampler_params={
+ "steps": slat_sampling_steps,
+ "cfg_strength": slat_guidance_strength,
+ },
+ )
+
+ # Create output directory
+ output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Save video preview
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
+ video_path = os.path.join(output_dir, 'preview.mp4')
+ imageio.mimsave(video_path, video, fps=15)
+
+ # Save state
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
+ state_path = os.path.join(output_dir, 'state.pkl')
+ torch.save(state, state_path)
+
+ torch.cuda.empty_cache()
+
+ return jsonify({
+ 'session_id': session_id,
+ 'preview_url': f'/download/preview/{session_id}',
+ 'state_url': f'/download/state/{session_id}'
+ }), 200
+
+ except Exception as e:
+ return jsonify({'error': str(e)}), 500
+
+@app.route('/generate_from_multiple_images', methods=['POST'])
+def generate_from_multiple_images():
+ try:
+ # Get parameters
+ seed = int(request.form.get('seed', 0))
+ ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5))
+ ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12))
+ slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0))
+ slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12))
+ multiimage_algo = request.form.get('multiimage_algo', 'stochastic')
+ preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true'
+
+ # Handle file uploads
+ if 'images' not in request.files:
+ return jsonify({'error': 'No images provided'}), 400
+
+ files = request.files.getlist('images')
+ if len(files) == 0:
+ return jsonify({'error': 'No selected files'}), 400
+
+ # Create session directory
+ session_id = str(uuid.uuid4())
+ session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id)
+ os.makedirs(session_dir, exist_ok=True)
+
+ # Save uploaded files
+ images = []
+ for file in files:
+ if file.filename == '':
+ continue
+ filename = secure_filename(file.filename)
+ image_path = os.path.join(session_dir, filename)
+ file.save(image_path)
+ images.append(Image.open(image_path))
+
+ # Generate 3D model
+ outputs = pipeline.run_multi_image(
+ images,
+ seed=seed,
+ formats=["gaussian", "mesh"],
+ preprocess_image=preprocess_image,
+ sparse_structure_sampler_params={
+ "steps": ss_sampling_steps,
+ "cfg_strength": ss_guidance_strength,
+ },
+ slat_sampler_params={
+ "steps": slat_sampling_steps,
+ "cfg_strength": slat_guidance_strength,
+ },
+ mode=multiimage_algo,
+ )
+
+ # Create output directory
+ output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Save video preview
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
+ video_path = os.path.join(output_dir, 'preview.mp4')
+ imageio.mimsave(video_path, video, fps=15)
+
+ # Save state
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
+ state_path = os.path.join(output_dir, 'state.pkl')
+ torch.save(state, state_path)
+
+ torch.cuda.empty_cache()
+
+ return jsonify({
+ 'session_id': session_id,
+ 'preview_url': f'/download/preview/{session_id}',
+ 'state_url': f'/download/state/{session_id}'
+ }), 200
+
+ except Exception as e:
+ return jsonify({'error': str(e)}), 500
+
+@app.route('/extract_glb', methods=['POST'])
+def extract_glb():
+ try:
+ # Get parameters
+ session_id = request.form.get('session_id')
+ mesh_simplify = float(request.form.get('mesh_simplify', 0.95))
+ texture_size = int(request.form.get('texture_size', 1024))
+
+ if not session_id:
+ return jsonify({'error': 'session_id is required'}), 400
+
+ # Load state
+ state_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'state.pkl')
+ if not os.path.exists(state_path):
+ return jsonify({'error': 'Invalid session_id'}), 404
+
+ state = torch.load(state_path)
+ gs, mesh = unpack_state(state)
+
+ # Extract GLB
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
+
+ # Save GLB
+ glb_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'model.glb')
+ glb.export(glb_path)
+
+ torch.cuda.empty_cache()
+
+ return jsonify({
+ 'glb_url': f'/download/glb/{session_id}'
+ }), 200
+
+ except Exception as e:
+ return jsonify({'error': str(e)}), 500
+
+@app.route('/download/preview/