prithivMLmods commited on
Commit
b958903
·
verified ·
1 Parent(s): de78ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -330
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import os
2
  import shutil
3
- import cv2
4
  import torch
5
  import numpy as np
6
  from PIL import Image
7
- import base64
8
- import io
9
  import tempfile
10
  from typing import *
11
  from datetime import datetime
@@ -26,121 +23,18 @@ import spaces
26
  from diffusers import ZImagePipeline
27
 
28
  # --- TRELLIS Specific Imports ---
29
- from trellis2.modules.sparse import SparseTensor
30
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
31
- from trellis2.renderers import EnvMap
32
- from trellis2.utils import render_utils
33
  import o_voxel
34
 
35
  # ==========================================
36
- # Global Configuration & Assets
37
  # ==========================================
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
 
42
- # TRELLIS Render Modes
43
- MODES = [
44
- {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
45
- {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
46
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
47
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
48
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
49
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
50
- ]
51
- STEPS = 8
52
- DEFAULT_MODE = 3
53
- DEFAULT_STEP = 3
54
-
55
  # ==========================================
56
- # CSS & JavaScript (For Custom Previewer)
57
- # ==========================================
58
-
59
- css = """
60
- /* Overwrite Gradio Default Style */
61
- .stepper-wrapper { padding: 0; }
62
- .stepper-container { padding: 0; align-items: center; }
63
- .step-button { flex-direction: row; }
64
- .step-connector { transform: none; }
65
- .step-number { width: 16px; height: 16px; }
66
- .step-label { position: relative; bottom: 0; }
67
- .wrap.center.full { inset: 0; height: 100%; }
68
- .wrap.center.full.translucent { background: var(--block-background-fill); }
69
- .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
70
-
71
- /* Previewer */
72
- .previewer-container {
73
- position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
74
- width: 100%; height: 722px; margin: 0 auto; padding: 20px;
75
- display: flex; flex-direction: column; align-items: center; justify-content: center;
76
- }
77
- .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
78
- .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
79
- .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
80
- .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
81
-
82
- /* Row 1: Display Modes */
83
- .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
84
- .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
85
- .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
86
- .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
87
-
88
- /* Row 2: Display Image */
89
- .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
90
- .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
91
- .previewer-container .previewer-main-image.visible { display: block; }
92
-
93
- /* Row 3: Custom HTML Slider */
94
- .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
95
- .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
96
- .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
97
- .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
98
- .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
99
-
100
- /* Overwrite Previewer Block Style */
101
- .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
102
- .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
103
- """
104
-
105
- head = """
106
- <script>
107
- function refreshView(mode, step) {
108
- const allImgs = document.querySelectorAll('.previewer-main-image');
109
- for (let i = 0; i < allImgs.length; i++) {
110
- const img = allImgs[i];
111
- if (img.classList.contains('visible')) {
112
- const id = img.id;
113
- const [_, m, s] = id.split('-');
114
- if (mode === -1) mode = parseInt(m.slice(1));
115
- if (step === -1) step = parseInt(s.slice(1));
116
- break;
117
- }
118
- }
119
- allImgs.forEach(img => img.classList.remove('visible'));
120
- const targetId = 'view-m' + mode + '-s' + step;
121
- const targetImg = document.getElementById(targetId);
122
- if (targetImg) targetImg.classList.add('visible');
123
-
124
- const allBtns = document.querySelectorAll('.mode-btn');
125
- allBtns.forEach((btn, idx) => {
126
- if (idx === mode) btn.classList.add('active');
127
- else btn.classList.remove('active');
128
- });
129
- }
130
- function selectMode(mode) { refreshView(mode, -1); }
131
- function onSliderChange(val) { refreshView(-1, parseInt(val)); }
132
- </script>
133
- """
134
-
135
- empty_html = f"""
136
- <div class="previewer-container">
137
- <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
138
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
139
- </div>
140
- """
141
-
142
- # ==========================================
143
- # Model & Asset Loading
144
  # ==========================================
145
 
146
  print("Initializing models...")
@@ -162,49 +56,23 @@ except Exception as e:
162
 
163
  # 2. TRELLIS.2 (Image to 3D)
164
  print("Loading TRELLIS.2...")
165
- trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
166
- trellis_pipeline.rembg_model = None
167
- trellis_pipeline.low_vram = False
168
- trellis_pipeline.cuda()
 
 
 
 
 
169
 
170
  # 3. Background Remover
171
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
172
 
173
- # 4. HDRI Maps (FIXED: Robust Fallback)
174
- def load_envmap_safe(path):
175
- """Attempt to load an HDRI map, return a dummy white map if it fails."""
176
- try:
177
- if not os.path.exists(path):
178
- raise FileNotFoundError(f"{path} not found")
179
- im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
180
- if im is None:
181
- raise ValueError(f"OpenCV returned None for {path}")
182
- rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
183
- return EnvMap(torch.tensor(rgb, dtype=torch.float32, device='cuda'))
184
- except Exception as e:
185
- print(f"Warning: Could not load HDRI from {path} ({e}). Using synthetic fallback.")
186
- # Create a simple white environment map (16x32 resolution)
187
- synthetic_env = torch.ones((16, 32, 3), dtype=torch.float32, device='cuda')
188
- return EnvMap(synthetic_env)
189
-
190
- # Always populate these keys, even if files are missing
191
- envmap = {
192
- 'forest': load_envmap_safe('assets/hdri/forest.exr'),
193
- 'sunset': load_envmap_safe('assets/hdri/sunset.exr'),
194
- 'courtyard': load_envmap_safe('assets/hdri/courtyard.exr'),
195
- }
196
-
197
  # ==========================================
198
  # Helper Functions
199
  # ==========================================
200
 
201
- def image_to_base64(image):
202
- buffered = io.BytesIO()
203
- image = image.convert("RGB")
204
- image.save(buffered, format="jpeg", quality=85)
205
- img_str = base64.b64encode(buffered.getvalue()).decode()
206
- return f"data:image/jpeg;base64,{img_str}"
207
-
208
  def start_session(req: gr.Request):
209
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
210
  os.makedirs(user_dir, exist_ok=True)
@@ -257,23 +125,6 @@ def preprocess_image(input: Image.Image) -> Image.Image:
257
  output = Image.fromarray((output * 255).astype(np.uint8))
258
  return output
259
 
260
- def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
261
- shape_slat, tex_slat, res = latents
262
- return {
263
- 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
264
- 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
265
- 'coords': shape_slat.coords.cpu().numpy(),
266
- 'res': res,
267
- }
268
-
269
- def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
270
- shape_slat = SparseTensor(
271
- feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
272
- coords=torch.from_numpy(state['coords']).cuda(),
273
- )
274
- tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
275
- return shape_slat, tex_slat, state['res']
276
-
277
  def get_seed(randomize_seed: bool, seed: int) -> int:
278
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
279
 
@@ -290,7 +141,7 @@ def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)):
290
  raise gr.Error("Please enter a prompt.")
