prithivMLmods commited on
Commit
6cdc2f9
·
verified ·
1 Parent(s): 9166768

update app [.]

Browse files
Files changed (1) hide show
  1. app.py +311 -294
app.py CHANGED
@@ -1,44 +1,45 @@
1
  import os
2
- import io
3
- import cv2
4
- import time
5
  import shutil
6
- import base64
7
  import torch
8
  import numpy as np
9
- import tempfile
10
- import gradio as gr
11
- import spaces
12
- from pathlib import Path
13
- from typing import Tuple, List, Optional
14
  from PIL import Image
 
 
 
 
15
  from datetime import datetime
16
- from gradio_client import Client, handle_file
17
 
18
- # --- 1. Environment Configuration ---
19
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
20
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
21
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
22
- # Note: Ensure autotune_cache.json exists or remove this line if not needed
23
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
24
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
25
 
26
- # --- 2. Imports from specific libraries ---
27
- # Ensure Z-Image-Turbo dependencies are met
 
 
28
  from diffusers import ZImagePipeline
29
 
30
- # Ensure TRELLIS dependencies are met
31
  from trellis2.modules.sparse import SparseTensor
32
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
33
  from trellis2.renderers import EnvMap
34
  from trellis2.utils import render_utils
35
  import o_voxel
36
 
37
- # --- 3. Constants & Settings ---
 
 
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
40
 
