Spaces:
Build error
Build error
| import os | |
| import sys | |
| import shutil | |
| import tempfile | |
| import argparse | |
| import subprocess | |
| import json | |
| import torch | |
| import numpy as np | |
| import spaces | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # Clone the repository if not already present | |
| REPO_URL = "https://github.com/Tencent-Hunyuan/HY-WorldPlay.git" | |
| REPO_DIR = "HY-WorldPlay" | |
| if not os.path.exists(REPO_DIR): | |
| print(f"Cloning {REPO_URL}...") | |
| subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) | |
| sys.path.append(os.path.abspath(REPO_DIR)) | |
| # Now importing specific modules from the cloned repo | |
| try: | |
| from hyvideo.pipelines.worldplay_video_pipeline import HunyuanVideo_1_5_Pipeline | |
| from hyvideo.commons.parallel_states import initialize_parallel_state | |
| from hyvideo.commons.infer_state import initialize_infer_state | |
| except ImportError as e: | |
| print(f"Error importing hyvideo: {e}") | |
| print("Dependencies might be missing. Ensure requirements.txt is correct.") | |
| # Mapping for pose actions if needed | |
| mapping = { | |
| (0,0,0,0): 0, | |
| (1,0,0,0): 1, # forward | |
| (0,1,0,0): 2, # backward | |
| (0,0,1,0): 3, # right | |
| (0,0,0,1): 4, # left | |
| (1,0,1,0): 5, | |
| (1,0,0,1): 6, | |
| (0,1,1,0): 7, | |
| (0,1,0,1): 8, | |
| } | |
| # --- Utility Functions adapted from generate.py --- | |
| def one_hot_to_one_dimension(one_hot): | |
| y = torch.tensor([mapping[tuple(row.tolist())] for row in one_hot]) | |
| return y | |
| def pose_to_input(pose_json_path, latent_chunk_num, tps=False): | |
| # This function is adapted to handle the JSON structure used in the repo | |
| import json | |
| from scipy.spatial.transform import Rotation as R | |
| pose_json = json.load(open(pose_json_path, 'r')) | |
| pose_keys = list(pose_json.keys()) | |
| intrinsic_list = [] | |
| w2c_list = [] | |
| # Simple sort to ensure chronological order if keys are timestamps or numbered | |
| pose_keys.sort() | |
| # We need to make sure we don't go out of bounds if JSON has fewer frames | |
| iterations = min(latent_chunk_num, len(pose_keys)) | |
| for i in range(iterations): | |
| t_key = pose_keys[i] | |
| c2w = np.array(pose_json[t_key]["extrinsic"]) | |
| w2c = np.linalg.inv(c2w) | |
| w2c_list.append(w2c) | |
| intrinsic = np.array(pose_json[t_key]["K"]) | |
| intrinsic[0, 0] /= intrinsic[0, 2] * 2 | |
| intrinsic[1, 1] /= intrinsic[1, 2] * 2 | |
| intrinsic[0, 2] = 0.5 | |
| intrinsic[1, 2] = 0.5 | |
| intrinsic_list.append(intrinsic) | |
| # Pad if we have fewer frames than requested chunks | |
| if len(w2c_list) < latent_chunk_num: | |
| # Repeat last frame | |
| last_w2c = w2c_list[-1] | |
| last_intrinsic = intrinsic_list[-1] | |
| for _ in range(latent_chunk_num - len(w2c_list)): | |
| w2c_list.append(last_w2c) | |
| intrinsic_list.append(last_intrinsic) | |
| w2c_list = np.array(w2c_list) | |
| intrinsic_list = torch.tensor(np.array(intrinsic_list)) | |
| c2ws = np.linalg.inv(w2c_list) | |
| C_inv = np.linalg.inv(c2ws[:-1]) | |
| relative_c2w = np.zeros_like(c2ws) | |
| relative_c2w[0, ...] = c2ws[0, ...] | |
| relative_c2w[1:, ...] = C_inv @ c2ws[1:, ...] | |
| trans_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) | |
| rotate_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) | |
| move_norm_valid = 0.0001 | |
| for i in range(1, relative_c2w.shape[0]): | |
| move_dirs = relative_c2w[i, :3, 3] | |
| move_norms = np.linalg.norm(move_dirs) | |
| if move_norms > move_norm_valid: | |
| move_norm_dirs = move_dirs / move_norms | |
| angles_rad = np.arccos(move_norm_dirs.clip(-1.0, 1.0)) | |
| trans_angles_deg = angles_rad * (180.0 / np.pi) | |
| else: | |
| trans_angles_deg = np.zeros(3) | |
| R_rel = relative_c2w[i, :3, :3] | |
| r = R.from_matrix(R_rel) | |
| rot_angles_deg = r.as_euler('xyz', degrees=True) | |
| if move_norms > move_norm_valid: | |
| if (not tps) or (tps == True and abs(rot_angles_deg[1]) < 5e-2 and abs(rot_angles_deg[0]) < 5e-2): | |
| if trans_angles_deg[2] < 60: | |
| trans_one_hot[i, 0] = 1 | |
| elif trans_angles_deg[2] > 120: | |
| trans_one_hot[i, 1] = 1 | |
| if trans_angles_deg[0] < 60: | |
| trans_one_hot[i, 2] = 1 | |
| elif trans_angles_deg[0] > 120: | |
| trans_one_hot[i, 3] = 1 | |
| if rot_angles_deg[1] > 5e-2: | |
| rotate_one_hot[i, 0] = 1 | |
| elif rot_angles_deg[1] < -5e-2: | |
| rotate_one_hot[i, 1] = 1 | |
| if rot_angles_deg[0] > 5e-2: | |
| rotate_one_hot[i, 2] = 1 | |
| elif rot_angles_deg[0] < -5e-2: | |
| rotate_one_hot[i, 3] = 1 | |
| trans_one_hot = torch.tensor(trans_one_hot) | |
| rotate_one_hot = torch.tensor(rotate_one_hot) | |
| trans_one_label = one_hot_to_one_dimension(trans_one_hot) | |
| rotate_one_label = one_hot_to_one_dimension(rotate_one_hot) | |
| action_one_label = trans_one_label * 9 + rotate_one_label | |
| return torch.tensor(w2c_list), torch.tensor(intrinsic_list), action_one_label | |
| # --- Model Loading and Inference --- | |
| MODEL_PATH = "tencent/HunyuanVideo-1.5" | |
| ACTION_CKPT = "tencent/HY-WorldPlay" | |
| # Global pipeline variable | |
| pipe = None | |
| def load_model(): | |
| global pipe | |
| if pipe is None: | |
| print("Loading Model...") | |
| # Ensure we have weights | |
| # We might rely on the diffusers pipeline to download, but for custom pipeline it often expects local path | |
| # Let's use snapshot_download to be safe and clear | |
| # Check if we are in an environment where we can download | |
| model_dir = snapshot_download(repo_id=MODEL_PATH, allow_patterns=["*.safetensors", "*.json", "*.txt"]) | |
| action_dir = snapshot_download(repo_id=ACTION_CKPT) # Downloads everything from HY-WorldPlay repo (checkpoints) | |
| # We need to pinpoint the specific subfolder for action checkpoint if it has one | |
| # Based on user description: "ar_distilled_action_model", "bidirectional_model", etc. | |
| # Let's assume we use bidirectional for better quality or whatever default is best | |
| # The user provided paths like "ar_model", "bidirectional_model". | |
| # Let's use bidirectional_model from the snapshot. | |
| action_subpath = os.path.join(action_dir, "bidirectional_model") | |
| # Configs from args | |
| transformer_dtype = torch.bfloat16 | |
| # Initialize parallel state (for single GPU usually world_size=1) | |
| # Check if initialized | |
| if not torch.distributed.is_initialized(): | |
| initialize_parallel_state(sp=1) | |
| pipe = HunyuanVideo_1_5_Pipeline.create_pipeline( | |
| pretrained_model_name_or_path=model_dir, | |
| transformer_version="480p_i2v", # Hardcoded based on provided args in snippets | |
| enable_offloading=True, | |
| enable_group_offloading=True, | |
| create_sr_pipeline=True, # Enable SR by default | |
| force_sparse_attn=False, | |
| transformer_dtype=transformer_dtype, | |
| action_ckpt=action_subpath, | |
| ) | |
| print("Model Loaded Successfully!") | |
| return pipe | |
| def generate(prompt, image_input, pose_json, seed, num_inference_steps, video_length): | |
| pipeline = load_model() | |
| # Handle Pose JSON | |
| pose_path = None | |
| if pose_json is not None: | |
| pose_path = pose_json.name | |
| else: | |
| # Create a default forward movement pose if not provided | |
| default_pose_content = { | |
| "0": { | |
| "extrinsic": [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], | |
| "K": [[500.0, 0.0, 400.0], [0.0, 500.0, 240.0], [0.0, 0.0, 1.0]] | |
| } | |
| } | |
| # Expand for a few frames to simulate forward movement | |
| for i in range(1, 16): | |
| # Move forward along Z (just a dummy generic forward) | |
| # In camera conventions often Z is forward or -Z. | |
| # Here we just keep static as safe default or minimal drift | |
| default_pose_content[str(i)] = default_pose_content["0"] | |
| with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as tmp_json: | |
| json.dump(default_pose_content, tmp_json) | |
| pose_path = tmp_json.name | |
| # Prepare inputs | |
| latent_chunk_num = (video_length - 1) // 4 + 1 | |
| viewmats, Ks, action = pose_to_input(pose_path, latent_chunk_num) | |
| # Handle Image Input (I2V vs T2V) | |
| extra_kwargs = {} | |
| if image_input is not None: | |
| extra_kwargs['reference_image'] = image_input | |
| # Run inference | |
| out = pipeline( | |
| enable_sr=True, | |
| prompt=prompt, | |
| aspect_ratio="16:9", | |
| num_inference_steps=num_inference_steps, | |
| sr_num_inference_steps=None, | |
| video_length=video_length, | |
| negative_prompt="", | |
| seed=seed, | |
| output_type="pt", | |
| prompt_rewrite=False, | |
| return_pre_sr_video=False, | |
| viewmats=viewmats.unsqueeze(0), | |
| Ks=Ks.unsqueeze(0), | |
| action=action.unsqueeze(0), | |
| few_step=False, | |
| chunk_latent_frames=16, | |
| model_type="bi", | |
| user_height=480, | |
| user_width=832, | |
| **extra_kwargs | |
| ) | |
| # Save video | |
| output_path = "output.mp4" | |
| import imageio | |
| import einops | |
| def save_vid(video, path): | |
| if video.ndim == 5: | |
| video = video[0] | |
| vid = (video * 255).clamp(0, 255).to(torch.uint8) | |
| vid = einops.rearrange(vid, 'c f h w -> f h w c') | |
| imageio.mimwrite(path, vid, fps=24) | |
| if hasattr(out, 'sr_videos'): | |
| save_vid(out.sr_videos, output_path) | |
| else: | |
| save_vid(out.videos, output_path) | |
| return output_path | |
| # --- Gradio UI --- | |
| default_pose_json_content = """ | |
| { | |
| "0": { | |
| "extrinsic": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], | |
| "K": [[500, 0, 400], [0, 500, 240], [0, 0, 1]] | |
| } | |
| } | |
| """ # Very minimal dummy, ideally we want a real trajectory | |
| with gr.Blocks() as app: | |
| gr.Markdown("# HY-WorldPlay (HunyuanWorld 1.5) Demo") | |
| gr.Markdown("Generate streaming videos with camera control using WorldPlay.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", value="A cinematic shot of a forest.") | |
| image = gr.Image(label="Input Image", type="filepath") | |
| pose_file = gr.File(label="Camera Path JSON", file_types=[".json"]) | |
| seed = gr.Number(label="Seed", value=123) | |
| steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=50, step=1) | |
| length = gr.Slider(label="Video Length (frames)", minimum=17, maximum=129, value=65, step=16) # 16*4 + 1 | |
| submit = gr.Button("Generate") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| submit.click( | |
| fn=generate, | |
| inputs=[prompt, image, pose_file, seed, steps, length], | |
| outputs=[output_video] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |