prithivMLmods commited on
Commit
f40f8d6
·
verified ·
1 Parent(s): 78bde42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -383
app.py CHANGED
@@ -1,68 +1,61 @@
1
  import os
2
- import io
3
- import cv2
4
- import time
5
- import base64
6
  import shutil
7
- import tempfile
8
  import torch
9
  import numpy as np
 
 
 
 
10
  import gradio as gr
11
  from gradio_client import Client, handle_file
 
12
  from pathlib import Path
13
- from typing import Tuple, List, Optional
14
- from PIL import Image
15
  from datetime import datetime
16
 
17
- # --- Environment Configuration (Must be set before importing trellis2) ---
18
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
19
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
20
- os.environ["ATTN_BACKEND"] = "flash_attn_3"
21
- # Adjust path if necessary or ensure autotune_cache.json exists in trellis2 dir
22
- try:
23
- from trellis2 import modules
24
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(modules.__file__)), 'autotune_cache.json')
25
- except:
26
- pass
27
- os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
28
-
29
- import spaces # For GPU management in Hugging Face Spaces
30
-
31
- # --- Imports for Z-Image-Turbo ---
32
- from diffusers import ZImagePipeline
33
-
34
- # --- Imports for TRELLIS.2 ---
35
  from trellis2.modules.sparse import SparseTensor
36
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
37
  from trellis2.renderers import EnvMap
38
  from trellis2.utils import render_utils
39
  import o_voxel
40
 
 
 
 
41
  # ==========================================
42
- # Global Constants & CSS/JS
43
  # ==========================================
44
 
 
 
 
 
 
 
 
45
  MAX_SEED = np.iinfo(np.int32).max
46
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
47
 
48
- # Asset definitions
49
- MODES = [
50
- {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
51
- {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
52
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
53
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
54
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
55
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
56
- ]
57
- STEPS = 8
58
- DEFAULT_MODE = 3
59
- DEFAULT_STEP = 3
60
 
61
- CSS = """
62
- /* TRELLIS Custom Styles */
63
  .stepper-wrapper { padding: 0; }
64
  .stepper-container { padding: 0; align-items: center; }
65
  .step-button { flex-direction: row; }
 
 
 
 
 
 
 
 
66
  .previewer-container {
67
  position: relative;
68
  font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
@@ -76,53 +69,41 @@ CSS = """
76
  justify-content: center;
77
  }
78
  .previewer-container .tips-icon {
79
- position: absolute; right: 10px; top: 10px; z-index: 10;
80
- border-radius: 10px; color: #fff; background-color: var(--color-accent);
81
- padding: 3px 6px; user-select: none;
82
  }
83
  .previewer-container .tips-text {
84
- position: absolute; right: 10px; top: 50px; color: #fff;
85
- background-color: var(--color-accent); border-radius: 10px;
86
- padding: 6px; text-align: left; max-width: 300px; z-index: 10;
87
  transition: all 0.3s; opacity: 0%; user-select: none;
88
  }
 
89
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
90
- .previewer-container .mode-row {
91
- width: 100%; display: flex; gap: 8px; justify-content: center;
92
- margin-bottom: 20px; flex-wrap: wrap;
93
- }
94
- .previewer-container .mode-btn {
95
- width: 24px; height: 24px; border-radius: 50%; cursor: pointer;
96
- opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover;
97
- }
98
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
99
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
100
- .previewer-container .display-row {
101
- margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1;
102
- display: flex; justify-content: center; align-items: center;
103
- }
104
- .previewer-container .previewer-main-image {
105
- max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none;
106
- }
107
  .previewer-container .previewer-main-image.visible { display: block; }
108
- .previewer-container .slider-row {
109
- width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px;
110
- }
111
- .previewer-container input[type=range] {
112
- -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent;
113
- }
114
- .previewer-container input[type=range]::-webkit-slider-runnable-track {
115
- width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px;
116
- }
117
- .previewer-container input[type=range]::-webkit-slider-thumb {
118
- height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent);
119
- cursor: pointer; -webkit-appearance: none; margin-top: -6px;
120
- box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s;
121
- }
122
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
 
123
  """
124
 
125
- HEAD = """
126
  <script>