291
 
292
  device = "cuda" if torch.cuda.is_available() else "cpu"
293
- generator = torch.Generator(device).manual_seed(42) # Fixed seed for consistency demo
294
 
295
  progress(0.1, desc="Generating Text-to-Image...")
296
  try:
@@ -307,131 +158,85 @@ def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)):
307
  except Exception as e:
308
  raise gr.Error(f"Z-Image Generation failed: {str(e)}")
309
 
310
- @spaces.GPU(duration=120)
311
- def image_to_3d(
312
  image: Image.Image,
313
  seed: int,
314
  resolution: str,
 
 
315
  ss_guidance_strength: float,
316
  ss_guidance_rescale: float,
317
  ss_sampling_steps: int,
318
  ss_rescale_t: float,
319
- shape_slat_guidance_strength: float,
320
- shape_slat_guidance_rescale: float,
321
- shape_slat_sampling_steps: int,
322
- shape_slat_rescale_t: float,
323
- tex_slat_guidance_strength: float,
324
- tex_slat_guidance_rescale: float,
325
- tex_slat_sampling_steps: int,
326
- tex_slat_rescale_t: float,
327
  req: gr.Request,
328
  progress=gr.Progress(track_tqdm=True),
329
- ) -> str:
330
 
331
  if image is None:
