prithivMLmods commited on
Commit
de9769e
·
verified ·
1 Parent(s): 54a0ae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -277
app.py CHANGED
@@ -2,91 +2,179 @@ import os
2
  import io
3
  import cv2
4
  import time
 
5
  import torch
 
6
  import shutil
7
  import base64
 
8
  import tempfile
9
  import numpy as np
10
  import gradio as gr
 
11
  from PIL import Image
12
  from typing import *
13
  from datetime import datetime
 
 
14
 
15
- # --- Environment Configuration ---
 
16
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
18
- os.environ["ATTN_BACKEND"] = "flash_attn_3" # Ensure you have flash-attn installed, or set to 'xformers'/'flash_attn'
 
19
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
20
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
21
 
22
- # --- Hugging Face Spaces / GPU Setup ---
23
- import spaces
24
- from diffusers import DiffusionPipeline
25
-
26
- # --- TRELLIS Imports ---
27
- # (Assumes running from root of TRELLIS repo)
28
  from trellis2.modules.sparse import SparseTensor
29
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
30
  from trellis2.renderers import EnvMap
31
  from trellis2.utils import render_utils
32
  import o_voxel
33
 
34
- # --- Background Removal ---
35
- # We use rembg locally for stability instead of an API call
36
- try:
37
- from rembg import remove
38
- except ImportError:
39
- print("Please install rembg: pip install rembg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # =========================================
42
- # MODEL LOADING
43
- # =========================================
44
 
45
- print(">>> Loading Z-Image-Turbo Pipeline...")
46
- z_image_pipe = DiffusionPipeline.from_pretrained(
 
47
  "Tongyi-MAI/Z-Image-Turbo",
48
  torch_dtype=torch.bfloat16,
49
  low_cpu_mem_usage=False,
50
  )
51
- z_image_pipe.to("cuda")
52
- print(">>> Z-Image-Turbo Loaded!")
53
 
54
- print(">>> Loading TRELLIS.2 Pipeline...")
55
- trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained(
56
- "microsoft/TRELLIS.2-4B",
57
- torch_dtype=torch.float16
58
- )
59
- trellis_pipeline.cuda()
60
 
61
- # Load EnvMap for rendering previews
62
- try:
63
- envmap = EnvMap.from_file("assets/app/envmap.exr")
64
- except:
65
- print("Warning: envmap.exr not found in assets/app/. Rendering might look flat.")
66
- envmap = None
67
-
68
- print(">>> TRELLIS.2 Loaded!")
69
 
70
- # =========================================
71
- # CONSTANTS & UTILS
72
- # =========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- MAX_SEED = np.iinfo(np.int32).max
75
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
76
- os.makedirs(TMP_DIR, exist_ok=True)
77
 
78
- # Pre-load Icons for HTML Previewer
79
- MODES = [
80
- {"name": "Normal", "icon_path": "assets/app/normal.png", "render_key": "normal"},
81
- {"name": "Clay render", "icon_path": "assets/app/clay.png", "render_key": "clay"},
82
- {"name": "Base color", "icon_path": "assets/app/basecolor.png", "render_key": "base_color"},
83
- {"name": "HDRI forest", "icon_path": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
84
- {"name": "HDRI sunset", "icon_path": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
85
- {"name": "HDRI courtyard", "icon_path": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
86
- ]
87
- STEPS = 8
88
- DEFAULT_MODE = 3
89
- DEFAULT_STEP = 3
90
 
91
  def image_to_base64(image):
92
  buffered = io.BytesIO()
@@ -95,48 +183,59 @@ def image_to_base64(image):
95
  img_str = base64.b64encode(buffered.getvalue()).decode()
96
  return f"data:image/jpeg;base64,{img_str}"
97
 
98
- # Load icons into memory as base64 to avoid path issues in HTML
99
- for mode in MODES:
100
- if os.path.exists(mode['icon_path']):
101
- with open(mode['icon_path'], "rb") as f:
102
- mode['icon_base64'] = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}"
103
- else:
104
- # Fallback empty image if asset missing
105
- mode['icon_base64'] = ""
 
 
 
 
 
 
 
 
 
106
 
107
  def preprocess_image(input_img: Image.Image) -> Image.Image:
108
- """Preprocess: Resize, Remove Background, Center Crop."""
109
- # 1. Resize if too large
 
 
 
 
 
110
  max_size = max(input_img.size)
111
  scale = min(1, 1024 / max_size)
112
  if scale < 1:
113
  input_img = input_img.resize((int(input_img.width * scale), int(input_img.height * scale)), Image.Resampling.LANCZOS)
114
 
115
- # 2. Remove Background (if no alpha)
116
- if input_img.mode != 'RGBA':
117
- input_img = remove(input_img)
118
  else:
119
- # Check if alpha is fully opaque
120
- alpha = np.array(input_img)[:, :, 3]
121
- if np.all(alpha == 255):
122
- input_img = remove(input_img)
123
-
124
- # 3. Crop to content
125
- output_np = np.array(input_img)
126
  alpha = output_np[:, :, 3]
127
  bbox = np.argwhere(alpha > 0.8 * 255)
128
- if len(bbox) == 0: return input_img # Empty image
129
-
 
130
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
131
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
132
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
133
- size = int(size * 1.1) # Add some padding
134
-
135
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
136
- output = input_img.crop(bbox)
137
 
138
- # 4. Composite on white (optional for 3D logic, but TRELLIS likes alpha)
139
- # Keeping alpha channel for TRELLIS
 
 
140
  return output
141
 
142
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
@@ -156,69 +255,74 @@ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
156
  tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
157
  return shape_slat, tex_slat, state['res']
158
 
 
 
159
 
160
- # =========================================
161
- # GRADIO LOGIC
162
- # =========================================
163
 
164
- @spaces.GPU(duration=60)
165
- def generate_z_image(prompt, height, width, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
166
- """Step 1: Text to Image"""
167
  if randomize_seed:
168
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
169
 
170
  generator = torch.Generator("cuda").manual_seed(int(seed))
171
-
172
- print(f"Generating image for: {prompt}")
173
- image = z_image_pipe(
174
  prompt=prompt,
175
  height=int(height),
176
  width=int(width),
177
- num_inference_steps=int(steps),
178
- guidance_scale=0.0, # Turbo usually uses 0 or low guidance
179
  generator=generator,
180
  ).images[0]
181
 
182
  return image, seed
183
 
184
- @spaces.GPU(duration=180)
185
- def generate_3d_trellis(
186
  image: Image.Image,
187
  seed: int,
188
- resolution: str = "1024",
189
- # Advanced Params with defaults
190
- ss_guidance_strength=7.5, ss_sampling_steps=12,
191
- slat_guidance_strength=3.0, slat_sampling_steps=12,
192
- req: gr.Request = None,
193
- progress=gr.Progress(track_tqdm=True)
194
- ):
195
- """Step 2: Image to 3D"""
196
- if image is None:
197
- raise gr.Error("Please generate or upload an image first.")
198
-
199
- # Preprocess
200
- processed_image = preprocess_image(image)
201
-
202
- # Run Pipeline
203
- # Using simplified params for the UI, mapping to full pipeline args
204
- outputs, latents = trellis_pipeline.run(
205
- processed_image,
 
 
206
  seed=seed,
207
- preprocess_image=False, # We did it manually
208
  sparse_structure_sampler_params={
209
  "steps": ss_sampling_steps,
210
  "guidance_strength": ss_guidance_strength,
211
- "guidance_rescale": 0.0, "rescale_t": 0.0,
 
212
  },
213
  shape_slat_sampler_params={
214
- "steps": slat_sampling_steps,
215
- "guidance_strength": slat_guidance_strength,
216
- "guidance_rescale": 0.0, "rescale_t": 0.0,
 
217
  },
218
  tex_slat_sampler_params={
219
- "steps": slat_sampling_steps,
220
- "guidance_strength": slat_guidance_strength,
221
- "guidance_rescale": 0.0, "rescale_t": 0.0,
 
222
  },
223
  pipeline_type={
224
  "512": "512",
@@ -229,34 +333,59 @@ def generate_3d_trellis(
229
  )
230
 
231
  mesh = outputs[0]
232
- # Simplify for visualization
233
- mesh.simplify(16777216)
234
 
235
- # Render Preview (Spinning view)
236
- images_render = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
 
 
 
 
237
  state = pack_state(latents)
238
  torch.cuda.empty_cache()
239
-
240
- # --- Build HTML ---
241
  images_html = ""
242
  for m_idx, mode in enumerate(MODES):
243
- key = mode['render_key']
244
- if key not in images_render: continue
245
-
 
246
  for s_idx in range(STEPS):
247
  unique_id = f"view-m{m_idx}-s{s_idx}"
248
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
249
  vis_class = "visible" if is_visible else ""
250
- img_b64 = image_to_base64(Image.fromarray(images_render[key][s_idx]))
251
- images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_b64}" loading="eager">'
252
-
 
 
 
 
 
 
253
  btns_html = ""
254
  for idx, mode in enumerate(MODES):
 
255
  active_class = "active" if idx == DEFAULT_MODE else ""
256
- btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
257
-
 
 
 
 
 
258
  full_html = f"""
259
  <div class="previewer-container">
 
 
 
 
 
 
 
260
  <div class="display-row">{images_html}</div>
261
  <div class="mode-row" id="btn-group">{btns_html}</div>
262
  <div class="slider-row">
@@ -264,32 +393,31 @@ def generate_3d_trellis(
264
  </div>
265
  </div>
266
  """
267
-
268
  return state, full_html
269
 
270
- @spaces.GPU(duration=60)
271
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int, req: gr.Request):
272
- """Step 3: Export GLB"""
 
 
 
 
 
273
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
274
- os.makedirs(user_dir, exist_ok=True)
275
-
276
  shape_slat, tex_slat, res = unpack_state(state)
277
- mesh = trellis_pipeline.decode_latent(shape_slat, tex_slat, res)[0]
278
-
279
- # Decimation logic
280
- # Approximate face count vs float 0-1
281
- target_faces = int(mesh_simplify * 100000) # Simple mapping
282
 
283
  glb = o_voxel.postprocess.to_glb(
284
  vertices=mesh.vertices,
285
  faces=mesh.faces,
286
  attr_volume=mesh.attrs,
287
  coords=mesh.coords,
288
- attr_layout=trellis_pipeline.pbr_attr_layout,
289
  grid_size=res,
290
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
291
- decimation_target=target_faces,
292
- texture_size=int(texture_size),
293
  remesh=True,
294
  remesh_band=1,
295
  remesh_project=0,
@@ -297,151 +425,146 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int, req: gr.Re
297
  )
298
 
299
  now = datetime.now()
300
- timestamp = now.strftime("%Y-%m-%dT%H%M%S")
301
- glb_path = os.path.join(user_dir, f'trellis_output_{timestamp}.glb')
 
302
  glb.export(glb_path, extension_webp=True)
303
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- return glb_path
306
-
307
-
308
- # =========================================
309
- # CSS & JS
310
- # =========================================
311
-
312
- css = """
313
- .previewer-container {
314
- width: 100%; height: 600px; display: flex; flex-direction: column; align-items: center; justify-content: center;
315
- background: var(--background-fill-secondary); border-radius: 8px; padding: 20px;
316
- }
317
- .display-row { flex-grow: 1; width: 100%; display: flex; justify-content: center; align-items: center; overflow: hidden; }
318
- .previewer-main-image { max-width: 100%; max-height: 100%; object-fit: contain; display: none; }
319
- .previewer-main-image.visible { display: block; }
320
- .mode-row { display: flex; gap: 10px; margin: 10px 0; }
321
- .mode-btn { width: 30px; height: 30px; border-radius: 50%; cursor: pointer; opacity: 0.6; border: 2px solid transparent; }
322
- .mode-btn:hover { opacity: 1; transform: scale(1.1); }
323
- .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
324
- .slider-row { width: 80%; }
325
- input[type=range] { width: 100%; }
326
- """
327
-
328
- head_js = """
329
- <script>
330
- function refreshView(mode, step) {
331
- const allImgs = document.querySelectorAll('.previewer-main-image');
332
- let currentMode = mode;
333
- let currentStep = step;
334
-
335
- // Find current state if args are -1
336
- if (currentMode === -1 || currentStep === -1) {
337
- for (let img of allImgs) {
338
- if (img.classList.contains('visible')) {
339
- const parts = img.id.split('-');
340
- if (currentMode === -1) currentMode = parseInt(parts[1].substring(1));
341
- if (currentStep === -1) currentStep = parseInt(parts[2].substring(1));
342
- break;
343
- }
344
- }
345
- }
346
- if (currentMode === -1) currentMode = 3;
347
- if (currentStep === -1) currentStep = 3;
348
-
349
- allImgs.forEach(img => img.classList.remove('visible'));
350
- const targetId = `view-m${currentMode}-s${currentStep}`;
351
- const target = document.getElementById(targetId);
352
- if (target) target.classList.add('visible');
353
-
354
- const allBtns = document.querySelectorAll('.mode-btn');
355
- allBtns.forEach((btn, idx) => {
356
- if(idx === currentMode) btn.classList.add('active');
357
- else btn.classList.remove('active');
358
- });
359
- }
360
-
361
- function selectMode(mode) { refreshView(mode, -1); }
362
- function onSliderChange(val) { refreshView(-1, parseInt(val)); }
363
- </script>
364
- """
365
-
366
- # =========================================
367
- # APP LAYOUT
368
- # =========================================
369
-
370
- with gr.Blocks(title="Z-Image-Turbo + TRELLIS 2", css=css, head=head_js) as demo:
371
- gr.Markdown("# 🧊 Text to 3D with Z-Image-Turbo + TRELLIS.2")
372
-
373
- # Session state
374
- trellis_state = gr.State()
375
-
376
  with gr.Row():
377
- # --- LEFT COLUMN: Text to Image ---
378
- with gr.Column(scale=1):
379
- gr.Markdown("### 1. Generate Image")
380
- prompt_input = gr.Textbox(label="Prompt", placeholder="A detailed 3D render of a futuristic robot helmet...")
381
 
382
- with gr.Accordion("Image Settings", open=False):
 
 
 
383
  with gr.Row():
384
- height_in = gr.Number(label="Height", value=1024)
385
- width_in = gr.Number(label="Width", value=1024)
386
- steps_in = gr.Slider(label="Steps", minimum=1, maximum=50, value=4, step=1)
387
- seed_in = gr.Number(label="Seed", value=0)
388
- random_seed = gr.Checkbox(label="Randomize Seed", value=True)
389
-
390
- gen_img_btn = gr.Button("Generate Image", variant="primary")
391
-
392
- output_image = gr.Image(label="Generated Image", type="pil", interactive=False)
393
-
394
- # --- RIGHT COLUMN: Image to 3D ---
395
- with gr.Column(scale=2):
396
- gr.Markdown("### 2. Generate 3D")
397
-
398
- with gr.Accordion("TRELLIS Settings", open=False):
399
- seed_3d = gr.Number(label="3D Seed", value=0)
400
- res_3d = gr.Dropdown(label="Resolution", choices=["512", "1024", "1536"], value="1024")
401
-
402
- gen_3d_btn = gr.Button("To 3D 🧊", variant="primary")
403
 
404
- # HTML Previewer
405
- html_output = gr.HTML(label="3D Preview", value="<div style='height:600px; display:flex; align-items:center; justify-content:center; color:gray;'>Generate 3D to view preview</div>")
 
 
 
 
 
 
 
 
406
 
407
- gr.Markdown("### 3. Export")
408
- with gr.Row():
409
- simplify_slider = gr.Slider(label="Mesh Density (Face Count)", minimum=0.1, maximum=2.0, value=0.9)
410
- tex_size_drop = gr.Dropdown(label="Texture Size", choices=[1024, 2048, 4096], value=2048)
411
- export_btn = gr.Button("Export GLB")
412
-
413
- glb_output = gr.File(label="Download GLB")
414
-
415
- # --- Event Wiring ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- # 1. Text to Image
 
 
 
418
  gen_img_btn.click(
419
  fn=generate_z_image,
420
- inputs=[prompt_input, height_in, width_in, steps_in, seed_in, random_seed],
421
- outputs=[output_image, seed_in]
422
  )
423
-
424
- # 2. Image to 3D
425
- gen_3d_btn.click(
426
- fn=generate_3d_trellis,
427
- inputs=[output_image, seed_3d, res_3d],
428
- outputs=[trellis_state, html_output]
429
  )
430
-
431
- # 3. Export
432
- export_btn.click(
433
- fn=extract_glb,
434
- inputs=[trellis_state, simplify_slider, tex_size_drop],
435
- outputs=[glb_output]
436
  )
437
 
438
- def on_load(req: gr.Request):
439
- # Setup session dir
440
- if req:
441
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
442
- os.makedirs(user_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- demo.load(on_load)
 
 
 
 
 
 
 
445
 
446
  if __name__ == "__main__":
447
- demo.queue().launch(show_api=False, share=True)
 
2
  import io
3
  import cv2
4
  import time
5
+ import math
6
  import torch
7
+ import shlex
8
  import shutil
9
  import base64
10
+ import random
11
  import tempfile
12
  import numpy as np
13
  import gradio as gr
14
+ import spaces
15
  from PIL import Image
16
  from typing import *
17
  from datetime import datetime
18
+ from gradio_client import Client, handle_file
19
+ from diffusers import DiffusionPipeline
20
 
21
+ # --- TRELLIS Imports ---
22
+ # Ensure these env vars are set before importing trellis2 modules
23
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
24
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
25
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
26
+ # Adjust path if needed or keep relative
27
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
28
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
29
 
 
 
 
 
 
 
30
  from trellis2.modules.sparse import SparseTensor
31
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
32
  from trellis2.renderers import EnvMap
33
  from trellis2.utils import render_utils
34
  import o_voxel
35
 
36
+ # ==========================================
37
+ # 1. HTML/CSS/JS CONFIGURATION
38
+ # ==========================================
39
+
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
42
+
43
+ MODES = [
44
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
45
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
46
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
47
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
48
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
49
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
50
+ ]
51
+ STEPS = 8
52
+ DEFAULT_MODE = 3
53
+ DEFAULT_STEP = 3
54
+
55
+ css = """
56
+ .stepper-wrapper { padding: 0; }
57
+ .stepper-container { padding: 0; align-items: center; }
58
+ .step-button { flex-direction: row; }
59
+ .step-connector { transform: none; }
60
+ .step-number { width: 16px; height: 16px; }
61
+ .step-label { position: relative; bottom: 0; }
62
+ .previewer-container {
63
+ position: relative; font-family: sans-serif; width: 100%; height: 722px;
64
+ margin: 0 auto; padding: 20px; display: flex; flex-direction: column;
65
+ align-items: center; justify-content: center;
66
+ }
67
+ .previewer-container .tips-icon {
68
+ position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px;
69
+ color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none;
70
+ }
71
+ .previewer-container .tips-text {
72
+ position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent);
73
+ border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10;
74
+ transition: all 0.3s; opacity: 0%; user-select: none;
75
+ }
76
+ .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
77
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
78
+ .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
79
+ .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
80
+ .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
81
+ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
82
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
83
+ .previewer-container .previewer-main-image.visible { display: block; }
84
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
85
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
86
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
87
+ .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
88
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
89
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
90
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
91
+ """
92
+
93
+ head = """
94
+ <script>
95
+ function refreshView(mode, step) {
96
+ const allImgs = document.querySelectorAll('.previewer-main-image');
97
+ for (let i = 0; i < allImgs.length; i++) {
98
+ const img = allImgs[i];
99
+ if (img.classList.contains('visible')) {
100
+ const id = img.id;
101
+ const [_, m, s] = id.split('-');
102
+ if (mode === -1) mode = parseInt(m.slice(1));
103
+ if (step === -1) step = parseInt(s.slice(1));
104
+ break;
105
+ }
106
+ }
107
+ allImgs.forEach(img => img.classList.remove('visible'));
108
+ const targetId = 'view-m' + mode + '-s' + step;
109
+ const targetImg = document.getElementById(targetId);
110
+ if (targetImg) targetImg.classList.add('visible');
111
+ const allBtns = document.querySelectorAll('.mode-btn');
112
+ allBtns.forEach((btn, idx) => {
113
+ if (idx === mode) btn.classList.add('active');
114
+ else btn.classList.remove('active');
115
+ });
116
+ }
117
+ function selectMode(mode) { refreshView(mode, -1); }
118
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
119
+ </script>
120
+ """
121
+
122
+ empty_html = """
123
+ <div class="previewer-container">
124
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
125
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
126
+ </div>
127
+ """
128
 
129
+ # ==========================================
130
+ # 2. MODEL LOADING
131
+ # ==========================================
132
 
133
+ print("Loading Z-Image-Turbo pipeline...")
134
+ # Load Z-Image Pipeline
135
+ z_pipe = DiffusionPipeline.from_pretrained(
136
  "Tongyi-MAI/Z-Image-Turbo",
137
  torch_dtype=torch.bfloat16,
138
  low_cpu_mem_usage=False,
139
  )
140
+ z_pipe.to("cuda")
 
141
 
142
+ print("Loading TRELLIS.2 pipeline...")
143
+ # Load TRELLIS Pipeline
144
+ trellis_pipe = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
145
+ trellis_pipe.rembg_model = None
146
+ trellis_pipe.low_vram = False
147
+ trellis_pipe.cuda()
148
 
149
+ # Load RMBG Client
150
+ print("Loading RMBG Client...")
151
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
 
 
 
 
 
152
 
153
+ # Load HDRI Maps (Ensure assets folder exists)
154
+ try:
155
+ envmap = {
156
+ 'forest': EnvMap(torch.tensor(
157
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
158
+ dtype=torch.float32, device='cuda'
159
+ )),
160
+ 'sunset': EnvMap(torch.tensor(
161
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
162
+ dtype=torch.float32, device='cuda'
163
+ )),
164
+ 'courtyard': EnvMap(torch.tensor(
165
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
166
+ dtype=torch.float32, device='cuda'
167
+ )),
168
+ }
169
+ except Exception as e:
170
+ print(f"Warning: Could not load HDRI maps. Check 'assets/hdri' folder. Error: {e}")
171
+ envmap = {}
172
 
173
+ print("All models loaded!")
 
 
174
 
175
+ # ==========================================
176
+ # 3. HELPER FUNCTIONS
177
+ # ==========================================
 
 
 
 
 
 
 
 
 
178
 
179
  def image_to_base64(image):
180
  buffered = io.BytesIO()
 
183
  img_str = base64.b64encode(buffered.getvalue()).decode()
184
  return f"data:image/jpeg;base64,{img_str}"
185
 
186
+ def start_session(req: gr.Request):
187
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
188
+ os.makedirs(user_dir, exist_ok=True)
189
+
190
+ def end_session(req: gr.Request):
191
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
192
+ if os.path.exists(user_dir):
193
+ shutil.rmtree(user_dir)
194
+
195
+ def remove_background(input_img: Image.Image) -> Image.Image:
196
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
197
+ input_img = input_img.convert('RGB')
198
+ input_img.save(f.name)
199
+ # Using Gradio Client for Bria RMBG
200
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
201
+ output = Image.open(output)
202
+ return output
203
 
204
  def preprocess_image(input_img: Image.Image) -> Image.Image:
205
+ """Preprocess the input image: Resize and Remove Background if needed."""
206
+ has_alpha = False
207
+ if input_img.mode == 'RGBA':
208
+ alpha = np.array(input_img)[:, :, 3]
209
+ if not np.all(alpha == 255):
210
+ has_alpha = True
211
+
212
  max_size = max(input_img.size)
213
  scale = min(1, 1024 / max_size)
214
  if scale < 1:
215
  input_img = input_img.resize((int(input_img.width * scale), int(input_img.height * scale)), Image.Resampling.LANCZOS)
216
 
217
+ if has_alpha:
218
+ output = input_img
 
219
  else:
220
+ output = remove_background(input_img)
221
+
222
+ output_np = np.array(output)
 
 
 
 
223
  alpha = output_np[:, :, 3]
224
  bbox = np.argwhere(alpha > 0.8 * 255)
225
+ if bbox.size == 0:
226
+ return output # Return original if empty
227
+
228
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
229
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
230
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
231
+ size = int(size * 1) # margin
 
232
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
233
+ output = output.crop(bbox)
234
 
235
+ # Normalize
236
+ output = np.array(output).astype(np.float32) / 255
237
+ output = output[:, :, :3] * output[:, :, 3:4]
238
+ output = Image.fromarray((output * 255).astype(np.uint8))
239
  return output
240
 
241
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
 
255
  tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
256
  return shape_slat, tex_slat, state['res']
257
 
258
+ def get_seed(randomize_seed: bool, seed: int) -> int:
259
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
260
 
261
+ # ==========================================
262
+ # 4. CORE GENERATION FUNCTIONS
263
+ # ==========================================
264
 
265
+ @spaces.GPU
266
+ def generate_z_image(prompt, height, width, num_inference_steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
267
+ """Generate image using Z-Image-Turbo"""
268
  if randomize_seed:
269
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
270
 
271
  generator = torch.Generator("cuda").manual_seed(int(seed))
272
+ image = z_pipe(
 
 
273
  prompt=prompt,
274
  height=int(height),
275
  width=int(width),
276
+ num_inference_steps=int(num_inference_steps),
277
+ guidance_scale=0.0,
278
  generator=generator,
279
  ).images[0]
280
 
281
  return image, seed
282
 
283
+ @spaces.GPU(duration=120)
284
+ def generate_trellis_3d(
285
  image: Image.Image,
286
  seed: int,
287
+ resolution: str,
288
+ ss_guidance_strength: float,
289
+ ss_guidance_rescale: float,
290
+ ss_sampling_steps: int,
291
+ ss_rescale_t: float,
292
+ shape_slat_guidance_strength: float,
293
+ shape_slat_guidance_rescale: float,
294
+ shape_slat_sampling_steps: int,
295
+ shape_slat_rescale_t: float,
296
+ tex_slat_guidance_strength: float,
297
+ tex_slat_guidance_rescale: float,
298
+ tex_slat_sampling_steps: int,
299
+ tex_slat_rescale_t: float,
300
+ req: gr.Request,
301
+ progress=gr.Progress(track_tqdm=True),
302
+ ) -> str:
303
+
304
+ # Run pipeline
305
+ outputs, latents = trellis_pipe.run(
306
+ image,
307
  seed=seed,
308
+ preprocess_image=False, # We handle preprocessing in the UI/before calling
309
  sparse_structure_sampler_params={
310
  "steps": ss_sampling_steps,
311
  "guidance_strength": ss_guidance_strength,
312
+ "guidance_rescale": ss_guidance_rescale,
313
+ "rescale_t": ss_rescale_t,
314
  },
315
  shape_slat_sampler_params={
316
+ "steps": shape_slat_sampling_steps,
317
+ "guidance_strength": shape_slat_guidance_strength,
318
+ "guidance_rescale": shape_slat_guidance_rescale,
319
+ "rescale_t": shape_slat_rescale_t,
320
  },
321
  tex_slat_sampler_params={
322
+ "steps": tex_slat_sampling_steps,
323
+ "guidance_strength": tex_slat_guidance_strength,
324
+ "guidance_rescale": tex_slat_guidance_rescale,
325
+ "rescale_t": tex_slat_rescale_t,
326
  },
327
  pipeline_type={
328
  "512": "512",
 
333
  )
334
 
335
  mesh = outputs[0]
336
+ mesh.simplify(16777216) # nvdiffrast limit
 
337
 
338
+ # Render Preview Images
339
+ if not envmap:
340
+ # Fallback if maps missing
341
+ print("Envmap missing, rendering basic")
342
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS)
343
+ else:
344
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
345
+
346
  state = pack_state(latents)
347
  torch.cuda.empty_cache()
348
+
349
+ # --- HTML Construction ---
350
  images_html = ""
351
  for m_idx, mode in enumerate(MODES):
352
+ # Check if render key exists (in case hdri missing)
353
+ if mode['render_key'] not in images:
354
+ continue
355
+
356
  for s_idx in range(STEPS):
357
  unique_id = f"view-m{m_idx}-s{s_idx}"
358
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
359
  vis_class = "visible" if is_visible else ""
360
+ img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
361
+
362
+ images_html += f"""
363
+ <img id="{unique_id}"
364
+ class="previewer-main-image {vis_class}"
365
+ src="{img_base64}"
366
+ loading="eager">
367
+ """
368
+
369
  btns_html = ""
370
  for idx, mode in enumerate(MODES):
371
+ if mode['render_key'] not in images: continue
372
  active_class = "active" if idx == DEFAULT_MODE else ""
373
+ btns_html += f"""
374
+ <img src="{mode['icon_base64']}"
375
+ class="mode-btn {active_class}"
376
+ onclick="selectMode({idx})"
377
+ title="{mode['name']}">
378
+ """
379
+
380
  full_html = f"""
381
  <div class="previewer-container">
382
+ <div class="tips-wrapper">
383
+ <div class="tips-icon">💡Tips</div>
384
+ <div class="tips-text">
385
+ <p>● <b>Render Mode</b> - Click buttons to switch render modes.</p>
386
+ <p>● <b>View Angle</b> - Drag slider to rotate.</p>
387
+ </div>
388
+ </div>
389
  <div class="display-row">{images_html}</div>
390
  <div class="mode-row" id="btn-group">{btns_html}</div>
391
  <div class="slider-row">
 
393
  </div>
394
  </div>
395
  """
 
396
  return state, full_html
397
 
398
+ @spaces.GPU(duration=120)
399
+ def extract_glb(
400
+ state: dict,
401
+ decimation_target: int,
402
+ texture_size: int,
403
+ req: gr.Request,
404
+ progress=gr.Progress(track_tqdm=True),
405
+ ) -> Tuple[str, str]:
406
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
407
  shape_slat, tex_slat, res = unpack_state(state)
408
+ mesh = trellis_pipe.decode_latent(shape_slat, tex_slat, res)[0]
409
+ mesh.simplify(16777216)
 
 
 
410
 
411
  glb = o_voxel.postprocess.to_glb(
412
  vertices=mesh.vertices,
413
  faces=mesh.faces,
414
  attr_volume=mesh.attrs,
415
  coords=mesh.coords,
416
+ attr_layout=trellis_pipe.pbr_attr_layout,
417
  grid_size=res,
418
  aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
419
+ decimation_target=decimation_target,
420
+ texture_size=texture_size,
421
  remesh=True,
422
  remesh_band=1,
423
  remesh_project=0,
 
425
  )
426
 
427
  now = datetime.now()
428
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
429
+ os.makedirs(user_dir, exist_ok=True)
430
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
431
  glb.export(glb_path, extension_webp=True)
432
  torch.cuda.empty_cache()
433
+ return glb_path, glb_path
434
+
435
+ # ==========================================
436
+ # 5. GRADIO APP INTERFACE
437
+ # ==========================================
438
+
439
+ with gr.Blocks(delete_cache=(600, 600), css=css, head=head) as demo:
440
+ gr.Markdown("""
441
+ # Z-Image-Turbo + TRELLIS.2: Text to 3D
442
+ Step 1: Generate an image from text.
443
+ Step 2: Convert that image into a 3D Asset.
444
+ """)
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  with gr.Row():
447
+ # --- LEFT COLUMN: INPUTS ---
448
+ with gr.Column(scale=1, min_width=360):
 
 
449
 
450
+ # --- Z-Image Section ---
451
+ with gr.Group():
452
+ gr.Markdown("### 1. Text to Image (Z-Image)")
453
+ prompt = gr.Textbox(label="Prompt", placeholder="A stylized 3d render of a cute robot...", lines=2)
454
  with gr.Row():
455
+ img_width = gr.Number(label="Width", value=1024, precision=0)
456
+ img_height = gr.Number(label="Height", value=1024, precision=0)
457
+ img_steps = gr.Slider(1, 10, value=4, step=1, label="Steps")
458
+ img_seed = gr.Number(value=42, label="Seed", precision=0)
459
+ img_rand_seed = gr.Checkbox(label="Randomize Seed", value=True)
460
+
461
+ gen_img_btn = gr.Button("Generate Image", variant="primary")
462
+
463
+ # --- Intermediate Image ---
464
+ image_prompt = gr.Image(label="Generated Image (Input for 3D)", format="png", image_mode="RGBA", type="pil", height=400)
 
 
 
 
 
 
 
 
 
465
 
466
+ preprocess_btn = gr.Button("Remove Background (Preprocess)", variant="secondary")
467
+
468
+ # --- TRELLIS Section ---
469
+ with gr.Group():
470
+ gr.Markdown("### 2. Image to 3D (TRELLIS)")
471
+ resolution = gr.Radio(["512", "1024", "1536"], label="3D Resolution", value="1024")
472
+ trellis_seed = gr.Slider(0, MAX_SEED, label="3D Seed", value=0, step=1)
473
+ trellis_rand_seed = gr.Checkbox(label="Randomize 3D Seed", value=True)
474
+
475
+ gen_3d_btn = gr.Button("Generate 3D Model", variant="primary")
476
 
477
+ # Advanced Settings
478
+ with gr.Accordion(label="Advanced 3D Settings", open=False):
479
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
480
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
481
+
482
+ gr.Markdown("Stage 1: Sparse Structure")
483
+ with gr.Row():
484
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
485
+ ss_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
486
+ gr.Markdown("Stage 2: Shape")
487
+ with gr.Row():
488
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=7.5, step=0.1)
489
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
490
+ gr.Markdown("Stage 3: Material")
491
+ with gr.Row():
492
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance", value=1.0, step=0.1)
493
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Steps", value=12, step=1)
494
+
495
+ # Hidden params kept for compatibility
496
+ ss_guidance_rescale = gr.Number(value=0.7, visible=False)
497
+ ss_rescale_t = gr.Number(value=5.0, visible=False)
498
+ shape_slat_guidance_rescale = gr.Number(value=0.5, visible=False)
499
+ shape_slat_rescale_t = gr.Number(value=3.0, visible=False)
500
+ tex_slat_guidance_rescale = gr.Number(value=0.0, visible=False)
501
+ tex_slat_rescale_t = gr.Number(value=3.0, visible=False)
502
+
503
+ # --- RIGHT COLUMN: OUTPUTS ---
504
+ with gr.Column(scale=10):
505
+ with gr.Walkthrough(selected=0) as walkthrough:
506
+ with gr.Step("Preview", id=0):
507
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
508
+ extract_btn = gr.Button("Extract GLB")
509
+ with gr.Step("Extract", id=1):
510
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
511
+ download_btn = gr.DownloadButton(label="Download GLB")
512
+
513
+ # State for the 3D generation latent
514
+ output_buf = gr.State()
515
+
516
+ # --- EVENT HANDLERS ---
517
 
518
+ demo.load(start_session)
519
+ demo.unload(end_session)
520
+
521
+ # 1. Generate Image
522
  gen_img_btn.click(
523
  fn=generate_z_image,
524
+ inputs=[prompt, img_height, img_width, img_steps, img_seed, img_rand_seed],
525
+ outputs=[image_prompt, img_seed] # Update image and show used seed
526
  )
527
+
528
+ # 2. Preprocess Image (Remove BG)
529
+ preprocess_btn.click(
530
+ fn=preprocess_image,
531
+ inputs=[image_prompt],
532
+ outputs=[image_prompt]
533
  )
534
+
535
+ # Auto-preprocess on upload as well (optional, from original code)
536
+ image_prompt.upload(
537
+ preprocess_image,
538
+ inputs=[image_prompt],
539
+ outputs=[image_prompt],
540
  )
541
 
542
+ # 3. Generate 3D
543
+ gen_3d_btn.click(
544
+ get_seed,
545
+ inputs=[trellis_rand_seed, trellis_seed],
546
+ outputs=[trellis_seed],
547
+ ).then(
548
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
549
+ ).then(
550
+ generate_trellis_3d,
551
+ inputs=[
552
+ image_prompt, trellis_seed, resolution,
553
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
554
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
555
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
556
+ ],
557
+ outputs=[output_buf, preview_output],
558
+ )
559
 
560
+ # 4. Extract GLB
561
+ extract_btn.click(
562
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
563
+ ).then(
564
+ extract_glb,
565
+ inputs=[output_buf, decimation_target, texture_size],
566
+ outputs=[glb_output, download_btn],
567
+ )
568
 
569
  if __name__ == "__main__":
570
+ demo.launch()