127
  function refreshView(mode, step) {
128
  const allImgs = document.querySelectorAll('.previewer-main-image');
@@ -140,6 +121,7 @@ HEAD = """
140
  const targetId = 'view-m' + mode + '-s' + step;
141
  const targetImg = document.getElementById(targetId);
142
  if (targetImg) targetImg.classList.add('visible');
 
143
  const allBtns = document.querySelectorAll('.mode-btn');
144
  allBtns.forEach((btn, idx) => {
145
  if (idx === mode) btn.classList.add('active');
@@ -151,68 +133,45 @@ HEAD = """
151
  </script>
152
  """
153
 
154
- EMPTY_HTML = f"""
155
  <div class="previewer-container">
156
- <div style="opacity: 0.5; text-align: center;">
157
- <p>3D Asset Preview will appear here.</p>
158
- </div>
159
  </div>
160
  """
161
 
162
  # ==========================================
163
- # Model Loading
164
  # ==========================================
165
 
166
- print("Initializing models...")
167
-
168
- # 1. Load Z-Image-Turbo (Text to Image)
169
- print("Loading Z-Image-Turbo...")
170
- try:
171
- t2i_pipe = ZImagePipeline.from_pretrained(
172
- "Tongyi-MAI/Z-Image-Turbo",
173
- torch_dtype=torch.bfloat16,
174
- low_cpu_mem_usage=False,
175
- )
176
- device = "cuda" if torch.cuda.is_available() else "cpu"
177
- t2i_pipe.to(device)
178
- print("Z-Image-Turbo loaded.")
179
- except Exception as e:
180
- print(f"Failed to load Z-Image-Turbo: {e}")
181
- t2i_pipe = None
182
-
183
- # 2. Load TRELLIS.2 (Image to 3D)
184
- print("Loading TRELLIS.2...")
185
- try:
186
- pipeline_trellis = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
187
- pipeline_trellis.rembg_model = None # We use external Bria RMBG
188
- pipeline_trellis.low_vram = False
189
- pipeline_trellis.cuda()
190
- print("TRELLIS.2 loaded.")
191
- except Exception as e:
192
- print(f"Failed to load TRELLIS.2: {e}")
193
- pipeline_trellis = None
194
-
195
- # 3. Load RMBG Client
196
- print("Loading RMBG Client...")
197
- try:
198
- rmbg_client = Client("briaai/BRIA-RMBG-2.0")
199
- except Exception as e:
200
- print(f"Failed to connect to RMBG client: {e}")
201
- rmbg_client = None
202
-
203
- # 4. Load EnvMaps (Assuming assets folder exists)
204
  envmap = {}
205
- if os.path.exists('assets/hdri'):
206
- try:
207
- envmap = {
208
- 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
209
- 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
210
- 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
211
- }
212
- except Exception as e:
213
- print(f"Warning: Could not load HDRIs: {e}")
214
- else:
215
- print("Warning: 'assets/hdri' folder not found. Preview modes may fail.")
 
 
216
 
217
  # ==========================================
218
  # Helper Functions
@@ -228,92 +187,51 @@ def image_to_base64(image):
228
  def start_session(req: gr.Request):
229
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
230
  os.makedirs(user_dir, exist_ok=True)
231
-
232
  def end_session(req: gr.Request):
233
- try:
234
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
235
  shutil.rmtree(user_dir)
236
- except:
237
- pass
238
 
239
  def remove_background(input: Image.Image) -> Image.Image:
240
- """Removes background using Bria RMBG API via Gradio Client."""
241
- if rmbg_client is None:
242
- return input # Fallback
243
-
244
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
245
  input = input.convert('RGB')
246
  input.save(f.name)
247
- f_path = f.name
248
-
249
- try:
250
- output_path = rmbg_client.predict(handle_file(f_path), api_name="/image")[0][0]
251
- output = Image.open(output_path)
252
  return output
253
- finally:
254
- if os.path.exists(f_path):
255
- os.remove(f_path)
256
 
257
  def preprocess_image(input: Image.Image) -> Image.Image:
258
- """Preprocesses image (resizing, centering, BG removal) for TRELLIS."""
259
  if input is None:
260
  return None
261
-
262
- # Check Alpha
263
  has_alpha = False
264
  if input.mode == 'RGBA':
265
  alpha = np.array(input)[:, :, 3]
266
  if not np.all(alpha == 255):
267
  has_alpha = True
268
-
269
- # Resize if too large
270
  max_size = max(input.size)
271
  scale = min(1, 1024 / max_size)
272
  if scale < 1:
273
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
274
-
275
- # Remove BG if needed
276
  if has_alpha:
277
  output = input
278
  else:
279
  output = remove_background(input)
280
-
281
- # Centering and Cropping logic
282
  output_np = np.array(output)
283
- # Ensure it has alpha now
284
- if output_np.shape[2] == 4:
285
- alpha = output_np[:, :, 3]
286
- rows = np.any(alpha > 200, axis=1)
287
- cols = np.any(alpha > 200, axis=0)
288
- if np.any(rows) and np.any(cols): # Check if image is not empty
289
- ymin, ymax = np.where(rows)[0][[0, -1]]
290
- xmin, xmax = np.where(cols)[0][[0, -1]]
291
-
292
- w = xmax - xmin
293
- h = ymax - ymin
294
- size = max(w, h)
295
- center_x, center_y = (xmin + xmax) / 2, (ymin + ymax) / 2
296
-
297
- # Add some padding
298
- size = int(size * 1.1)
299
-
300
- # Crop
301
- left = max(0, int(center_x - size // 2))
302
- top = max(0, int(center_y - size // 2))
303
- right = min(output.width, int(center_x + size // 2))
304
- bottom = min(output.height, int(center_y + size // 2))
305
-
306
- output = output.crop((left, top, right, bottom))
307
-
308
- # Premultiply alpha on black background logic for clean tensor conversion later?
309
- # Actually TRELLIS pipeline usually handles RGBA.
310
- # But let's standardize:
311
- output_np = np.array(output).astype(np.float32) / 255
312
- if output_np.shape[2] == 4:
313
- # Premultiply
314
- output_np[:, :, :3] * output_np[:, :, 3:4]
315
-
316
- output = Image.fromarray((output_np * 255).astype(np.uint8))
317
  return output
318
 
319
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
@@ -324,7 +242,7 @@ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
324
  'coords': shape_slat.coords.cpu().numpy(),
325
  'res': res,
326
  }
327
-
328
  def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
329
  shape_slat = SparseTensor(
330
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
@@ -337,30 +255,37 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
337
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
338
 
339
  # ==========================================
340
- # Main Processing Functions
341
  # ==========================================
342
 
343
  @spaces.GPU()
344
- def generate_text_to_image(prompt, progress=gr.Progress(track_tqdm=True)):
345
- """Generates an image from text using Z-Image-Turbo."""
346
- if t2i_pipe is None:
347
- raise gr.Error("Text-to-Image Model not loaded.")
348
  if not prompt.strip():
349
- raise gr.Error("Prompt is empty.")
350
-
351
  device = "cuda" if torch.cuda.is_available() else "cpu"
352
- generator = torch.Generator(device).manual_seed(42) # Fixed seed for consistency demo, or make parametric
353
 
354
- result = t2i_pipe(
355
- prompt=prompt,
356
- negative_prompt=None,
357
- height=1024,
358
- width=1024,
359
- num_inference_steps=9,
360
- guidance_scale=0.0,
361
- generator=generator,
362
- )
363
- return result.images[0]
 
 
 
 
 
 
 
364
 
365
  @spaces.GPU(duration=120)
366
  def image_to_3d(
@@ -381,24 +306,15 @@ def image_to_3d(
381
  tex_slat_rescale_t: float,
382
  req: gr.Request,
383
  progress=gr.Progress(track_tqdm=True),
384
- ) -> Tuple[dict, str]:
385
-
386
- if pipeline_trellis is None:
387
- raise gr.Error("TRELLIS Model not loaded.")
388
-
389
  if image is None:
390
- raise gr.Error("Input image is missing.")
391
-
392
- # Ensure image is preprocessed (if it came directly from T2I, it has a background)
393
- # If the user manually uploaded an RGBA, preprocess_image ensures it's clean.
394
- # Note: Logic handled by calling preprocess_image in the Gradio event chain or inside here.
395
- # We will assume the input 'image' to this function is the result of the Preprocess step.
396
 
397
  # --- Sampling ---
398
- outputs, latents = pipeline_trellis.run(
399
  image,
400
  seed=seed,
401
- preprocess_image=False, # We assume input is already preprocessed
402
  sparse_structure_sampler_params={
403
  "steps": ss_sampling_steps,
404
  "guidance_strength": ss_guidance_strength,
@@ -424,64 +340,40 @@ def image_to_3d(
424
  }[resolution],
425
  return_latent=True,
426
  )
427
-
428
  mesh = outputs[0]
429
- mesh.simplify(16777216)
430
-
431
- # Render Preview
432
- images_rendered = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
433
  state = pack_state(latents)
434
  torch.cuda.empty_cache()
435
 
436
  # --- HTML Construction ---
437
  images_html = ""
438
  for m_idx, mode in enumerate(MODES):
439
- if mode['render_key'] not in images_rendered: continue # skip if missing hdri
440
  for s_idx in range(STEPS):
441
  unique_id = f"view-m{m_idx}-s{s_idx}"
442
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
443
  vis_class = "visible" if is_visible else ""
444
- img_base64 = image_to_base64(Image.fromarray(images_rendered[mode['render_key']][s_idx]))
445
-
446
- images_html += f"""
447
- <img id="{unique_id}"
448
- class="previewer-main-image {vis_class}"
449
- src="{img_base64}"
450
- loading="eager">
451
- """
452
 
453
  btns_html = ""
454
- for idx, mode in enumerate(MODES):
455
- if mode['render_key'] not in images_rendered: continue
456
  active_class = "active" if idx == DEFAULT_MODE else ""
457
- btns_html += f"""
458
- <img src="{mode['icon_base64']}"
459
- class="mode-btn {active_class}"
460
- onclick="selectMode({idx})"
461
- title="{mode['name']}">
462
- """
463
 
464
  full_html = f"""
465
  <div class="previewer-container">
466
  <div class="tips-wrapper">
467
  <div class="tips-icon">💡Tips</div>
468
- <div class="tips-text">
469
- <p>● <b>Render Mode</b> - Click buttons to switch render modes.</p>
470
- <p>● <b>View Angle</b> - Drag slider to rotate.</p>
471
- </div>
472
- </div>
473
- <div class="display-row">
474
- {images_html}
475
- </div>
476
- <div class="mode-row" id="btn-group">
477
- {btns_html}
478
  </div>
 
 
479
  <div class="slider-row">
480
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
481
  </div>
482
  </div>
483
  """
484
-
485
  return state, full_html
486
 
487
  @spaces.GPU(duration=120)
@@ -492,18 +384,19 @@ def extract_glb(
492
  req: gr.Request,
493
  progress=gr.Progress(track_tqdm=True),
494
  ) -> Tuple[str, str]:
495
-
 
 
496
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
497
  shape_slat, tex_slat, res = unpack_state(state)
498
- mesh = pipeline_trellis.decode_latent(shape_slat, tex_slat, res)[0]
499
  mesh.simplify(16777216)
500
-
501
  glb = o_voxel.postprocess.to_glb(
502
  vertices=mesh.vertices,
503
  faces=mesh.faces,
504
  attr_volume=mesh.attrs,
505
  coords=mesh.coords,
506
- attr_layout=pipeline_trellis.pbr_attr_layout,
507
  grid_size=res,
508
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
509
  decimation_target=decimation_target,
@@ -513,7 +406,6 @@ def extract_glb(
513
  remesh_project=0,
514
  use_tqdm=True,
515
  )
516
-
517
  now = datetime.now()
518
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
519
  os.makedirs(user_dir, exist_ok=True)
@@ -523,135 +415,147 @@ def extract_glb(
523
  return glb_path, glb_path
524
 
525
  # ==========================================
526
- # Gradio Interface
527
  # ==========================================
528
 
529
- if __name__ == "__main__":
530
- os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
531
 
532
- # Pre-calculate base64 for icons to avoid FS lag
533
- for i in range(len(MODES)):
534
- if os.path.exists(MODES[i]['icon']):
535
- icon = Image.open(MODES[i]['icon'])
536
- MODES[i]['icon_base64'] = image_to_base64(icon)
537
- else:
538
- MODES[i]['icon_base64'] = "" # Handle missing assets
539
-
540
- with gr.Blocks(css=CSS, head=HEAD, delete_cache=(600, 600)) as demo:
541
- gr.Markdown("""
542
- # TRELLIS.2-3D
543
- ### Text-to-Image (Z-Image-Turbo) + Image-to-3D (TRELLIS.2) Pipeline
544
- """)
545
-
546
- with gr.Row():
547
- # Left Column: Inputs & Text-to-Image
548
- with gr.Column(scale=1):
549
- gr.Markdown("### 1. Generate Image")
550
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="A 3D rendering of a cute isometric house...")
551
- gen_image_btn = gr.Button("Generate Image from Text", variant="secondary")
 
 
 
 
552
 
553
- gr.Markdown("### 2. Prepare for 3D")
554
- # This Image component acts as the bridge.
555
- # It accepts output from T2I OR user upload.
556
- image_input = gr.Image(label="Input Image (Auto-Preprocessed)", type="pil", image_mode="RGBA", height=300)
 
557
 
558
- with gr.Accordion("Image-to-3D Settings", open=True):
559
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
560
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
561
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
562
 
563
- with gr.Accordion("Advanced Parameters", open=False):
564
- gr.Markdown("**Stage 1: Sparse Structure**")
565
- ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
566
- ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.7, step=0.01)
567
- ss_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
568
- ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
569
-
570
- gr.Markdown("**Stage 2: Shape**")
571
- shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
572
- shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.5, step=0.01)
573
- shape_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
574
- shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
575
-
576
- gr.Markdown("**Stage 3: Texture**")
577
- tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=1.0, step=0.1)
578
- tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.0, step=0.01)
579
- tex_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
580
- tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
581
-
582
- gen_3d_btn = gr.Button("Generate 3D Model", variant="primary")
583
-
584
- # Right Column: 3D Preview & Export
585
- with gr.Column(scale=2):
586
- gr.Markdown("### 3. 3D Preview")
587
- # We use a Walkthrough to switch between Preview HTML and GLB Viewer
588
- with gr.Walkthrough(selected=0) as walkthrough:
589
- with gr.Step("Preview", id=0):
590
- preview_output = gr.HTML(EMPTY_HTML, label="3D Preview", container=True)
591
- gr.Markdown("*(If the preview is black, verify 'assets/hdri' files exist)*")
592
-
593
- with gr.Row():
594
- decimation_target = gr.Slider(100000, 500000, label="Mesh Decimation", value=300000, step=10000)
595
- texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
596
- extract_btn = gr.Button("Extract & Download GLB")
597
-
598
- with gr.Step("Result", id=1):
599
- glb_output = gr.Model3D(label="Extracted GLB", height=600, display_mode="solid", clear_color=(0.2, 0.2, 0.2, 1.0))
600
- download_btn = gr.DownloadButton(label="Download .glb File")
601
- back_btn = gr.Button("Back to Preview")
602
-
603
- # Hidden State to store TRELLIS latent representation
604
- output_state = gr.State()
605
-
606
- # ====================
607
- # Event Handling
608
- # ====================
609
-
610
- demo.load(start_session)
611
- demo.unload(end_session)
612
-
613
- # 1. Text to Image
614
- gen_image_btn.click(
615
- generate_text_to_image,
616
- inputs=[text_prompt],
617
- outputs=[image_input]
618
- )
619
 
620
- # 2. Image Preprocessing (Auto-trigger when image changes)
621
- image_input.change(
622
- preprocess_image,
623
- inputs=[image_input],
624
- outputs=[image_input]
625
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
- # 3. Image to 3D Generation
628
- gen_3d_btn.click(
629
- get_seed,
630
- inputs=[randomize_seed, seed],
631
- outputs=[seed],
632
- ).then(
633
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
634
- ).then(
635
- image_to_3d,
636
- inputs=[
637
- image_input, seed, resolution,
638
- ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
639
- shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
640
- tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
641
- ],
642
- outputs=[output_state, preview_output],
643
- )
644
 
645
- # 4. Extract GLB
646
- extract_btn.click(
647
- lambda: gr.Walkthrough(selected=1), outputs=walkthrough
648
- ).then(
649
- extract_glb,
650
- inputs=[output_state, decimation_target, texture_size],
651
- outputs=[glb_output, download_btn],
652
- )
653
-
654
- # 5. Back button
655
- back_btn.click(lambda: gr.Walkthrough(selected=0), outputs=walkthrough)
656
 
657
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
2
  import shutil
3
+ import cv2
4
  import torch
5
  import numpy as np
6
+ from PIL import Image
7
+ import base64
8
+ import io
9
+ import tempfile
10
  import gradio as gr
11
  from gradio_client import Client, handle_file
12
+ import spaces
13
  from pathlib import Path
14
+ from typing import *
 
15
  from datetime import datetime
16
 
17
+ # --- TRELLIS Specific Imports ---
18
+ # Ensure these libraries are installed in your environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from trellis2.modules.sparse import SparseTensor
20
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
21
  from trellis2.renderers import EnvMap
22
  from trellis2.utils import render_utils
23
  import o_voxel
24
 
25
+ # --- Z-Image Specific Imports ---
26
+ from diffusers import ZImagePipeline
27
+
28
  # ==========================================
29
+ # Configuration & Environment Setup
30
  # ==========================================
31
 
32
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
33
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
34
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
35
+ # Adjust path if necessary for your environment
36
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
37
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
38
+
39
  MAX_SEED = np.iinfo(np.int32).max
40
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
 
42
+ # ==========================================
43
+ # CSS & JS Resources (From TRELLIS)
44
+ # ==========================================
 
 
 
 
 
 
 
 
 
45
 
46
+ css = """
47
+ /* Overwrite Gradio Default Style */
48
  .stepper-wrapper { padding: 0; }
49
  .stepper-container { padding: 0; align-items: center; }
50
  .step-button { flex-direction: row; }
51
+ .step-connector { transform: none; }
52
+ .step-number { width: 16px; height: 16px; }
53
+ .step-label { position: relative; bottom: 0; }
54
+ .wrap.center.full { inset: 0; height: 100%; }
55
+ .wrap.center.full.translucent { background: var(--block-background-fill); }
56
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
57
+
58
+ /* Previewer */
59
  .previewer-container {
60
  position: relative;
61
  font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
 
69
  justify-content: center;
70
  }
71
  .previewer-container .tips-icon {
72
+ position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px;
73
+ color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none;
 
74
  }
75
  .previewer-container .tips-text {
76
+ position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent);
77
+ border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10;
 
78
  transition: all 0.3s; opacity: 0%; user-select: none;
79
  }