41
- # Rendering Modes configuration
42
  MODES = [
43
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
44
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
@@ -51,8 +52,12 @@ STEPS = 8
51
  DEFAULT_MODE = 3
52
  DEFAULT_STEP = 3
53
 
54
- # --- 4. CSS & JavaScript (From TRELLIS) ---
 
 
 
55
  css = """
 
56
  .stepper-wrapper { padding: 0; }
57
  .stepper-container { padding: 0; align-items: center; }
58
  .step-button { flex-direction: row; }
@@ -61,23 +66,38 @@ css = """
61
  .step-label { position: relative; bottom: 0; }
62
  .wrap.center.full { inset: 0; height: 100%; }
63
  .wrap.center.full.translucent { background: var(--block-background-fill); }
64
- .meta-text-center { display: block !important; position: absolute !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
65
- .previewer-container { position: relative; font-family: sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; }
 
 
 
 
 
 
66
  .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
67
  .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
 
68
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
 
 
69
  .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
70
  .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
71
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
72
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
 
 
73
  .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
74
  .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
75
  .previewer-container .previewer-main-image.visible { display: block; }
 
 
76
  .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
77
  .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
78
  .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
79
  .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
80
  .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
 
 
81
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
82
  .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
83
  """
@@ -100,6 +120,7 @@ head = """
100
  const targetId = 'view-m' + mode + '-s' + step;
101
  const targetImg = document.getElementById(targetId);
102
  if (targetImg) targetImg.classList.add('visible');
 
103
  const allBtns = document.querySelectorAll('.mode-btn');
104
  allBtns.forEach((btn, idx) => {
105
  if (idx === mode) btn.classList.add('active');
@@ -111,58 +132,60 @@ head = """
111
  </script>
112
  """
113
 
114
- empty_html = """
115
  <div class="previewer-container">
116
- <div style="text-align: center; color: var(--body-text-color); opacity: 0.5;">
117
- <p>Generated 3D Preview will appear here</p>
118
- </div>
119
  </div>
120
  """
121
 
122
- # --- 5. Global Model Initialization ---
123
- print("Initializing Models...")
 
 
 
124
 
125
- # A. Load Z-Image-Turbo
126
  print("Loading Z-Image-Turbo...")
127
- z_pipe = ZImagePipeline.from_pretrained(
128
- "Tongyi-MAI/Z-Image-Turbo",
129
- torch_dtype=torch.bfloat16,
130
- low_cpu_mem_usage=False,
131
- )
132
- if torch.cuda.is_available():
133
- z_pipe.to("cuda")
134
-
135
- # B. Load TRELLIS.2
136
- print("Loading TRELLIS.2-4B...")
 
 
 
 
 
 
137
  trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
138
- trellis_pipeline.rembg_model = None # We use external API/Client for robustness
139
  trellis_pipeline.low_vram = False
140
- if torch.cuda.is_available():
141
- trellis_pipeline.cuda()
142
 
143
- # C. Load RMBG Client (Background Removal)
144
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
145
 
146
- # D. Load Environment Maps for Rendering
147
- # Helper to safely load local assets
148
- def load_env_map(name, filename):
149
- path = os.path.join('assets/hdri', filename)
150
- if os.path.exists(path):
151
- return EnvMap(torch.tensor(
152
- cv2.cvtColor(cv2.imread(path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
153
- dtype=torch.float32, device='cuda'
154
- ))
155
- else:
156
- print(f"Warning: HDRI {filename} not found. Rendering might fail or be black.")
157
- return None # Should handle fallback
158
-
159
- envmap = {
160
- 'forest': load_env_map('forest', 'forest.exr'),
161
- 'sunset': load_env_map('sunset', 'sunset.exr'),
162
- 'courtyard': load_env_map('courtyard', 'courtyard.exr'),
163
- }
164
 
165
- # --- 6. Helper Functions ---
 
 
166
 
167
  def image_to_base64(image):
168
  buffered = io.BytesIO()
@@ -180,92 +203,89 @@ def end_session(req: gr.Request):
180
  if os.path.exists(user_dir):
181
  shutil.rmtree(user_dir)
182
 
183
- def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
184
- shape_slat, tex_slat, res = latents
185
- return {
186
- 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
187
- 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
188
- 'coords': shape_slat.coords.cpu().numpy(),
189
- 'res': res,
190
- }
191
-
192
- def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
193
- shape_slat = SparseTensor(
194
- feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
195
- coords=torch.from_numpy(state['coords']).cuda(),
196
- )
197
- tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
198
- return shape_slat, tex_slat, state['res']
199
-
200
- def get_seed(randomize_seed: bool, seed: int) -> int:
201
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
202
-
203
- def remove_background(input_img: Image.Image) -> Image.Image:
204
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
205
- input_img = input_img.convert('RGB')
206
- input_img.save(f.name)
207
- # Use gradio client to call remote or local rmbg
208
  output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
209
- output_img = Image.open(output)
210
- return output_img
211
 
212
- def preprocess_image(input_img: Image.Image) -> Image.Image:
213
- if input_img is None: return None
214
- # Check alpha
215
  has_alpha = False
216
- if input_img.mode == 'RGBA':
217
- alpha = np.array(input_img)[:, :, 3]
218
  if not np.all(alpha == 255):
219
  has_alpha = True
220
-
221
- # Resize logic
222
- max_size = max(input_img.size)
223
  scale = min(1, 1024 / max_size)
224
  if scale < 1:
225
- input_img = input_img.resize((int(input_img.width * scale), int(input_img.height * scale)), Image.Resampling.LANCZOS)
226
-
227
  if has_alpha:
228
- output = input_img
229
  else:
230
- output = remove_background(input_img)
231
-
232
- # Recenter and crop
233
  output_np = np.array(output)
234
  alpha = output_np[:, :, 3]
235
  bbox = np.argwhere(alpha > 0.8 * 255)
236
- if bbox.size == 0: return output # Fallback if empty
237
-
238
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
239
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
240
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
241
- size = int(size * 1) # margin factor
242
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
243
  output = output.crop(bbox)
244
-
245
- # Premultiply alpha
246
  output = np.array(output).astype(np.float32) / 255
247
  output = output[:, :, :3] * output[:, :, 3:4]
248
  output = Image.fromarray((output * 255).astype(np.uint8))
249
  return output
250
 
251
- # --- 7. Core GPU Generators ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  @spaces.GPU()
254
- def generate_text_to_image(prompt):
255
- """Generates image from text using Z-Image-Turbo."""
 
 
256
  if not prompt.strip():
257
  raise gr.Error("Please enter a prompt.")
258
 
259
  device = "cuda" if torch.cuda.is_available() else "cpu"
260
- generator = torch.Generator(device).manual_seed(42) # Fixed seed for T2I stability, or make random
261
 
262
- print(f"Generating image for: {prompt}")
263
  try:
264
  result = z_pipe(
265
  prompt=prompt,
 
266
  height=1024,
267
  width=1024,
268
- num_inference_steps=9, # Z-Image-Turbo specific
269
  guidance_scale=0.0,
270
  generator=generator,
271
  )
@@ -273,70 +293,65 @@ def generate_text_to_image(prompt):
273
  except Exception as e:
274
  raise gr.Error(f"Z-Image Generation failed: {str(e)}")
275
 
276
- @spaces.GPU(duration=180)
277
- def generate_3d_from_image(
278
  image: Image.Image,
279
  seed: int,
280
  resolution: str,
281
- ss_params: list,
282
- shape_params: list,
283
- tex_params: list,
284
- ):
285
- """
286
- Main pipeline: Preprocess Image -> Generate TRELLIS Latents -> Render Preview
287
- """
288
- if image is None:
289
- raise gr.Error("No image provided for 3D generation.")
290
-
291
- # 1. Preprocess (Remove BG if needed, center)
292
- processed_image = preprocess_image(image)
 
 
 
293
 
294
- # 2. Extract params
295
- (ss_steps, ss_guidance, ss_rescale, ss_t) = ss_params
296
- (sh_steps, sh_guidance, sh_rescale, sh_t) = shape_params
297
- (tex_steps, tex_guidance, tex_rescale, tex_t) = tex_params
298
 
299
- # 3. Run TRELLIS Pipeline
300
  outputs, latents = trellis_pipeline.run(
301
- processed_image,
302
  seed=seed,
303
- preprocess_image=False, # We did it manually
304
  sparse_structure_sampler_params={
305
- "steps": int(ss_steps),
306
- "guidance_strength": ss_guidance,
307
- "guidance_rescale": ss_rescale,
308
- "rescale_t": ss_t,
309
  },
310
  shape_slat_sampler_params={
311
- "steps": int(sh_steps),
312
- "guidance_strength": sh_guidance,
313
- "guidance_rescale": sh_rescale,
314
- "rescale_t": sh_t,
315
  },
316
  tex_slat_sampler_params={
317
- "steps": int(tex_steps),
318
- "guidance_strength": tex_guidance,
319
- "guidance_rescale": tex_rescale,
320
- "rescale_t": tex_t,
321
  },
322
- pipeline_type={
323
- "512": "512",
324
- "1024": "1024_cascade",
325
- "1536": "1536_cascade",
326
- }[resolution],
327
  return_latent=True,
328
  )
329
-
330
- # 4. Render Preview
331
  mesh = outputs[0]
332
- mesh.simplify(16777216)
333
 
334
- # Use environment map for rendering if available
335
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
336
  state = pack_state(latents)
337
  torch.cuda.empty_cache()
338
-
339
- # 5. Build HTML
340
  images_html = ""
341
  for m_idx, mode in enumerate(MODES):
342
  for s_idx in range(STEPS):
@@ -344,18 +359,18 @@ def generate_3d_from_image(
344
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
345
  vis_class = "visible" if is_visible else ""
346
  img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
347
- images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
348
-
349
  btns_html = ""
350
  for idx, mode in enumerate(MODES):
351
  active_class = "active" if idx == DEFAULT_MODE else ""
352
- btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
353
-
354
  full_html = f"""
355
  <div class="previewer-container">
356
  <div class="tips-wrapper">
357
  <div class="tips-icon">💡Tips</div>
358
- <div class="tips-text"><p>● <b>Render Mode</b> - Switch render modes.</p><p>● <b>View Angle</b> - Drag slider.</p></div>
359
  </div>
360
  <div class="display-row">{images_html}</div>
361
  <div class="mode-row" id="btn-group">{btns_html}</div>
@@ -364,14 +379,21 @@ def generate_3d_from_image(
364
  </div>
365
  </div>
366
  """
367
-
368
- return state, full_html, processed_image
369
 
370
  @spaces.GPU(duration=120)
371
- def extract_glb(state: dict, decimation_target: int, texture_size: int, req: gr.Request):
 
 
 
 
 
 
 
 
 
372
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
373
  shape_slat, tex_slat, res = unpack_state(state)
374
-
375
  mesh = trellis_pipeline.decode_latent(shape_slat, tex_slat, res)[0]
376
  mesh.simplify(16777216)
377
 
@@ -390,150 +412,145 @@ def extract_glb(state: dict, decimation_target: int, texture_size: int, req: gr.
390
  remesh_project=0,
391
  use_tqdm=True,
392
  )
393
-
394
  now = datetime.now()
395
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
396
  os.makedirs(user_dir, exist_ok=True)
397
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
398
  glb.export(glb_path, extension_webp=True)
399
  torch.cuda.empty_cache()
400
-
401
  return glb_path, glb_path
402
 
403
- # --- 8. UI Construction ---
 
 
404
 
405
- # Load icons for UI (Assumes assets folder exists)
406
- for i in range(len(MODES)):
407
- icon_path = MODES[i]['icon']
408
- if os.path.exists(icon_path):
409
- MODES[i]['icon_base64'] = image_to_base64(Image.open(icon_path))
410
- else:
411
- # Fallback empty image if asset missing
412
- MODES[i]['icon_base64'] = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
413
-
414
- with gr.Blocks(css=css, title="TRELLIS.2-3D Turbo", delete_cache=(600, 600)) as demo:
415
 
416
- # Session Management
417
- demo.load(start_session)
418
- demo.unload(end_session)
419
-
420
- gr.Markdown("""
421
- # ⚡ Text to 3D Turbo
422
- **Z-Image-Turbo** (Text-to-Image) + **TRELLIS.2** (Image-to-3D)
423
- """)
424
-
425
- with gr.Row():
426
- # LEFT COLUMN: INPUTS
427
- with gr.Column(scale=1, min_width=350):
428
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="A stylized cute dragon, high quality, 3d render...")
429
- generate_txt2img_btn = gr.Button("1. Generate Image", variant="primary")
430
-
431
- # Intermediate Image
432
- generated_image = gr.Image(label="Generated Reference Image", type="pil", height=300)
433
-
434
- gr.Markdown("---")
435
- generate_3d_btn = gr.Button("2. Generate 3D Model", variant="secondary")
436
-
437
- with gr.Accordion("3D Settings", open=False):
438
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
439
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
440
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
441
 
442
- # Grouped params for passing to function
443
- # SS
444
- ss_steps = gr.Slider(1, 50, value=12, label="Structure Steps")
445
- ss_str = gr.Slider(1.0, 10.0, value=7.5, label="Structure Guidance")
446
- ss_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Structure Rescale")
447
- ss_t = gr.Slider(1.0, 6.0, value=5.0, label="Structure T")
448
 
449
- # Shape
450
- sh_steps = gr.Slider(1, 50, value=12, label="Shape Steps")
451
- sh_str = gr.Slider(1.0, 10.0, value=7.5, label="Shape Guidance")
452
- sh_rescale = gr.Slider(0.0, 1.0, value=0.5, label="Shape Rescale")
453
- sh_t = gr.Slider(1.0, 6.0, value=3.0, label="Shape T")
454
 
455
- # Texture
456
- tex_steps = gr.Slider(1, 50, value=12, label="Texture Steps")
457
- tex_str = gr.Slider(1.0, 10.0, value=1.0, label="Texture Guidance")
458
- tex_rescale = gr.Slider(0.0, 1.0, value=0.0, label="Texture Rescale")
459
- tex_t = gr.Slider(1.0, 6.0, value=3.0, label="Texture T")
460
-
461
- # RIGHT COLUMN: OUTPUTS
462
- with gr.Column(scale=2):
463
- with gr.Tabs():
464
- with gr.Tab("3D Preview"):
465
- preview_html = gr.HTML(empty_html, label="3D Preview")
466
- # Hidden state to store latent representation
467
- trellis_state = gr.State()
468
 
469
- with gr.Tab("Export GLB"):
470
- gr.Markdown("Extract GLB from the generated 3D data.")
471
- decimation = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
472
- tex_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
473
- extract_btn = gr.Button("Extract & Download GLB")
474
- glb_model = gr.Model3D(label="Extracted GLB", display_mode="solid", clear_color=(0.2, 0.2, 0.2, 1.0))
475
- dl_button = gr.DownloadButton("Download .glb")
476
-
477
- # --- Wiring ---
478
-
479
- # 1. Text to Image
480
- generate_txt2img_btn.click(
481
- generate_text_to_image,
482
- inputs=[text_prompt],
483
- outputs=[generated_image]
484
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- # 2. Image to 3D
487
- generate_3d_btn.click(
488
- get_seed, inputs=[randomize_seed, seed], outputs=[seed]
489
- ).then(
490
- generate_3d_from_image,
491
- inputs=[
492
- generated_image, seed, resolution,
493
- # Pack params into lists to keep signature clean
494
- gr.State([12, 7.5, 0.7, 5.0]), # Default SS placeholder, normally would bind sliders directly
495
- gr.State([12, 7.5, 0.5, 3.0]), # Default Shape placeholder
496
- gr.State([12, 1.0, 0.0, 3.0]) # Default Tex placeholder
497
- ],
498
- # Actually bind the real sliders
499
- inputs_kwargs={
500
- "ss_params": [ss_steps, ss_str, ss_rescale, ss_t],
501
- "shape_params": [sh_steps, sh_str, sh_rescale, sh_t],
502
- "tex_params": [tex_steps, tex_str, tex_rescale, tex_t]
503
- },
504
- outputs=[trellis_state, preview_html, generated_image] # Update image with processed version
505
- )
506
 
507
- # Fix: Gradio .then chaining with lists of inputs needs careful handling.
508
- # Let's redefine the click to map inputs directly without the inputs_kwargs workaround for clarity.
509
- generate_3d_btn.click(
510
- get_seed, inputs=[randomize_seed, seed], outputs=[seed]
511
- ).then(
512
- lambda *args: generate_3d_from_image(
513
- args[0], args[1], args[2],
514
- [args[3], args[4], args[5], args[6]], # SS params
515
- [args[7], args[8], args[9], args[10]], # Shape params
516
- [args[11], args[12], args[13], args[14]] # Tex params
517
- ),
518
- inputs=[
519
- generated_image, seed, resolution,
520
- ss_steps, ss_str, ss_rescale, ss_t,
521
- sh_steps, sh_str, sh_rescale, sh_t,
522
- tex_steps, tex_str, tex_rescale, tex_t
523
- ],
524
- outputs=[trellis_state, preview_html, generated_image]
525
- )
526
 
527
- # 3. Extract GLB
528
- extract_btn.click(
529
- extract_glb,
530
- inputs=[trellis_state, decimation, tex_size],
531
- outputs=[glb_model, dl_button]
532
- )
 
 
 
 
 
 
 
 
 
 
 
533
 
534
- if __name__ == "__main__":
535
- # Create temp directory
536
- os.makedirs(TMP_DIR, exist_ok=True)
537
-
538
- # Launch with custom scripts
539
- demo.launch(head=head)
 
 
 
 
 
1
  import os
 
 
 
2
  import shutil
3
+ import cv2
4
  import torch
5
  import numpy as np
 
 
 
 
 
6
  from PIL import Image
7
+ import base64
8
+ import io
9
+ import tempfile
10
+ from typing import *
11
  from datetime import datetime
12
+ from pathlib import Path
13
 
14
+ # --- Environment Setup (Must be before other imports) ---
15
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
16
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
17
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
18
+ # Set autotune cache relative to this file
19
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
20
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
21
 
22
+ # --- Third Party Imports ---
23
+ import gradio as gr
24
+ from gradio_client import Client, handle_file
25
+ import spaces
26
  from diffusers import ZImagePipeline
27
 
28
+ # --- TRELLIS Specific Imports ---
29
  from trellis2.modules.sparse import SparseTensor
30
  from trellis2.pipelines import Trellis2ImageTo3DPipeline
31
  from trellis2.renderers import EnvMap
32
  from trellis2.utils import render_utils
33
  import o_voxel
34
 
35
+ # ==========================================
36
+ # Global Configuration & Assets
37
+ # ==========================================
38
+
39
  MAX_SEED = np.iinfo(np.int32).max
40
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
41
 
42
+ # TRELLIS Render Modes
43
  MODES = [
44
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
45
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
 
52
  DEFAULT_MODE = 3
53
  DEFAULT_STEP = 3
54
 
55
+ # ==========================================
56
+ # CSS & JavaScript (For Custom Previewer)
57
+ # ==========================================
58
+
59
  css = """
60
+ /* Overwrite Gradio Default Style */
61
  .stepper-wrapper { padding: 0; }
62
  .stepper-container { padding: 0; align-items: center; }
63
  .step-button { flex-direction: row; }
 
66
  .step-label { position: relative; bottom: 0; }
67
  .wrap.center.full { inset: 0; height: 100%; }
68
  .wrap.center.full.translucent { background: var(--block-background-fill); }
69
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
70
+
71
+ /* Previewer */
72
+ .previewer-container {
73
+ position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
74
+ width: 100%; height: 722px; margin: 0 auto; padding: 20px;
75
+ display: flex; flex-direction: column; align-items: center; justify-content: center;
76
+ }
77
  .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
78
  .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
79
+ .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
80
  .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
81
+
82
+ /* Row 1: Display Modes */
83
  .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
84
  .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
85
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
86
  .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
87
+
88
+ /* Row 2: Display Image */
89
  .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
90
  .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
91
  .previewer-container .previewer-main-image.visible { display: block; }
92
+
93
+ /* Row 3: Custom HTML Slider */
94
  .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
95
  .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
96
  .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
97
  .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
98
  .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
99
+
100
+ /* Overwrite Previewer Block Style */
101
  .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
102
  .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
103
  """
 
120
  const targetId = 'view-m' + mode + '-s' + step;
121
  const targetImg = document.getElementById(targetId);
122
  if (targetImg) targetImg.classList.add('visible');
123
+
124
  const allBtns = document.querySelectorAll('.mode-btn');
125
  allBtns.forEach((btn, idx) => {
126
  if (idx === mode) btn.classList.add('active');
 
132
  </script>
133
  """
134
 
135
+ empty_html = f"""
136
  <div class="previewer-container">
137
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
138
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
 
139
  </div>
140
  """
141
 
142
+ # ==========================================
143
+ # Model Loading
144
+ # ==========================================
145
+
146
+ print("Initializing models...")
147
 
148
+ # 1. Z-Image-Turbo (Text to Image)
149
  print("Loading Z-Image-Turbo...")
150
+ try:
151
+ z_pipe = ZImagePipeline.from_pretrained(
152
+ "Tongyi-MAI/Z-Image-Turbo",
153
+ torch_dtype=torch.bfloat16,
154
+ low_cpu_mem_usage=False,
155
+ )
156
+ device = "cuda" if torch.cuda.is_available() else "cpu"
157
+ z_pipe.to(device)
158
+ print("Z-Image-Turbo loaded.")
159
+ except Exception as e:
160
+ print(f"Failed to load Z-Image-Turbo: {e}")
161
+ z_pipe = None
162
+
163
+ # 2. TRELLIS.2 (Image to 3D)
164
+ print("Loading TRELLIS.2...")
165
+ # Initialize on startup
166
  trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
167
+ trellis_pipeline.rembg_model = None
168
  trellis_pipeline.low_vram = False
169
+ trellis_pipeline.cuda()
 
170
 
171
+ # 3. Background Remover
172
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
173
 
174
+ # 4. HDRI Maps for TRELLIS
175
+ envmap = {}
176
+ # Try to load assets, handle gracefully if running in a basic environment
177
+ try:
178
+ envmap = {
179
+ 'forest': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
180
+ 'sunset': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
181
+ 'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
182
+ }
183
+ except Exception as e:
184
+ print(f"Warning: Could not load HDRI assets. {e}")
 
 
 
 
 
 
 
185
 
186
+ # ==========================================
187
+ # Helper Functions
188
+ # ==========================================
189
 
190
  def image_to_base64(image):
191
  buffered = io.BytesIO()
 
203
  if os.path.exists(user_dir):
204
  shutil.rmtree(user_dir)
205
 
206
+ def remove_background(input: Image.Image) -> Image.Image:
207
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
208
+ input = input.convert('RGB')
209
+ input.save(f.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
211
+ output = Image.open(output)
212
+ return output
213
 
214
+ def preprocess_image(input: Image.Image) -> Image.Image:
215
+ """Preprocess the input image: remove bg, crop, resize."""
 
216
  has_alpha = False
217
+ if input.mode == 'RGBA':
218
+ alpha = np.array(input)[:, :, 3]
219
  if not np.all(alpha == 255):
220
  has_alpha = True
221
+ max_size = max(input.size)
 
 
222
  scale = min(1, 1024 / max_size)
223
  if scale < 1:
224
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
 
225
  if has_alpha:
226
+ output = input
227
  else:
228
+ output = remove_background(input)
229
+
 
230
  output_np = np.array(output)
231
  alpha = output_np[:, :, 3]
232
  bbox = np.argwhere(alpha > 0.8 * 255)
233
+ if bbox.size == 0:
234
+ return output # Return original if empty
235
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
236
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
237
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
238
+ size = int(size * 1)
239
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
240
  output = output.crop(bbox)
 
 
241
  output = np.array(output).astype(np.float32) / 255
242
  output = output[:, :, :3] * output[:, :, 3:4]
243
  output = Image.fromarray((output * 255).astype(np.uint8))
244
  return output
245
 
246
+ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
247
+ shape_slat, tex_slat, res = latents
248
+ return {
249
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
250
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
251
+ 'coords': shape_slat.coords.cpu().numpy(),
252
+ 'res': res,
253
+ }
254
+
255
+ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
256
+ shape_slat = SparseTensor(
257
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
258
+ coords=torch.from_numpy(state['coords']).cuda(),
259
+ )
260
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
261
+ return shape_slat, tex_slat, state['res']
262
+
263
+ def get_seed(randomize_seed: bool, seed: int) -> int:
264
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
265
+
266
+ # ==========================================
267
+ # Inference Logic
268
+ # ==========================================
269
 
270
  @spaces.GPU()
271
+ def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)):
272
+ """Generate Image using Z-Image Turbo"""
273
+ if z_pipe is None:
274
+ raise gr.Error("Z-Image-Turbo model failed to load.")
275
  if not prompt.strip():
276
  raise gr.Error("Please enter a prompt.")
277
 
278
  device = "cuda" if torch.cuda.is_available() else "cpu"
279
+ generator = torch.Generator(device).manual_seed(42) # Or random
280
 
281
+ progress(0.1, desc="Generating Text-to-Image...")
282
  try:
283
  result = z_pipe(
284
  prompt=prompt,
285
+ negative_prompt=None,
286
  height=1024,
287
  width=1024,
288
+ num_inference_steps=9,
289
  guidance_scale=0.0,
290
  generator=generator,
291
  )
 
293
  except Exception as e:
294
  raise gr.Error(f"Z-Image Generation failed: {str(e)}")
295
 
296
+ @spaces.GPU(duration=120)
297
+ def image_to_3d(
298
  image: Image.Image,
299
  seed: int,
300
  resolution: str,
301
+ ss_guidance_strength: float,
302
+ ss_guidance_rescale: float,
303
+ ss_sampling_steps: int,
304
+ ss_rescale_t: float,
305
+ shape_slat_guidance_strength: float,
306
+ shape_slat_guidance_rescale: float,
307
+ shape_slat_sampling_steps: int,
308
+ shape_slat_rescale_t: float,
309
+ tex_slat_guidance_strength: float,
310
+ tex_slat_guidance_rescale: float,
311
+ tex_slat_sampling_steps: int,
312
+ tex_slat_rescale_t: float,
313
+ req: gr.Request,
314
+ progress=gr.Progress(track_tqdm=True),
315
+ ) -> str:
316
 
317
+ if image is None:
318
+ raise gr.Error("Input image is missing.")
 
 
319
 
320
+ # --- Sampling ---
321
  outputs, latents = trellis_pipeline.run(
322
+ image,
323
  seed=seed,
324
+ preprocess_image=False, # We pre-process in the upload handler or assume clean input
325
  sparse_structure_sampler_params={
326
+ "steps": ss_sampling_steps,
327
+ "guidance_strength": ss_guidance_strength,
328
+ "guidance_rescale": ss_guidance_rescale,
329
+ "rescale_t": ss_rescale_t,
330
  },
331
  shape_slat_sampler_params={
332
+ "steps": shape_slat_sampling_steps,
333
+ "guidance_strength": shape_slat_guidance_strength,
334
+ "guidance_rescale": shape_slat_guidance_rescale,
335
+ "rescale_t": shape_slat_rescale_t,
336
  },
337
  tex_slat_sampler_params={
338
+ "steps": tex_slat_sampling_steps,
339
+ "guidance_strength": tex_slat_guidance_strength,
340
+ "guidance_rescale": tex_slat_guidance_rescale,
341
+ "rescale_t": tex_slat_rescale_t,
342
  },
343
+ pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution],
 
 
 
 
344
  return_latent=True,
345
  )
 
 
346
  mesh = outputs[0]
347
+ mesh.simplify(16777216)
348
 
349
+ # Render Preview
350
  images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
351
  state = pack_state(latents)
352
  torch.cuda.empty_cache()
353
+
354
+ # --- HTML Construction ---
355
  images_html = ""
356
  for m_idx, mode in enumerate(MODES):
357
  for s_idx in range(STEPS):
 
359
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
360
  vis_class = "visible" if is_visible else ""
361
  img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
362
+ images_html += f"""<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">"""
363
+
364
  btns_html = ""
365
  for idx, mode in enumerate(MODES):
366
  active_class = "active" if idx == DEFAULT_MODE else ""
367
+ btns_html += f"""<img src="{mode['icon_base64']}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode['name']}">"""
368
+
369
  full_html = f"""
370
  <div class="previewer-container">
371
  <div class="tips-wrapper">
372
  <div class="tips-icon">💡Tips</div>
373
+ <div class="tips-text"><p>● <b>Render Mode</b> - Click buttons to switch modes.</p><p>● <b>View Angle</b> - Drag slider to rotate.</p></div>
374
  </div>
375
  <div class="display-row">{images_html}</div>
376
  <div class="mode-row" id="btn-group">{btns_html}</div>
 
379
  </div>
380
  </div>
381
  """
382
+ return state, full_html
 
383
 
384
  @spaces.GPU(duration=120)
385
+ def extract_glb(
386
+ state: dict,
387
+ decimation_target: int,
388
+ texture_size: int,
389
+ req: gr.Request,
390
+ progress=gr.Progress(track_tqdm=True),
391
+ ) -> Tuple[str, str]:
392
+ if state is None:
393
+ raise gr.Error("No 3D model generated yet.")
394
+
395
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
396
  shape_slat, tex_slat, res = unpack_state(state)
 
397
  mesh = trellis_pipeline.decode_latent(shape_slat, tex_slat, res)[0]
398
  mesh.simplify(16777216)
399
 
 
412
  remesh_project=0,
413
  use_tqdm=True,
414
  )
 
415
  now = datetime.now()
416
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
417
  os.makedirs(user_dir, exist_ok=True)
418
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
419
  glb.export(glb_path, extension_webp=True)
420
  torch.cuda.empty_cache()
 
421
  return glb_path, glb_path
422
 
423
+ # ==========================================
424
+ # Gradio UI Blocks
425
+ # ==========================================
426
 
427
+ if __name__ == "__main__":
428
+ os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
429
 
430
+ # Pre-process icon base64
431
+ for i in range(len(MODES)):
432
+ if os.path.exists(MODES[i]['icon']):
433
+ icon = Image.open(MODES[i]['icon'])
434
+ MODES[i]['icon_base64'] = image_to_base64(icon)
435
+ else:
436
+ MODES[i]['icon_base64'] = "" # Fallback
437
+
438
+ with gr.Blocks(css=css, head=head, delete_cache=(600, 600)) as demo:
439
+ gr.Markdown("""
440
+ # TRELLIS.2-3D
441
+ **Unified Text-to-3D Workflow**
442
+
443
+ 1. **Text to Image**: Generate a base image using Z-Image-Turbo.
444
+ 2. **Image to 3D**: Convert that image into a high-quality 3D asset using TRELLIS.2.
445
+ """)
446
+
447
+ with gr.Row():
448
+ # --- Column 1: Inputs & Config ---
449
+ with gr.Column(scale=1, min_width=360):
 
 
 
 
 
450
 
451
+ # --- Step 1: Text to Image ---
452
+ with gr.Group():
453
+ gr.Markdown("### Step 1: Generate Image")
454
+ txt_prompt = gr.Textbox(label="Prompt", placeholder="A sci-fi helmet, high quality, white background")
455
+ btn_gen_img = gr.Button("Generate Image", variant="secondary")
 
456
 
457
+ # --- Step 2: Image to 3D Input ---
458
+ gr.Markdown("### Step 2: Configure & Convert")
459
+ image_prompt = gr.Image(label="Input Image (Generated or Uploaded)", format="png", image_mode="RGBA", type="pil", height=300)
 
 
460
 
461
+ with gr.Accordion("3D Generation Settings", open=True):
462
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
463
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
464
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
465
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target (For GLB)", value=300000, step=10000)
466
+ texture_size = gr.Slider(1024, 4096, label="Texture Size (For GLB)", value=2048, step=1024)
 
 
 
 
 
 
 
467
 
468
+ btn_gen_3d = gr.Button("Generate 3D", variant="primary")
469
+
470
+ with gr.Accordion(label="Advanced Sampling Settings", open=False):
471
+ gr.Markdown("**Stage 1: Sparse Structure**")
472
+ ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
473
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale")
474
+ ss_sampling_steps = gr.Slider(1, 50, value=12, label="Steps")
475
+ ss_rescale_t = gr.Slider(1.0, 6.0, value=5.0, label="Rescale T")
476
+
477
+ gr.Markdown("**Stage 2: Shape**")
478
+ shape_guidance = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
479
+ shape_rescale = gr.Slider(0.0, 1.0, value=0.5, label="Rescale")
480
+ shape_steps = gr.Slider(1, 50, value=12, label="Steps")
481
+ shape_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
482
+
483
+ gr.Markdown("**Stage 3: Material**")
484
+ tex_guidance = gr.Slider(1.0, 10.0, value=1.0, label="Guidance")
485
+ tex_rescale = gr.Slider(0.0, 1.0, value=0.0, label="Rescale")
486
+ tex_steps = gr.Slider(1, 50, value=12, label="Steps")
487
+ tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
488
+
489
+ # --- Column 2: Outputs ---
490
+ with gr.Column(scale=10):
491
+ with gr.Walkthrough(selected=0) as walkthrough:
492
+ with gr.Step("Preview", id=0):
493
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
494
+ extract_btn = gr.Button("Extract GLB")
495
+ with gr.Step("Extract", id=1):
496
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
497
+ download_btn = gr.DownloadButton(label="Download GLB")
498
+
499
+ gr.Markdown("Note: GLB extraction might take ~30s.")
500
 
501
+ # ==========================================
502
+ # Wiring Events
503
+ # ==========================================
504
+
505
+ # State to hold the latent 3D representation
506
+ output_buf = gr.State()
507
+
508
+ demo.load(start_session)
509
+ demo.unload(end_session)
510
+
511
+ # 1. Text to Image Event
512
+ btn_gen_img.click(
513
+ generate_txt2img,
514
+ inputs=[txt_prompt],
515
+ outputs=[image_prompt]
516
+ ).then(
517
+ preprocess_image, # Auto preprocess the generated image (rmbg)
518
+ inputs=[image_prompt],
519
+ outputs=[image_prompt]
520
+ )
521
 
522
+ # 2. Upload Image Event (Preprocess)
523
+ image_prompt.upload(
524
+ preprocess_image,
525
+ inputs=[image_prompt],
526
+ outputs=[image_prompt],
527
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
+ # 3. Image to 3D Event
530
+ btn_gen_3d.click(
531
+ get_seed,
532
+ inputs=[randomize_seed, seed],
533
+ outputs=[seed],
534
+ ).then(
535
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
536
+ ).then(
537
+ image_to_3d,
538
+ inputs=[
539
+ image_prompt, seed, resolution,
540
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
541
+ shape_guidance, shape_rescale, shape_steps, shape_rescale_t,
542
+ tex_guidance, tex_rescale, tex_steps, tex_rescale_t,
543
+ ],
544
+ outputs=[output_buf, preview_output],
545
+ )
546
 
547
+ # 4. Extraction Event
548
+ extract_btn.click(
549
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
550
+ ).then(
551
+ extract_glb,
552
+ inputs=[output_buf, decimation_target, texture_size],
553
+ outputs=[glb_output, download_btn],
554
+ )
555
+
556
+ demo.launch()