prithivMLmods commited on
Commit
78bde42
·
verified ·
1 Parent(s): a2d96d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -221
app.py CHANGED
@@ -1,31 +1,37 @@
1
  import os
2
- import shutil
3
  import cv2
 
 
 
 
4
  import torch
5
  import numpy as np
 
 
 
 
6
  from PIL import Image
7
- import base64
8
- import io
9
- import tempfile
10
- from typing import *
11
  from datetime import datetime
12
- from pathlib import Path
13
 
14
- # --- Environment Setup (Must be before other imports) ---
15
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
16
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
17
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
18
- # Set autotune cache relative to this file
19
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
 
 
 
 
20
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
21
 
22
- # --- Third Party Imports ---
23
- import gradio as gr
24
- from gradio_client import Client, handle_file
25
- import spaces
26
  from diffusers import ZImagePipeline
27
 
28
- # --- TRELLIS Specific Imports ---
29
  from trellis2.modules.sparse import SparseTensor
30
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
31
  from trellis2.renderers import EnvMap
@@ -33,13 +39,13 @@ from trellis2.utils import render_utils
33
  import o_voxel
34
 
35
  # ==========================================
36
- # Global Configuration & Assets
37
  # ==========================================
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
 
42
- # TRELLIS Render Modes
43
  MODES = [
44
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
45
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
@@ -52,57 +58,71 @@ STEPS = 8
52
  DEFAULT_MODE = 3
53
  DEFAULT_STEP = 3
54
 
55
- # ==========================================
56
- # CSS & JavaScript (For Custom Previewer)
57
- # ==========================================
58
-
59
- css = """
60
- /* Overwrite Gradio Default Style */
61
  .stepper-wrapper { padding: 0; }
62
  .stepper-container { padding: 0; align-items: center; }
63
  .step-button { flex-direction: row; }
64
- .step-connector { transform: none; }
65
- .step-number { width: 16px; height: 16px; }
66
- .step-label { position: relative; bottom: 0; }
67
- .wrap.center.full { inset: 0; height: 100%; }
68
- .wrap.center.full.translucent { background: var(--block-background-fill); }
69
- .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
70
-
71
- /* Previewer */
72
  .previewer-container {
73
- position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
74
- width: 100%; height: 722px; margin: 0 auto; padding: 20px;
75
- display: flex; flex-direction: column; align-items: center; justify-content: center;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  }
77
- .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
78
- .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
79
- .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
80
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
81
-
82
- /* Row 1: Display Modes */
83
- .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
84
- .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
 
 
 
 
85
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
86
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
87
-
88
- /* Row 2: Display Image */
89
- .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
90
- .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
 
 
 
91
  .previewer-container .previewer-main-image.visible { display: block; }
92
-
93
- /* Row 3: Custom HTML Slider */
94
- .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
95
- .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
96
- .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
97
- .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
98
- .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
99
-
100
- /* Overwrite Previewer Block Style */
 
 
 
 
 
101
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
102
- .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
103
  """
104
 
105
- head = """
106
  <script>
107
  function refreshView(mode, step) {
108
  const allImgs = document.querySelectorAll('.previewer-main-image');
@@ -120,7 +140,6 @@ head = """
120
  const targetId = 'view-m' + mode + '-s' + step;
121
  const targetImg = document.getElementById(targetId);
122
  if (targetImg) targetImg.classList.add('visible');
123
-
124
  const allBtns = document.querySelectorAll('.mode-btn');
125
  allBtns.forEach((btn, idx) => {
126
  if (idx === mode) btn.classList.add('active');
@@ -132,10 +151,11 @@ head = """
132
  </script>
133
  """
134
 
135
- empty_html = f"""
136
  <div class="previewer-container">
137
- <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
138
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
 
139
  </div>
140
  """
141
 
@@ -145,43 +165,54 @@ empty_html = f"""
145
 
146
  print("Initializing models...")
147
 
148
- # 1. Z-Image-Turbo (Text to Image)
149
  print("Loading Z-Image-Turbo...")
150
  try:
151
- z_pipe = ZImagePipeline.from_pretrained(
152
  "Tongyi-MAI/Z-Image-Turbo",
153
  torch_dtype=torch.bfloat16,
154
  low_cpu_mem_usage=False,
155
  )
156
  device = "cuda" if torch.cuda.is_available() else "cpu"
157
- z_pipe.to(device)
158
  print("Z-Image-Turbo loaded.")
159
  except Exception as e:
160
  print(f"Failed to load Z-Image-Turbo: {e}")
161
- z_pipe = None
162
 
163
- # 2. TRELLIS.2 (Image to 3D)
164
  print("Loading TRELLIS.2...")
165
- # Initialize on startup
166
- trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
167
- trellis_pipeline.rembg_model = None
168
- trellis_pipeline.low_vram = False
169
- trellis_pipeline.cuda()
170
-
171
- # 3. Background Remover
172
- rmbg_client = Client("briaai/BRIA-RMBG-2.0")
 
173
 
174
- # 4. HDRI Maps for TRELLIS
175
- envmap = {}
176
- # Try to load assets, handle gracefully if running in a basic environment
177
  try:
178
- envmap = {
179
- 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
180
- 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
181
- 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
182
- }
183
  except Exception as e:
184
- print(f"Warning: Could not load HDRI assets. {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # ==========================================
187
  # Helper Functions
@@ -199,48 +230,90 @@ def start_session(req: gr.Request):
199
  os.makedirs(user_dir, exist_ok=True)
200
 
201
  def end_session(req: gr.Request):
202
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
203
- if os.path.exists(user_dir):
204
  shutil.rmtree(user_dir)
 
 
205
 
206
  def remove_background(input: Image.Image) -> Image.Image:
207
- with tempfile.NamedTemporaryFile(suffix='.png') as f:
 
 
 
 
208
  input = input.convert('RGB')
209
  input.save(f.name)
210
- output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
211
- output = Image.open(output)
 
 
 
212
  return output
 
 
 
213
 
214
  def preprocess_image(input: Image.Image) -> Image.Image:
215
- """Preprocess the input image: remove bg, crop, resize."""
 
 
 
 
216
  has_alpha = False
217
  if input.mode == 'RGBA':
218
  alpha = np.array(input)[:, :, 3]
219
  if not np.all(alpha == 255):
220
  has_alpha = True
 
 
221
  max_size = max(input.size)
222
  scale = min(1, 1024 / max_size)
223
  if scale < 1:
224
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
 
 
225
  if has_alpha:
226
  output = input
227
  else:
228
  output = remove_background(input)
229
-
 
230
  output_np = np.array(output)
231
- alpha = output_np[:, :, 3]
232
- bbox = np.argwhere(alpha > 0.8 * 255)
233
- if bbox.size == 0:
234
- return output # Return original if empty
235
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
236
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
237
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
238
- size = int(size * 1)
239
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
240
- output = output.crop(bbox)
241
- output = np.array(output).astype(np.float32) / 255
242
- output = output[:, :, :3] * output[:, :, 3:4]
243
- output = Image.fromarray((output * 255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  return output
245
 
246
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
@@ -264,34 +337,30 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
264
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
265
 
266
  # ==========================================
267
- # Inference Logic
268
  # ==========================================
269
 
270
  @spaces.GPU()
271
- def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)):
272
- """Generate Image using Z-Image Turbo"""
273
- if z_pipe is None:
274
- raise gr.Error("Z-Image-Turbo model failed to load.")
275
  if not prompt.strip():
276
- raise gr.Error("Please enter a prompt.")
277
-
278
  device = "cuda" if torch.cuda.is_available() else "cpu"
279
- generator = torch.Generator(device).manual_seed(42) # Or random
280
 
281
- progress(0.1, desc="Generating Text-to-Image...")
282
- try:
283
- result = z_pipe(
284
- prompt=prompt,
285
- negative_prompt=None,
286
- height=1024,
287
- width=1024,
288
- num_inference_steps=9,
289
- guidance_scale=0.0,
290
- generator=generator,
291
- )
292
- return result.images[0]
293
- except Exception as e:
294
- raise gr.Error(f"Z-Image Generation failed: {str(e)}")
295
 
296
  @spaces.GPU(duration=120)
297
  def image_to_3d(
@@ -312,16 +381,24 @@ def image_to_3d(
312
  tex_slat_rescale_t: float,
313
  req: gr.Request,
314
  progress=gr.Progress(track_tqdm=True),
315
- ) -> str:
316
 
 
 
 
317
  if image is None:
318
  raise gr.Error("Input image is missing.")
319
 
 
 
 
 
 
320
  # --- Sampling ---
321
- outputs, latents = trellis_pipeline.run(
322
  image,
323
  seed=seed,
324
- preprocess_image=False, # We pre-process in the upload handler or assume clean input
325
  sparse_structure_sampler_params={
326
  "steps": ss_sampling_steps,
327
  "guidance_strength": ss_guidance_strength,
@@ -340,45 +417,71 @@ def image_to_3d(
340
  "guidance_rescale": tex_slat_guidance_rescale,
341
  "rescale_t": tex_slat_rescale_t,
342
  },
343
- pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution],
 
 
 
 
344
  return_latent=True,
345
  )
 
346
  mesh = outputs[0]
347
  mesh.simplify(16777216)
348
 
349
  # Render Preview
350
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
351
  state = pack_state(latents)
352
  torch.cuda.empty_cache()
353
 
354
  # --- HTML Construction ---
355
  images_html = ""
356
  for m_idx, mode in enumerate(MODES):
 
357
  for s_idx in range(STEPS):
358
  unique_id = f"view-m{m_idx}-s{s_idx}"
359
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
360
  vis_class = "visible" if is_visible else ""
361
- img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
362
- images_html += f"""<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">"""
 
 
 
 
 
 
363
 
364
  btns_html = ""
365
- for idx, mode in enumerate(MODES):
 
366
  active_class = "active" if idx == DEFAULT_MODE else ""
367
- btns_html += f"""<img src="{mode['icon_base64']}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode['name']}">"""
 
 
 
 
 
368
 
369
  full_html = f"""
370
  <div class="previewer-container">
371
  <div class="tips-wrapper">
372
  <div class="tips-icon">💡Tips</div>
373
- <div class="tips-text"><p>● <b>Render Mode</b> - Click buttons to switch modes.</p><p>● <b>View Angle</b> - Drag slider to rotate.</p></div>
 
 
 
 
 
 
 
 
 
374
  </div>
375
- <div class="display-row">{images_html}</div>
376
- <div class="mode-row" id="btn-group">{btns_html}</div>
377
  <div class="slider-row">
378
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
379
  </div>
380
  </div>
381
  """
 
382
  return state, full_html
383
 
384
  @spaces.GPU(duration=120)
@@ -389,12 +492,10 @@ def extract_glb(
389
  req: gr.Request,
390
  progress=gr.Progress(track_tqdm=True),
391
  ) -> Tuple[str, str]:
392
- if state is None:
393
- raise gr.Error("No 3D model generated yet.")
394
-
395
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
396
  shape_slat, tex_slat, res = unpack_state(state)
397
- mesh = trellis_pipeline.decode_latent(shape_slat, tex_slat, res)[0]
398
  mesh.simplify(16777216)
399
 
400
  glb = o_voxel.postprocess.to_glb(
@@ -402,7 +503,7 @@ def extract_glb(
402
  faces=mesh.faces,
403
  attr_volume=mesh.attrs,
404
  coords=mesh.coords,
405
- attr_layout=trellis_pipeline.pbr_attr_layout,
406
  grid_size=res,
407
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
408
  decimation_target=decimation_target,
@@ -412,6 +513,7 @@ def extract_glb(
412
  remesh_project=0,
413
  use_tqdm=True,
414
  )
 
415
  now = datetime.now()
416
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
417
  os.makedirs(user_dir, exist_ok=True)
@@ -421,113 +523,109 @@ def extract_glb(
421
  return glb_path, glb_path
422
 
423
  # ==========================================
424
- # Gradio UI Blocks
425
  # ==========================================
426
 
427
  if __name__ == "__main__":
428
  os.makedirs(TMP_DIR, exist_ok=True)
429
 
430
- # Pre-process icon base64
431
  for i in range(len(MODES)):
432
  if os.path.exists(MODES[i]['icon']):
433
  icon = Image.open(MODES[i]['icon'])
434
  MODES[i]['icon_base64'] = image_to_base64(icon)
435
  else:
436
- MODES[i]['icon_base64'] = "" # Fallback
437
 
438
- with gr.Blocks(delete_cache=(600, 600)) as demo:
439
  gr.Markdown("""
440
  # TRELLIS.2-3D
441
- **Unified Text-to-3D Workflow**
442
-
443
- 1. **Text to Image**: Generate a base image using Z-Image-Turbo.
444
- 2. **Image to 3D**: Convert that image into a high-quality 3D asset using TRELLIS.2.
445
  """)
446
 
447
  with gr.Row():
448
- # --- Column 1: Inputs & Config ---
449
- with gr.Column(scale=1, min_width=360):
450
-
451
- # --- Step 1: Text to Image ---
452
- with gr.Group():
453
- gr.Markdown("### Step 1: Generate Image")
454
- txt_prompt = gr.Textbox(label="Prompt", placeholder="A sci-fi helmet, high quality, white background")
455
- btn_gen_img = gr.Button("Generate Image", variant="secondary")
456
 
457
- # --- Step 2: Image to 3D Input ---
458
- gr.Markdown("### Step 2: Configure & Convert")
459
- image_prompt = gr.Image(label="Input Image (Generated or Uploaded)", format="png", image_mode="RGBA", type="pil", height=300)
 
460
 
461
- with gr.Accordion("3D Generation Settings", open=True):
462
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
463
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
464
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
465
- decimation_target = gr.Slider(100000, 500000, label="Decimation Target (For GLB)", value=300000, step=10000)
466
- texture_size = gr.Slider(1024, 4096, label="Texture Size (For GLB)", value=2048, step=1024)
467
-
468
- btn_gen_3d = gr.Button("Generate 3D", variant="primary")
469
-
470
- with gr.Accordion(label="Advanced Sampling Settings", open=False):
471
- gr.Markdown("**Stage 1: Sparse Structure**")
472
- ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
473
- ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale")
474
- ss_sampling_steps = gr.Slider(1, 50, value=12, label="Steps")
475
- ss_rescale_t = gr.Slider(1.0, 6.0, value=5.0, label="Rescale T")
476
 
477
- gr.Markdown("**Stage 2: Shape**")
478
- shape_guidance = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
479
- shape_rescale = gr.Slider(0.0, 1.0, value=0.5, label="Rescale")
480
- shape_steps = gr.Slider(1, 50, value=12, label="Steps")
481
- shape_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
482
-
483
- gr.Markdown("**Stage 3: Material**")
484
- tex_guidance = gr.Slider(1.0, 10.0, value=1.0, label="Guidance")
485
- tex_rescale = gr.Slider(0.0, 1.0, value=0.0, label="Rescale")
486
- tex_steps = gr.Slider(1, 50, value=12, label="Steps")
487
- tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
488
-
489
- # --- Column 2: Outputs ---
490
- with gr.Column(scale=10):
 
 
 
 
 
 
 
 
 
 
 
491
  with gr.Walkthrough(selected=0) as walkthrough:
492
  with gr.Step("Preview", id=0):
493
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
494
- extract_btn = gr.Button("Extract GLB")
495
- with gr.Step("Extract", id=1):
496
- glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
497
- download_btn = gr.DownloadButton(label="Download GLB")
498
-
499
- gr.Markdown("Note: GLB extraction might take ~30s.")
500
-
501
- # ==========================================
502
- # Wiring Events
503
- # ==========================================
 
 
 
 
 
 
 
 
504
 
505
- # State to hold the latent 3D representation
506
- output_buf = gr.State()
507
-
508
  demo.load(start_session)
509
  demo.unload(end_session)
510
 
511
- # 1. Text to Image Event
512
- btn_gen_img.click(
513
- generate_txt2img,
514
- inputs=[txt_prompt],
515
- outputs=[image_prompt]
516
- ).then(
517
- preprocess_image, # Auto preprocess the generated image (rmbg)
518
- inputs=[image_prompt],
519
- outputs=[image_prompt]
520
  )
521
 
522
- # 2. Upload Image Event (Preprocess)
523
- image_prompt.upload(
524
  preprocess_image,
525
- inputs=[image_prompt],
526
- outputs=[image_prompt],
527
  )
528
 
529
- # 3. Image to 3D Event
530
- btn_gen_3d.click(
531
  get_seed,
532
  inputs=[randomize_seed, seed],
533
  outputs=[seed],
@@ -536,21 +634,24 @@ if __name__ == "__main__":
536
  ).then(
537
  image_to_3d,
538
  inputs=[
539
- image_prompt, seed, resolution,
540
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
541
- shape_guidance, shape_rescale, shape_steps, shape_rescale_t,
542
- tex_guidance, tex_rescale, tex_steps, tex_rescale_t,
543
  ],
544
- outputs=[output_buf, preview_output],
545
  )
546
 
547
- # 4. Extraction Event
548
  extract_btn.click(
549
  lambda: gr.Walkthrough(selected=1), outputs=walkthrough
550
  ).then(
551
  extract_glb,
552
- inputs=[output_buf, decimation_target, texture_size],
553
  outputs=[glb_output, download_btn],
554
  )
 
 
 
555
 
556
- demo.launch(css=css, head=head, mcp_server=True, ssr_mode=False, show_error=True)
 
1
  import os
2
+ import io
3
  import cv2
4
+ import time
5
+ import base64
6
+ import shutil
7
+ import tempfile
8
  import torch
9
  import numpy as np
10
+ import gradio as gr
11
+ from gradio_client import Client, handle_file
12
+ from pathlib import Path
13
+ from typing import Tuple, List, Optional
14
  from PIL import Image
 
 
 
 
15
  from datetime import datetime
 
16
 
17
+ # --- Environment Configuration (Must be set before importing trellis2) ---
18
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
19
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
20
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
21
+ # Adjust path if necessary or ensure autotune_cache.json exists in trellis2 dir
22
+ try:
23
+ from trellis2 import modules
24
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(modules.__file__)), 'autotune_cache.json')
25
+ except:
26
+ pass
27
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
28
 
29
+ import spaces # For GPU management in Hugging Face Spaces
30
+
31
+ # --- Imports for Z-Image-Turbo ---
 
32
  from diffusers import ZImagePipeline
33
 
34
+ # --- Imports for TRELLIS.2 ---
35
  from trellis2.modules.sparse import SparseTensor
36
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
37
  from trellis2.renderers import EnvMap
 
39
  import o_voxel
40
 
41
  # ==========================================
42
+ # Global Constants & CSS/JS
43
  # ==========================================
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
47
 
48
+ # Asset definitions
49
  MODES = [
50
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
51
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
 
58
  DEFAULT_MODE = 3
59
  DEFAULT_STEP = 3
60
 
61
+ CSS = """
62
+ /* TRELLIS Custom Styles */
 
 
 
 
63
  .stepper-wrapper { padding: 0; }
64
  .stepper-container { padding: 0; align-items: center; }
65
  .step-button { flex-direction: row; }
 
 
 
 
 
 
 
 
66
  .previewer-container {
67
+ position: relative;
68
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
69
+ width: 100%;
70
+ height: 722px;
71
+ margin: 0 auto;
72
+ padding: 20px;
73
+ display: flex;
74
+ flex-direction: column;
75
+ align-items: center;
76
+ justify-content: center;
77
+ }
78
+ .previewer-container .tips-icon {
79
+ position: absolute; right: 10px; top: 10px; z-index: 10;
80
+ border-radius: 10px; color: #fff; background-color: var(--color-accent);
81
+ padding: 3px 6px; user-select: none;
82
+ }
83
+ .previewer-container .tips-text {
84
+ position: absolute; right: 10px; top: 50px; color: #fff;
85
+ background-color: var(--color-accent); border-radius: 10px;
86
+ padding: 6px; text-align: left; max-width: 300px; z-index: 10;
87
+ transition: all 0.3s; opacity: 0%; user-select: none;
88
  }
 
 
 
89
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
90
+ .previewer-container .mode-row {
91
+ width: 100%; display: flex; gap: 8px; justify-content: center;
92
+ margin-bottom: 20px; flex-wrap: wrap;
93
+ }
94
+ .previewer-container .mode-btn {
95
+ width: 24px; height: 24px; border-radius: 50%; cursor: pointer;
96
+ opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover;
97
+ }
98
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
99
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
100
+ .previewer-container .display-row {
101
+ margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1;
102
+ display: flex; justify-content: center; align-items: center;
103
+ }
104
+ .previewer-container .previewer-main-image {
105
+ max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none;
106
+ }
107
  .previewer-container .previewer-main-image.visible { display: block; }
108
+ .previewer-container .slider-row {
109
+ width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px;
110
+ }
111
+ .previewer-container input[type=range] {
112
+ -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent;
113
+ }
114
+ .previewer-container input[type=range]::-webkit-slider-runnable-track {
115
+ width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px;
116
+ }
117
+ .previewer-container input[type=range]::-webkit-slider-thumb {
118
+ height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent);
119
+ cursor: pointer; -webkit-appearance: none; margin-top: -6px;
120
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s;
121
+ }
122
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
 
123
  """
124
 
125
+ HEAD = """
126
  <script>
127
  function refreshView(mode, step) {
128
  const allImgs = document.querySelectorAll('.previewer-main-image');
 
140
  const targetId = 'view-m' + mode + '-s' + step;
141
  const targetImg = document.getElementById(targetId);
142
  if (targetImg) targetImg.classList.add('visible');
 
143
  const allBtns = document.querySelectorAll('.mode-btn');
144
  allBtns.forEach((btn, idx) => {
145
  if (idx === mode) btn.classList.add('active');
 
151
  </script>
152
  """
153
 
154
+ EMPTY_HTML = f"""
155
  <div class="previewer-container">
156
+ <div style="opacity: 0.5; text-align: center;">
157
+ <p>3D Asset Preview will appear here.</p>
158
+ </div>
159
  </div>
160
  """
161
 
 
165
 
166
  print("Initializing models...")
167
 
168
+ # 1. Load Z-Image-Turbo (Text to Image)
169
  print("Loading Z-Image-Turbo...")
170
  try:
171
+ t2i_pipe = ZImagePipeline.from_pretrained(
172
  "Tongyi-MAI/Z-Image-Turbo",
173
  torch_dtype=torch.bfloat16,
174
  low_cpu_mem_usage=False,
175
  )
176
  device = "cuda" if torch.cuda.is_available() else "cpu"
177
+ t2i_pipe.to(device)
178
  print("Z-Image-Turbo loaded.")
179
  except Exception as e:
180
  print(f"Failed to load Z-Image-Turbo: {e}")
181
+ t2i_pipe = None
182
 
183
+ # 2. Load TRELLIS.2 (Image to 3D)
184
  print("Loading TRELLIS.2...")
185
+ try:
186
+ pipeline_trellis = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
187
+ pipeline_trellis.rembg_model = None # We use external Bria RMBG
188
+ pipeline_trellis.low_vram = False
189
+ pipeline_trellis.cuda()
190
+ print("TRELLIS.2 loaded.")
191
+ except Exception as e:
192
+ print(f"Failed to load TRELLIS.2: {e}")
193
+ pipeline_trellis = None
194
 
195
+ # 3. Load RMBG Client
196
+ print("Loading RMBG Client...")
 
197
  try:
198
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
 
 
 
 
199
  except Exception as e:
200
+ print(f"Failed to connect to RMBG client: {e}")
201
+ rmbg_client = None
202
+
203
+ # 4. Load EnvMaps (Assuming assets folder exists)
204
+ envmap = {}
205
+ if os.path.exists('assets/hdri'):
206
+ try:
207
+ envmap = {
208
+ 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
209
+ 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
210
+ 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
211
+ }
212
+ except Exception as e:
213
+ print(f"Warning: Could not load HDRIs: {e}")
214
+ else:
215
+ print("Warning: 'assets/hdri' folder not found. Preview modes may fail.")
216
 
217
  # ==========================================
218
  # Helper Functions
 
230
  os.makedirs(user_dir, exist_ok=True)
231
 
232
  def end_session(req: gr.Request):
233
+ try:
234
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
235
  shutil.rmtree(user_dir)
236
+ except:
237
+ pass
238
 
239
  def remove_background(input: Image.Image) -> Image.Image:
240
+ """Removes background using Bria RMBG API via Gradio Client."""
241
+ if rmbg_client is None:
242
+ return input # Fallback
243
+
244
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
245
  input = input.convert('RGB')
246
  input.save(f.name)
247
+ f_path = f.name
248
+
249
+ try:
250
+ output_path = rmbg_client.predict(handle_file(f_path), api_name="/image")[0][0]
251
+ output = Image.open(output_path)
252
  return output
253
+ finally:
254
+ if os.path.exists(f_path):
255
+ os.remove(f_path)
256
 
257
  def preprocess_image(input: Image.Image) -> Image.Image:
258
+ """Preprocesses image (resizing, centering, BG removal) for TRELLIS."""
259
+ if input is None:
260
+ return None
261
+
262
+ # Check Alpha
263
  has_alpha = False
264
  if input.mode == 'RGBA':
265
  alpha = np.array(input)[:, :, 3]
266
  if not np.all(alpha == 255):
267
  has_alpha = True
268
+
269
+ # Resize if too large
270
  max_size = max(input.size)
271
  scale = min(1, 1024 / max_size)
272
  if scale < 1:
273
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
274
+
275
+ # Remove BG if needed
276
  if has_alpha:
277
  output = input
278
  else:
279
  output = remove_background(input)
280
+
281
+ # Centering and Cropping logic
282
  output_np = np.array(output)
283
+ # Ensure it has alpha now
284
+ if output_np.shape[2] == 4:
285
+ alpha = output_np[:, :, 3]
286
+ rows = np.any(alpha > 200, axis=1)
287
+ cols = np.any(alpha > 200, axis=0)
288
+ if np.any(rows) and np.any(cols): # Check if image is not empty
289
+ ymin, ymax = np.where(rows)[0][[0, -1]]
290
+ xmin, xmax = np.where(cols)[0][[0, -1]]
291
+
292
+ w = xmax - xmin
293
+ h = ymax - ymin
294
+ size = max(w, h)
295
+ center_x, center_y = (xmin + xmax) / 2, (ymin + ymax) / 2
296
+
297
+ # Add some padding
298
+ size = int(size * 1.1)
299
+
300
+ # Crop
301
+ left = max(0, int(center_x - size // 2))
302
+ top = max(0, int(center_y - size // 2))
303
+ right = min(output.width, int(center_x + size // 2))
304
+ bottom = min(output.height, int(center_y + size // 2))
305
+
306
+ output = output.crop((left, top, right, bottom))
307
+
308
+ # Premultiply alpha on black background logic for clean tensor conversion later?
309
+ # Actually TRELLIS pipeline usually handles RGBA.
310
+ # But let's standardize:
311
+ output_np = np.array(output).astype(np.float32) / 255
312
+ if output_np.shape[2] == 4:
313
+ # Premultiply
314
+ output_np[:, :, :3] * output_np[:, :, 3:4]
315
+
316
+ output = Image.fromarray((output_np * 255).astype(np.uint8))
317
  return output
318
 
319
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
337
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
338
 
339
  # ==========================================
340
+ # Main Processing Functions
341
  # ==========================================
342
 
343
  @spaces.GPU()
344
+ def generate_text_to_image(prompt, progress=gr.Progress(track_tqdm=True)):
345
+ """Generates an image from text using Z-Image-Turbo."""
346
+ if t2i_pipe is None:
347
+ raise gr.Error("Text-to-Image Model not loaded.")
348
  if not prompt.strip():
349
+ raise gr.Error("Prompt is empty.")
350
+
351
  device = "cuda" if torch.cuda.is_available() else "cpu"
352
+ generator = torch.Generator(device).manual_seed(42) # Fixed seed for consistency demo, or make parametric
353
 
354
+ result = t2i_pipe(
355
+ prompt=prompt,
356
+ negative_prompt=None,
357
+ height=1024,
358
+ width=1024,
359
+ num_inference_steps=9,
360
+ guidance_scale=0.0,
361
+ generator=generator,
362
+ )
363
+ return result.images[0]
 
 
 
 
364
 
365
  @spaces.GPU(duration=120)
366
  def image_to_3d(
 
381
  tex_slat_rescale_t: float,
382
  req: gr.Request,
383
  progress=gr.Progress(track_tqdm=True),
384
+ ) -> Tuple[dict, str]:
385
 
386
+ if pipeline_trellis is None:
387
+ raise gr.Error("TRELLIS Model not loaded.")
388
+
389
  if image is None:
390
  raise gr.Error("Input image is missing.")
391
 
392
+ # Ensure image is preprocessed (if it came directly from T2I, it has a background)
393
+ # If the user manually uploaded an RGBA, preprocess_image ensures it's clean.
394
+ # Note: Logic handled by calling preprocess_image in the Gradio event chain or inside here.
395
+ # We will assume the input 'image' to this function is the result of the Preprocess step.
396
+
397
  # --- Sampling ---
398
+ outputs, latents = pipeline_trellis.run(
399
  image,
400
  seed=seed,
401
+ preprocess_image=False, # We assume input is already preprocessed
402
  sparse_structure_sampler_params={
403
  "steps": ss_sampling_steps,
404
  "guidance_strength": ss_guidance_strength,
 
417
  "guidance_rescale": tex_slat_guidance_rescale,
418
  "rescale_t": tex_slat_rescale_t,
419
  },
420
+ pipeline_type={
421
+ "512": "512",
422
+ "1024": "1024_cascade",
423
+ "1536": "1536_cascade",
424
+ }[resolution],
425
  return_latent=True,
426
  )
427
+
428
  mesh = outputs[0]
429
  mesh.simplify(16777216)
430
 
431
  # Render Preview
432
+ images_rendered = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
433
  state = pack_state(latents)
434
  torch.cuda.empty_cache()
435
 
436
  # --- HTML Construction ---
437
  images_html = ""
438
  for m_idx, mode in enumerate(MODES):
439
+ if mode['render_key'] not in images_rendered: continue # skip if missing hdri
440
  for s_idx in range(STEPS):
441
  unique_id = f"view-m{m_idx}-s{s_idx}"
442
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
443
  vis_class = "visible" if is_visible else ""
444
+ img_base64 = image_to_base64(Image.fromarray(images_rendered[mode['render_key']][s_idx]))
445
+
446
+ images_html += f"""
447
+ <img id="{unique_id}"
448
+ class="previewer-main-image {vis_class}"
449
+ src="{img_base64}"
450
+ loading="eager">
451
+ """
452
 
453
  btns_html = ""
454
+ for idx, mode in enumerate(MODES):
455
+ if mode['render_key'] not in images_rendered: continue
456
  active_class = "active" if idx == DEFAULT_MODE else ""
457
+ btns_html += f"""
458
+ <img src="{mode['icon_base64']}"
459
+ class="mode-btn {active_class}"
460
+ onclick="selectMode({idx})"
461
+ title="{mode['name']}">
462
+ """
463
 
464
  full_html = f"""
465
  <div class="previewer-container">
466
  <div class="tips-wrapper">
467
  <div class="tips-icon">💡Tips</div>
468
+ <div class="tips-text">
469
+ <p>● <b>Render Mode</b> - Click buttons to switch render modes.</p>
470
+ <p>● <b>View Angle</b> - Drag slider to rotate.</p>
471
+ </div>
472
+ </div>
473
+ <div class="display-row">
474
+ {images_html}
475
+ </div>
476
+ <div class="mode-row" id="btn-group">
477
+ {btns_html}
478
  </div>
 
 
479
  <div class="slider-row">
480
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
481
  </div>
482
  </div>
483
  """
484
+
485
  return state, full_html
486
 
487
  @spaces.GPU(duration=120)
 
492
  req: gr.Request,
493
  progress=gr.Progress(track_tqdm=True),
494
  ) -> Tuple[str, str]:
495
+
 
 
496
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
497
  shape_slat, tex_slat, res = unpack_state(state)
498
+ mesh = pipeline_trellis.decode_latent(shape_slat, tex_slat, res)[0]
499
  mesh.simplify(16777216)
500
 
501
  glb = o_voxel.postprocess.to_glb(
 
503
  faces=mesh.faces,
504
  attr_volume=mesh.attrs,
505
  coords=mesh.coords,
506
+ attr_layout=pipeline_trellis.pbr_attr_layout,
507
  grid_size=res,
508
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
509
  decimation_target=decimation_target,
 
513
  remesh_project=0,
514
  use_tqdm=True,
515
  )
516
+
517
  now = datetime.now()
518
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
519
  os.makedirs(user_dir, exist_ok=True)
 
523
  return glb_path, glb_path
524
 
525
  # ==========================================
526
+ # Gradio Interface
527
  # ==========================================
528
 
529
  if __name__ == "__main__":
530
  os.makedirs(TMP_DIR, exist_ok=True)
531
 
532
+ # Pre-calculate base64 for icons to avoid FS lag
533
  for i in range(len(MODES)):
534
  if os.path.exists(MODES[i]['icon']):
535
  icon = Image.open(MODES[i]['icon'])
536
  MODES[i]['icon_base64'] = image_to_base64(icon)
537
  else:
538
+ MODES[i]['icon_base64'] = "" # Handle missing assets
539
 
540
+ with gr.Blocks(css=CSS, head=HEAD, delete_cache=(600, 600)) as demo:
541
  gr.Markdown("""
542
  # TRELLIS.2-3D
543
+ ### Text-to-Image (Z-Image-Turbo) + Image-to-3D (TRELLIS.2) Pipeline
 
 
 
544
  """)
545
 
546
  with gr.Row():
547
+ # Left Column: Inputs & Text-to-Image
548
+ with gr.Column(scale=1):
549
+ gr.Markdown("### 1. Generate Image")
550
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="A 3D rendering of a cute isometric house...")
551
+ gen_image_btn = gr.Button("Generate Image from Text", variant="secondary")
 
 
 
552
 
553
+ gr.Markdown("### 2. Prepare for 3D")
554
+ # This Image component acts as the bridge.
555
+ # It accepts output from T2I OR user upload.
556
+ image_input = gr.Image(label="Input Image (Auto-Preprocessed)", type="pil", image_mode="RGBA", height=300)
557
 
558
+ with gr.Accordion("Image-to-3D Settings", open=True):
559
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
560
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
561
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
562
 
563
+ with gr.Accordion("Advanced Parameters", open=False):
564
+ gr.Markdown("**Stage 1: Sparse Structure**")
565
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
566
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.7, step=0.01)
567
+ ss_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
568
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
569
+
570
+ gr.Markdown("**Stage 2: Shape**")
571
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
572
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.5, step=0.01)
573
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
574
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
575
+
576
+ gr.Markdown("**Stage 3: Texture**")
577
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=1.0, step=0.1)
578
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.0, step=0.01)
579
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
580
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
581
+
582
+ gen_3d_btn = gr.Button("Generate 3D Model", variant="primary")
583
+
584
+ # Right Column: 3D Preview & Export
585
+ with gr.Column(scale=2):
586
+ gr.Markdown("### 3. 3D Preview")
587
+ # We use a Walkthrough to switch between Preview HTML and GLB Viewer
588
  with gr.Walkthrough(selected=0) as walkthrough:
589
  with gr.Step("Preview", id=0):
590
+ preview_output = gr.HTML(EMPTY_HTML, label="3D Preview", container=True)
591
+ gr.Markdown("*(If the preview is black, verify 'assets/hdri' files exist)*")
592
+
593
+ with gr.Row():
594
+ decimation_target = gr.Slider(100000, 500000, label="Mesh Decimation", value=300000, step=10000)
595
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
596
+ extract_btn = gr.Button("Extract & Download GLB")
597
+
598
+ with gr.Step("Result", id=1):
599
+ glb_output = gr.Model3D(label="Extracted GLB", height=600, display_mode="solid", clear_color=(0.2, 0.2, 0.2, 1.0))
600
+ download_btn = gr.DownloadButton(label="Download .glb File")
601
+ back_btn = gr.Button("Back to Preview")
602
+
603
+ # Hidden State to store TRELLIS latent representation
604
+ output_state = gr.State()
605
+
606
+ # ====================
607
+ # Event Handling
608
+ # ====================
609
 
 
 
 
610
  demo.load(start_session)
611
  demo.unload(end_session)
612
 
613
+ # 1. Text to Image
614
+ gen_image_btn.click(
615
+ generate_text_to_image,
616
+ inputs=[text_prompt],
617
+ outputs=[image_input]
 
 
 
 
618
  )
619
 
620
+ # 2. Image Preprocessing (Auto-trigger when image changes)
621
+ image_input.change(
622
  preprocess_image,
623
+ inputs=[image_input],
624
+ outputs=[image_input]
625
  )
626
 
627
+ # 3. Image to 3D Generation
628
+ gen_3d_btn.click(
629
  get_seed,
630
  inputs=[randomize_seed, seed],
631
  outputs=[seed],
 
634
  ).then(
635
  image_to_3d,
636
  inputs=[
637
+ image_input, seed, resolution,
638
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
639
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
640
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
641
  ],
642
+ outputs=[output_state, preview_output],
643
  )
644
 
645
+ # 4. Extract GLB
646
  extract_btn.click(
647
  lambda: gr.Walkthrough(selected=1), outputs=walkthrough
648
  ).then(
649
  extract_glb,
650
+ inputs=[output_state, decimation_target, texture_size],
651
  outputs=[glb_output, download_btn],
652
  )
653
+
654
+ # 5. Back button
655
+ back_btn.click(lambda: gr.Walkthrough(selected=0), outputs=walkthrough)
656
 
657
+ demo.launch(server_name="0.0.0.0", server_port=7860)