80
+ .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
81
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
82
+
83
+ /* Row 1: Display Modes */
84
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
85
+ .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
 
 
 
 
86
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
87
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
88
+
89
+ /* Row 2: Display Image */
90
+ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
91
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
 
 
 
92
  .previewer-container .previewer-main-image.visible { display: block; }
93
+
94
+ /* Row 3: Custom HTML Slider */
95
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
96
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
97
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
98
+ .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
99
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
100
+
101
+ /* Overwrite Previewer Block Style */
 
 
 
 
 
102
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
103
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
104
  """
105
 
106
+ head = """
107
  <script>
108
  function refreshView(mode, step) {
109
  const allImgs = document.querySelectorAll('.previewer-main-image');
 
121
  const targetId = 'view-m' + mode + '-s' + step;
122
  const targetImg = document.getElementById(targetId);
123
  if (targetImg) targetImg.classList.add('visible');
124
+
125
  const allBtns = document.querySelectorAll('.mode-btn');
126
  allBtns.forEach((btn, idx) => {
127
  if (idx === mode) btn.classList.add('active');
 
133
  </script>
134
  """
135
 
136
+ empty_html = f"""
137
  <div class="previewer-container">
138
+ <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
139
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
 
140
  </div>
141
  """
142
 
143
  # ==========================================
144
+ # Model Loading & Global Variables
145
  # ==========================================
146
 
147
+ print("Loading Z-Image Turbo model...")
148
+ # Text-to-Image Pipeline
149
+ z_pipe = ZImagePipeline.from_pretrained(
150
+ "Tongyi-MAI/Z-Image-Turbo",
151
+ torch_dtype=torch.bfloat16,
152
+ low_cpu_mem_usage=False,
153
+ )
154
+ if torch.cuda.is_available():
155
+ z_pipe.to("cuda")
156
+ print("Z-Image Turbo loaded.")
157
+
158
+ # Image-to-3D Pipeline Placeholders
159
+ pipeline = None
160
+ rmbg_client = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  envmap = {}
162
+
163
+ # TRELLIS Settings
164
+ MODES = [
165
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
166
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
167
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
168
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
169
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
170
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
171
+ ]
172
+ STEPS = 8
173
+ DEFAULT_MODE = 3
174
+ DEFAULT_STEP = 3
175
 
176
  # ==========================================
177
  # Helper Functions
 
187
  def start_session(req: gr.Request):
188
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
189
  os.makedirs(user_dir, exist_ok=True)
190
+
191
  def end_session(req: gr.Request):
192
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
193
+ if os.path.exists(user_dir):
194
  shutil.rmtree(user_dir)
 
 
195
 
196
  def remove_background(input: Image.Image) -> Image.Image:
197
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
 
 
 
 
198
  input = input.convert('RGB')
199
  input.save(f.name)
200
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
201
+ output = Image.open(output)
 
 
 
202
  return output
 
 
 
203
 
204
  def preprocess_image(input: Image.Image) -> Image.Image:
 
205
  if input is None:
206
  return None
207
+ # if has alpha channel, use it directly; otherwise, remove background
 
208
  has_alpha = False
209
  if input.mode == 'RGBA':
210
  alpha = np.array(input)[:, :, 3]
211
  if not np.all(alpha == 255):
212
  has_alpha = True
 
 
213
  max_size = max(input.size)
214
  scale = min(1, 1024 / max_size)
215
  if scale < 1:
216
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
 
 
217
  if has_alpha:
218
  output = input
219
  else:
220
  output = remove_background(input)
 
 
221
  output_np = np.array(output)
222
+ alpha = output_np[:, :, 3]
223
+ bbox = np.argwhere(alpha > 0.8 * 255)
224
+ if bbox.size == 0:
225
+ return output # Return original if no object found
226
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
227
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
228
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
229
+ size = int(size * 1.1) # Added slight padding
230
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
231
+ output = output.crop(bbox) # type: ignore
232
+ output = np.array(output).astype(np.float32) / 255
233
+ output = output[:, :, :3] * output[:, :, 3:4]
234
+ output = Image.fromarray((output * 255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  return output
236
 
237
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
242
  'coords': shape_slat.coords.cpu().numpy(),
243
  'res': res,
244
  }
245
+
246
  def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
247
  shape_slat = SparseTensor(
248
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
 
255
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
256
 
257
  # ==========================================
258
+ # Core Generation Functions (GPU)
259
  # ==========================================
260
 
261
  @spaces.GPU()
262
+ def text_to_image(prompt, progress=gr.Progress(track_tqdm=True)):
263
+ """Generates image using Z-Image Turbo"""
264
+ if z_pipe is None:
265
+ raise gr.Error("Text-to-Image model failed to load.")
266
  if not prompt.strip():
267
+ raise gr.Error("Please enter a text prompt.")
268
+
269
  device = "cuda" if torch.cuda.is_available() else "cpu"
270
+ generator = torch.Generator(device).manual_seed(42) # Can be randomized if needed
271
 
272
+ progress(0.1, desc="Generating Text-to-Image...")
273
+ try:
274
+ result = z_pipe(
275
+ prompt=prompt,
276
+ negative_prompt=None,
277
+ height=1024,
278
+ width=1024,
279
+ num_inference_steps=9,
280
+ guidance_scale=0.0,
281
+ generator=generator,
282
+ )
283
+ image = result.images[0]
284
+ progress(1.0, desc="Image Generation Complete!")
285
+ # Automatically preprocess for 3D (remove bg if needed)
286
+ return preprocess_image(image)
287
+ except Exception as e:
288
+ raise gr.Error(f"Generation failed: {str(e)}")
289
 
290
  @spaces.GPU(duration=120)
291
  def image_to_3d(
 
306
  tex_slat_rescale_t: float,
307
  req: gr.Request,
308
  progress=gr.Progress(track_tqdm=True),
309
+ ) -> str:
 
 
 
 
310
  if image is None:
311
+ raise gr.Error("Please upload or generate an image first.")
 
 
 
 
 
312
 
313
  # --- Sampling ---
314
+ outputs, latents = pipeline.run(
315
  image,
316
  seed=seed,
317
+ preprocess_image=False, # We preprocessed before this step
318
  sparse_structure_sampler_params={
319
  "steps": ss_sampling_steps,
320
  "guidance_strength": ss_guidance_strength,
 
340
  }[resolution],
341
  return_latent=True,
342
  )
 
343
  mesh = outputs[0]
344
+ mesh.simplify(16777216) # nvdiffrast limit
345
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
346
  state = pack_state(latents)
347
  torch.cuda.empty_cache()
348
 
349
  # --- HTML Construction ---
350
  images_html = ""
351
  for m_idx, mode in enumerate(MODES):
 
352
  for s_idx in range(STEPS):
353
  unique_id = f"view-m{m_idx}-s{s_idx}"
354
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
355
  vis_class = "visible" if is_visible else ""
356
+ img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
357
+ images_html += f"""<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">"""
 
 
 
 
 
 
358
 
359
  btns_html = ""
360
+ for idx, mode in enumerate(MODES):
 
361
  active_class = "active" if idx == DEFAULT_MODE else ""
362
+ btns_html += f"""<img src="{mode['icon_base64']}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode['name']}">"""
 
 
 
 
 
363
 
364
  full_html = f"""
365
  <div class="previewer-container">
366
  <div class="tips-wrapper">
367
  <div class="tips-icon">💡Tips</div>
368
+ <div class="tips-text"><p>● <b>Render Mode</b> - Switch render modes.</p><p>● <b>View Angle</b> - Drag slider.</p></div>
 
 
 
 
 
 
 
 
 
369
  </div>
370
+ <div class="display-row">{images_html}</div>
371
+ <div class="mode-row" id="btn-group">{btns_html}</div>
372
  <div class="slider-row">
373
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
374
  </div>
375
  </div>
376
  """
 
377
  return state, full_html
378
 
379
  @spaces.GPU(duration=120)
 
384
  req: gr.Request,
385
  progress=gr.Progress(track_tqdm=True),
386
  ) -> Tuple[str, str]:
387
+ if state is None:
388
+ raise gr.Error("No 3D model generated yet.")
389
+
390
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
391
  shape_slat, tex_slat, res = unpack_state(state)
392
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
393
  mesh.simplify(16777216)
 
394
  glb = o_voxel.postprocess.to_glb(
395
  vertices=mesh.vertices,
396
  faces=mesh.faces,
397
  attr_volume=mesh.attrs,
398
  coords=mesh.coords,
399
+ attr_layout=pipeline.pbr_attr_layout,
400
  grid_size=res,
401
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
402
  decimation_target=decimation_target,
 
406
  remesh_project=0,
407
  use_tqdm=True,
408
  )
 
409
  now = datetime.now()
410
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
411
  os.makedirs(user_dir, exist_ok=True)
 
415
  return glb_path, glb_path
416
 
417
  # ==========================================
418
+ # Gradio UI Blocks
419
  # ==========================================
420
 
421
+ with gr.Blocks(delete_cache=(600, 600), title="TRELLIS.2-3D") as demo:
422
+ gr.Markdown("""
423
+ # TRELLIS.2-3D: Text to 3D Generation
424
+ **Model 1:** [Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) (Text to Image)
425
+ **Model 2:** [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) (Image to 3D)
426
 
427
+ *Generate an image from text, then convert it to a 3D Asset.*
428
+ """)
429
+
430
+ with gr.Row():
431
+ # --- Column 1: Controls ---
432
+ with gr.Column(scale=1, min_width=360):
433
+ gr.Markdown("### 1. Generate Image (Z-Image)")
434
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="A cute isometric 3D render of a robot...")
435
+ gen_image_btn = gr.Button("Generate Image", variant="primary")
436
+
437
+ gr.Markdown("### 2. Configure 3D (TRELLIS)")
438
+ image_input = gr.Image(label="Input Image (Auto-filled or Upload)", format="png", image_mode="RGBA", type="pil", height=300)
439
+
440
+ with gr.Accordion("3D Generation Settings", open=True):
441
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
442
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
443
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
444
+
445
+ with gr.Accordion(label="Advanced 3D Settings", open=False):
446
+ gr.Markdown("Stage 1: Sparse Structure")
447
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5)
448
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Rescale", value=0.7)
449
+ ss_sampling_steps = gr.Slider(1, 50, label="Steps", value=12)
450
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0)
451
 