332
- raise gr.Error("Input image is missing.")
333
-
334
- # --- Sampling ---
335
- outputs, latents = trellis_pipeline.run(
336
- image,
337
- seed=seed,
338
- preprocess_image=False, # We pre-process in the upload handler or assume clean input
339
- sparse_structure_sampler_params={
340
- "steps": ss_sampling_steps,
341
- "guidance_strength": ss_guidance_strength,
342
- "guidance_rescale": ss_guidance_rescale,
343
- "rescale_t": ss_rescale_t,
344
- },
345
- shape_slat_sampler_params={
346
- "steps": shape_slat_sampling_steps,
347
- "guidance_strength": shape_slat_guidance_strength,
348
- "guidance_rescale": shape_slat_guidance_rescale,
349
- "rescale_t": shape_slat_rescale_t,
350
- },
351
- tex_slat_sampler_params={
352
- "steps": tex_slat_sampling_steps,
353
- "guidance_strength": tex_slat_guidance_strength,
354
- "guidance_rescale": tex_slat_guidance_rescale,
355
- "rescale_t": tex_slat_rescale_t,
356
- },
357
- pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution],
358
- return_latent=True,
359
- )
360
- mesh = outputs[0]
361
- mesh.simplify(16777216)
362
-
363
- # Render Preview
364
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
365
- state = pack_state(latents)
366
- torch.cuda.empty_cache()
367
 
368
- # --- HTML Construction ---
369
- images_html = ""
370
- for m_idx, mode in enumerate(MODES):
371
- for s_idx in range(STEPS):
372
- unique_id = f"view-m{m_idx}-s{s_idx}"
373
- is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
374
- vis_class = "visible" if is_visible else ""
375
-
376
- render_key = mode['render_key']
377
- if render_key in images:
378
- img_data = images[render_key][s_idx]
379
- img_base64 = image_to_base64(Image.fromarray(img_data))
380
- images_html += f"""<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">"""
381
- else:
382
- images_html += f"""<div id="{unique_id}" class="previewer-main-image {vis_class}" style="width:100%;height:100%;background:#333;color:white;display:flex;align-items:center;justify-content:center;">Render Error</div>"""
383
-
384
- btns_html = ""
385
- for idx, mode in enumerate(MODES):
386
- active_class = "active" if idx == DEFAULT_MODE else ""
387
- btns_html += f"""<img src="{mode['icon_base64']}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode['name']}">"""
388
-
389
- full_html = f"""
390
- <div class="previewer-container">
391
- <div class="tips-wrapper">
392
- <div class="tips-icon">💡Tips</div>
393
- <div class="tips-text"><p>● <b>Render Mode</b> - Click buttons to switch modes.</p><p>● <b>View Angle</b> - Drag slider to rotate.</p></div>
394
- </div>
395
- <div class="display-row">{images_html}</div>
396
- <div class="mode-row" id="btn-group">{btns_html}</div>
397
- <div class="slider-row">
398
- <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
399
- </div>
400
- </div>
401
- """
402
- return state, full_html
403
-
404
- # Increased timeout to 300s (5mins) for robust extraction
405
- @spaces.GPU(duration=300)
406
- def extract_glb(
407
- state: dict,
408
- decimation_target: int,
409
- texture_size: int,
410
- req: gr.Request,
411
- progress=gr.Progress(track_tqdm=True),
412
- ) -> Tuple[str, str]:
413
- if state is None:
414
- raise gr.Error("No 3D model generated yet.")
415
-
416
- # Clear cache before heavy operation
417
- torch.cuda.empty_cache()
418
-
419
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
420
- shape_slat, tex_slat, res = unpack_state(state)
421
 
 
 