452
+ gr.Markdown("Stage 2: Shape")
453
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, value=7.5)
454
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, value=0.5)
455
+ shape_slat_sampling_steps = gr.Slider(1, 50, value=12)
456
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, value=3.0)
457
 
458
+ gr.Markdown("Stage 3: Material")
459
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, value=1.0)
460
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, value=0.0)
461
+ tex_slat_sampling_steps = gr.Slider(1, 50, value=12)
462
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, value=3.0)
463
+
464
+ generate_3d_btn = gr.Button("Generate 3D Asset", variant="primary")
465
+
466
+ # --- Column 2: Output ---
467
+ with gr.Column(scale=2):
468
+ with gr.Walkthrough(selected=0) as walkthrough:
469
+ with gr.Step("Preview", id=0):
470
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
471
 
472
+ gr.Markdown("### 3. Extract & Download")
473
+ with gr.Row():
474
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
475
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
476
+ extract_btn = gr.Button("Extract GLB")
477
+
478
+ with gr.Step("Extract", id=1):
479
+ glb_output = gr.Model3D(label="Extracted GLB", height=600, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
480
+ download_btn = gr.DownloadButton(label="Download GLB")
481
+
482
+ output_buf = gr.State()
483
+
484
+ # --- Event Handlers ---
485
+
486
+ demo.load(start_session)
487
+ demo.unload(end_session)
488
+
489
+ # 1. Text to Image
490
+ gen_image_btn.click(
491
+ text_to_image,
492
+ inputs=[text_prompt],
493
+ outputs=[image_input]
494
+ )
495
+
496
+ # Handle manual upload preprocessing
497
+ image_input.upload(
498
+ preprocess_image,
499
+ inputs=[image_input],
500
+ outputs=[image_input],
501
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
+ # 2. Image to 3D
504
+ generate_3d_btn.click(
505
+ get_seed,
506
+ inputs=[randomize_seed, seed],
507
+ outputs=[seed],
508
+ ).then(
509
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
510
+ ).then(
511
+ image_to_3d,
512
+ inputs=[
513
+ image_input, seed, resolution,
514
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
515
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
516
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
517
+ ],
518
+ outputs=[output_buf, preview_output],
519
+ )
520
+
521
+ # 3. Extract GLB
522
+ extract_btn.click(
523
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
524
+ ).then(
525
+ extract_glb,
526
+ inputs=[output_buf, decimation_target, texture_size],
527
+ outputs=[glb_output, download_btn],
528
+ )
529
 
530
+ # ==========================================
531
+ # Main Execution
532
+ # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
+ if __name__ == "__main__":
535
+ os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
536
 
537
+ # Initialize TRELLIS dependencies and assets
538
+ try:
539
+ # Load Icons
540
+ for i in range(len(MODES)):
541
+ if os.path.exists(MODES[i]['icon']):
542
+ icon = Image.open(MODES[i]['icon'])
543
+ MODES[i]['icon_base64'] = image_to_base64(icon)
544
+ else:
545
+ print(f"Warning: Icon not found at {MODES[i]['icon']}")
546
+ MODES[i]['icon_base64'] = "" # Fallback or placeholder
547
+
548
+ # Initialize RMBG and TRELLIS
549
+ print("Loading RMBG and TRELLIS models...")
550
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
551
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
552
+ pipeline.rembg_model = None
553
+ pipeline.low_vram = False
554
+ pipeline.cuda()
555
+
556
+ # Load Env Maps
557
+ envmap_files = {
558
+ 'forest': 'assets/hdri/forest.exr',
559
+ 'sunset': 'assets/hdri/sunset.exr',
560
+ 'courtyard': 'assets/hdri/courtyard.exr'
561
+ }