422
  try:
423
- # Decode Latents
424
- mesh = trellis_pipeline.decode_latent(shape_slat, tex_slat, res)[0]
425
- mesh.simplify(16777216)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
- # Convert to GLB
428
  glb = o_voxel.postprocess.to_glb(
429
  vertices=mesh.vertices,
430
  faces=mesh.faces,
431
  attr_volume=mesh.attrs,
432
  coords=mesh.coords,
433
  attr_layout=trellis_pipeline.pbr_attr_layout,
434
- grid_size=res,
435
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
436
  decimation_target=decimation_target,
437
  texture_size=texture_size,
@@ -441,19 +246,19 @@ def extract_glb(
441
  use_tqdm=True,
442
  )
443
 
444
- # Export
445
  now = datetime.now()
446
- timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
447
- os.makedirs(user_dir, exist_ok=True)
448
- glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
449
  glb.export(glb_path, extension_webp=True)
450
 
451
- except Exception as e:
452
  torch.cuda.empty_cache()
453
- raise gr.Error(f"Extraction failed: {str(e)}")
454
 
455
- torch.cuda.empty_cache()
456
- return glb_path, glb_path
 
 
457
 
458
  # ==========================================
459
  # Gradio UI Blocks
@@ -461,50 +266,46 @@ def extract_glb(
461
 
462
  if __name__ == "__main__":
463
  os.makedirs(TMP_DIR, exist_ok=True)
464
-
465
- # Pre-process icon base64
466
- for i in range(len(MODES)):
467
- icon_path = MODES[i]['icon']
468
- if os.path.exists(icon_path):
469
- icon = Image.open(icon_path)
470
- MODES[i]['icon_base64'] = image_to_base64(icon)
471
- else:
472
- MODES[i]['icon_base64'] = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
473
 
474
  with gr.Blocks(delete_cache=(600, 600)) as demo:
475
  gr.Markdown("""
476
- # TRELLIS.2-3D
477
- **Unified Text-to-3D Workflow**
478
 
479
- 1. **Text to Image**: Generate a base image using Z-Image-Turbo.
480
- 2. **Image to 3D**: Convert that image into a high-quality 3D asset using TRELLIS.2.
 
481
  """)
482
 
483
  with gr.Row():
484
- # --- Column 1: Inputs & Config ---
485
  with gr.Column(scale=1, min_width=360):
486
 
487
- # --- Step 1: Text to Image ---
488
- with gr.Group():
489
- gr.Markdown("### Step 1: Generate Image")
490
- txt_prompt = gr.Textbox(label="Prompt", placeholder="A sci-fi helmet, high quality, white background")
491
- btn_gen_img = gr.Button("Generate Image", variant="secondary")
 
 
492
 
493
- # --- Step 2: Image to 3D Input ---
494
- gr.Markdown("### Step 2: Configure & Convert")
495
- image_prompt = gr.Image(label="Input Image (Generated or Uploaded)", format="png", image_mode="RGBA", type="pil", height=300)
496
 
497
- with gr.Accordion("3D Generation Settings", open=True):
498
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
 
 
499
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
500
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
501
- # LOWERED DEFAULTS TO PREVENT BROWSER FREEZE
502
- decimation_target = gr.Slider(50000, 500000, label="Decimation Target (Faces)", value=100000, step=10000)
 
503
  texture_size = gr.Slider(512, 4096, label="Texture Size", value=1024, step=512)
504
 
505
- btn_gen_3d = gr.Button("Generate 3D", variant="primary")
506
 
507
- with gr.Accordion(label="Advanced Sampling Settings", open=False):
 
508
  gr.Markdown("**Stage 1: Sparse Structure**")
509
  ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
510
  ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale")
@@ -524,28 +325,25 @@ if __name__ == "__main__":
524
  tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
525
 
526
  # --- Column 2: Outputs ---
527
- with gr.Column(scale=10):
528
- with gr.Walkthrough(selected=0) as walkthrough:
529
- with gr.Step("Preview", id=0):
530
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
531
- extract_btn = gr.Button("Extract GLB")
532
- with gr.Step("Extract", id=1):
533
- # Lower default height to prevent full screen takeover
534
- glb_output = gr.Model3D(label="Extracted GLB", height=600, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
535
- download_btn = gr.DownloadButton(label="Download GLB")
536
-
537
- gr.Markdown("Note: Extraction may take 1-2 minutes. If browser slows down, reduce 'Decimation Target'.")
538
 
539
  # ==========================================
540
  # Wiring Events
541
  # ==========================================
542
-
543
- output_buf = gr.State()
544
 
545
  demo.load(start_session)
546
  demo.unload(end_session)
547
 
548
- # 1. Text to Image Event
549
  btn_gen_img.click(
550
  generate_txt2img,
551
  inputs=[txt_prompt],
@@ -556,38 +354,28 @@ if __name__ == "__main__":
556
  outputs=[image_prompt]
557
  )
558
 
559
- # 2. Upload Image Event (Preprocess)
560
  image_prompt.upload(
561
  preprocess_image,
562
  inputs=[image_prompt],
563
  outputs=[image_prompt],
564
  )
565
 
566
- # 3. Image to 3D Event
567
  btn_gen_3d.click(
568
  get_seed,
569
  inputs=[randomize_seed, seed],
570
  outputs=[seed],
571
  ).then(
572
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
573
- ).then(
574
- image_to_3d,
575
  inputs=[
576
- image_prompt, seed, resolution,
 
577
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
578
  shape_guidance, shape_rescale, shape_steps, shape_rescale_t,
579
  tex_guidance, tex_rescale, tex_steps, tex_rescale_t,
580
  ],
581
- outputs=[output_buf, preview_output],
582
- )
583
-
584
- # 4. Extraction Event
585
- extract_btn.click(
586
- lambda: gr.Walkthrough(selected=1), outputs=walkthrough
587
- ).then(
588
- extract_glb,
589
- inputs=[output_buf, decimation_target, texture_size],
590
  outputs=[glb_output, download_btn],
591
  )
592
 
593
- demo.launch(css=css, head=head)
 
1
  import os
2
  import shutil
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
 
6
  import tempfile
7
  from typing import *
8
  from datetime import datetime
 
23
  from diffusers import ZImagePipeline
24
 
25
  # --- TRELLIS Specific Imports ---
 
26
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
 
 
27
  import o_voxel
28
 
29
  # ==========================================
30
+ # Global Configuration
31
  # ==========================================
32
 
33
  MAX_SEED = np.iinfo(np.int32).max
34
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ==========================================
37
+ # Model Loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ==========================================
39
 
40
  print("Initializing models...")
 
56
 
57
  # 2. TRELLIS.2 (Image to 3D)
58
  print("Loading TRELLIS.2...")
59
+ try:
60
+ trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
61
+ trellis_pipeline.rembg_model = None
62
+ trellis_pipeline.low_vram = False
63
+ trellis_pipeline.cuda()
64
+ print("TRELLIS.2 loaded.")
65
+ except Exception as e:
66
+ print(f"Failed to load TRELLIS.2: {e}")
67
+ trellis_pipeline = None
68
 
69
  # 3. Background Remover
70
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ==========================================
73
  # Helper Functions
74
  # ==========================================
75
 
 
 
 
 
 
 
 
76
  def start_session(req: gr.Request):
77
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
78
  os.makedirs(user_dir, exist_ok=True)
 
125
  output = Image.fromarray((output * 255).astype(np.uint8))
126
  return output
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_seed(randomize_seed: bool, seed: int) -> int:
129
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
130
 
 
141
  raise gr.Error("Please enter a prompt.")
142
 
143
  device = "cuda" if torch.cuda.is_available() else "cpu"
144
+ generator = torch.Generator(device).manual_seed(42)
145
 
146
  progress(0.1, desc="Generating Text-to-Image...")
147
  try:
 
158
  except Exception as e:
159
  raise gr.Error(f"Z-Image Generation failed: {str(e)}")
160
 
161
+ @spaces.GPU(duration=300)
162
+ def generate_3d(
163
  image: Image.Image,
164
  seed: int,
165
  resolution: str,
166
+ decimation_target: int,
167
+ texture_size: int,
168
  ss_guidance_strength: float,
169
  ss_guidance_rescale: float,
170
  ss_sampling_steps: int,
171
  ss_rescale_t: float,
172
+ shape_guidance: float,
173
+ shape_rescale: float,
174
+ shape_steps: int,
175
+ shape_rescale_t: float,
176
+ tex_guidance: float,
177
+ tex_rescale: float,
178
+ tex_steps: int,
179
+ tex_rescale_t: float,
180
  req: gr.Request,
181
  progress=gr.Progress(track_tqdm=True),
182
+ ) -> Tuple[str, str]:
183
 
184
  if image is None:
185
+ raise gr.Error("Please provide an input image.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ if trellis_pipeline is None:
188
+ raise gr.Error("TRELLIS model is not loaded.")
189
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
191
+ os.makedirs(user_dir, exist_ok=True)
192
 
193
+ # 1. Run Pipeline
194
+ progress(0.1, desc="Generating 3D Geometry...")
195
  try:
196
+ outputs, latents = trellis_pipeline.run(
197
+ image,
198
+ seed=seed,
199
+ preprocess_image=False, # Assumed already preprocessed by UI handler
200
+ sparse_structure_sampler_params={
201
+ "steps": ss_sampling_steps,
202
+ "guidance_strength": ss_guidance_strength,
203
+ "guidance_rescale": ss_guidance_rescale,
204
+ "rescale_t": ss_rescale_t,
205
+ },
206
+ shape_slat_sampler_params={
207
+ "steps": shape_steps,
208
+ "guidance_strength": shape_guidance,
209
+ "guidance_rescale": shape_rescale,
210
+ "rescale_t": shape_rescale_t,
211
+ },
212
+ tex_slat_sampler_params={
213
+ "steps": tex_steps,
214
+ "guidance_strength": tex_guidance,
215
+ "guidance_rescale": tex_rescale,
216
+ "rescale_t": tex_rescale_t,
217
+ },
218
+ pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution],
219
+ return_latent=True,
220
+ )
221
+
222
+ # 2. Process Mesh
223
+ progress(0.7, desc="Processing Mesh...")
224
+ mesh = outputs[0]
225
+ mesh.simplify(16777216) # Simplify for processing limits
226
+
227
+ # 3. Export to GLB
228
+ progress(0.9, desc="Baking Texture & Exporting GLB...")
229
+
230
+ # Note: We use the latent grid resolution from the pipeline output
231
+ grid_size = latents[2]
232
 
 
233
  glb = o_voxel.postprocess.to_glb(
234
  vertices=mesh.vertices,
235
  faces=mesh.faces,
236
  attr_volume=mesh.attrs,
237
  coords=mesh.coords,
238
  attr_layout=trellis_pipeline.pbr_attr_layout,
239
+ grid_size=grid_size,
240
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
241
  decimation_target=decimation_target,
242
  texture_size=texture_size,
 
246
  use_tqdm=True,
247
  )
248
 
 
249
  now = datetime.now()
250
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S")
251
+ glb_path = os.path.join(user_dir, f'trellis_output_{timestamp}.glb')
 
252
  glb.export(glb_path, extension_webp=True)
253
 
254
+ # Clean up
255
  torch.cuda.empty_cache()
256
+ return glb_path, glb_path
257
 
258
+ except Exception as e:
259
+ torch.cuda.empty_cache()
260
+ raise gr.Error(f"Generation failed: {str(e)}")
261
+
262
 
263
  # ==========================================
264
  # Gradio UI Blocks
 
266
 
267
  if __name__ == "__main__":
268
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
269
 
270
  with gr.Blocks(delete_cache=(600, 600)) as demo:
271
  gr.Markdown("""
272
+ # TRELLIS.2-3D (Direct Export)
 
273
 
274
+ **Workflow:**
275
+ 1. **Generate Image** (Text-to-Image) or **Upload Image**.
276
+ 2. Click **Generate 3D** to create the GLB asset directly.
277
  """)
278
 
279
  with gr.Row():
280
+ # --- Column 1: Inputs ---
281
  with gr.Column(scale=1, min_width=360):
282
 
283
+ # Input Source
284
+ with gr.Tabs():
285
+ with gr.Tab("Text to Image"):
286
+ txt_prompt = gr.Textbox(label="Prompt", placeholder="A treasure chest, isometric style, white background", lines=3)
287
+ btn_gen_img = gr.Button("Generate Image", variant="secondary")
288
+ with gr.Tab("Upload"):
289
+ gr.Markdown("Upload an image directly if you have one.")
290
 
291
+ # Image Display
292
+ image_prompt = gr.Image(label="Input Image", format="png", image_mode="RGBA", type="pil", height=300)
 
293
 
294
+ # 3D Settings
295
+ gr.Markdown("### 3D Settings")
296
+ with gr.Group():
297
+ resolution = gr.Radio(["512", "1024", "1536"], label="Generation Resolution", value="1024")
298
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
299
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
300
+
301
+ gr.Markdown(" **Export Settings**")
302
+ decimation_target = gr.Slider(50000, 500000, label="Target Faces", value=150000, step=10000)
303
  texture_size = gr.Slider(512, 4096, label="Texture Size", value=1024, step=512)
304
 
305
+ btn_gen_3d = gr.Button("Generate 3D Asset", variant="primary", scale=2)
306
 
307
+ # Advanced Settings
308
+ with gr.Accordion(label="Advanced Sampler Settings", open=False):
309
  gr.Markdown("**Stage 1: Sparse Structure**")
310
  ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
311
  ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale")
 
325
  tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
326
 
327
  # --- Column 2: Outputs ---
328
+ with gr.Column(scale=2):
329
+ gr.Markdown("### 3D Output")
330
+ glb_output = gr.Model3D(
331
+ label="Generated GLB",
332
+ display_mode="solid",
333
+ clear_color=(0.2, 0.2, 0.2, 1.0),
334
+ height=600,
335
+ interactive=True
336
+ )
337
+ download_btn = gr.DownloadButton(label="Download GLB File")
 
338
 
339
  # ==========================================
340
  # Wiring Events
341
  # ==========================================
 
 
342
 
343
  demo.load(start_session)
344
  demo.unload(end_session)
345
 
346
+ # 1. Text to Image
347
  btn_gen_img.click(
348
  generate_txt2img,
349
  inputs=[txt_prompt],
 
354
  outputs=[image_prompt]
355
  )
356
 
357
+ # 2. Auto-preprocess uploaded images
358
  image_prompt.upload(
359
  preprocess_image,
360
  inputs=[image_prompt],
361
  outputs=[image_prompt],
362
  )
363
 
364
+ # 3. Generate 3D
365
  btn_gen_3d.click(
366
  get_seed,
367
  inputs=[randomize_seed, seed],
368
  outputs=[seed],
369
  ).then(
370
+ generate_3d,
 
 
371
  inputs=[
372
+ image_prompt, seed, resolution,
373
+ decimation_target, texture_size,
374
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
375
  shape_guidance, shape_rescale, shape_steps, shape_rescale_t,
376
  tex_guidance, tex_rescale, tex_steps, tex_rescale_t,
377
  ],
 
 
 
 
 
 
 
 
 
378
  outputs=[glb_output, download_btn],
379
  )
380
 
381
+ demo.launch()