HY-WorldPlay / app.py
gagndeep's picture
Upload folder using huggingface_hub
9f27f59 verified
import os
import sys
import shutil
import tempfile
import argparse
import subprocess
import json
import torch
import numpy as np
import spaces
import gradio as gr
from huggingface_hub import snapshot_download
# Clone the repository if not already present
REPO_URL = "https://github.com/Tencent-Hunyuan/HY-WorldPlay.git"
REPO_DIR = "HY-WorldPlay"
if not os.path.exists(REPO_DIR):
print(f"Cloning {REPO_URL}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
sys.path.append(os.path.abspath(REPO_DIR))
# Now importing specific modules from the cloned repo
try:
from hyvideo.pipelines.worldplay_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.parallel_states import initialize_parallel_state
from hyvideo.commons.infer_state import initialize_infer_state
except ImportError as e:
print(f"Error importing hyvideo: {e}")
print("Dependencies might be missing. Ensure requirements.txt is correct.")
# Mapping for pose actions if needed
mapping = {
(0,0,0,0): 0,
(1,0,0,0): 1, # forward
(0,1,0,0): 2, # backward
(0,0,1,0): 3, # right
(0,0,0,1): 4, # left
(1,0,1,0): 5,
(1,0,0,1): 6,
(0,1,1,0): 7,
(0,1,0,1): 8,
}
# --- Utility Functions adapted from generate.py ---
def one_hot_to_one_dimension(one_hot):
y = torch.tensor([mapping[tuple(row.tolist())] for row in one_hot])
return y
def pose_to_input(pose_json_path, latent_chunk_num, tps=False):
# This function is adapted to handle the JSON structure used in the repo
import json
from scipy.spatial.transform import Rotation as R
pose_json = json.load(open(pose_json_path, 'r'))
pose_keys = list(pose_json.keys())
intrinsic_list = []
w2c_list = []
# Simple sort to ensure chronological order if keys are timestamps or numbered
pose_keys.sort()
# We need to make sure we don't go out of bounds if JSON has fewer frames
iterations = min(latent_chunk_num, len(pose_keys))
for i in range(iterations):
t_key = pose_keys[i]
c2w = np.array(pose_json[t_key]["extrinsic"])
w2c = np.linalg.inv(c2w)
w2c_list.append(w2c)
intrinsic = np.array(pose_json[t_key]["K"])
intrinsic[0, 0] /= intrinsic[0, 2] * 2
intrinsic[1, 1] /= intrinsic[1, 2] * 2
intrinsic[0, 2] = 0.5
intrinsic[1, 2] = 0.5
intrinsic_list.append(intrinsic)
# Pad if we have fewer frames than requested chunks
if len(w2c_list) < latent_chunk_num:
# Repeat last frame
last_w2c = w2c_list[-1]
last_intrinsic = intrinsic_list[-1]
for _ in range(latent_chunk_num - len(w2c_list)):
w2c_list.append(last_w2c)
intrinsic_list.append(last_intrinsic)
w2c_list = np.array(w2c_list)
intrinsic_list = torch.tensor(np.array(intrinsic_list))
c2ws = np.linalg.inv(w2c_list)
C_inv = np.linalg.inv(c2ws[:-1])
relative_c2w = np.zeros_like(c2ws)
relative_c2w[0, ...] = c2ws[0, ...]
relative_c2w[1:, ...] = C_inv @ c2ws[1:, ...]
trans_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32)
rotate_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32)
move_norm_valid = 0.0001
for i in range(1, relative_c2w.shape[0]):
move_dirs = relative_c2w[i, :3, 3]
move_norms = np.linalg.norm(move_dirs)
if move_norms > move_norm_valid:
move_norm_dirs = move_dirs / move_norms
angles_rad = np.arccos(move_norm_dirs.clip(-1.0, 1.0))
trans_angles_deg = angles_rad * (180.0 / np.pi)
else:
trans_angles_deg = np.zeros(3)
R_rel = relative_c2w[i, :3, :3]
r = R.from_matrix(R_rel)
rot_angles_deg = r.as_euler('xyz', degrees=True)
if move_norms > move_norm_valid:
if (not tps) or (tps == True and abs(rot_angles_deg[1]) < 5e-2 and abs(rot_angles_deg[0]) < 5e-2):
if trans_angles_deg[2] < 60:
trans_one_hot[i, 0] = 1
elif trans_angles_deg[2] > 120:
trans_one_hot[i, 1] = 1
if trans_angles_deg[0] < 60:
trans_one_hot[i, 2] = 1
elif trans_angles_deg[0] > 120:
trans_one_hot[i, 3] = 1
if rot_angles_deg[1] > 5e-2:
rotate_one_hot[i, 0] = 1
elif rot_angles_deg[1] < -5e-2:
rotate_one_hot[i, 1] = 1
if rot_angles_deg[0] > 5e-2:
rotate_one_hot[i, 2] = 1
elif rot_angles_deg[0] < -5e-2:
rotate_one_hot[i, 3] = 1
trans_one_hot = torch.tensor(trans_one_hot)
rotate_one_hot = torch.tensor(rotate_one_hot)
trans_one_label = one_hot_to_one_dimension(trans_one_hot)
rotate_one_label = one_hot_to_one_dimension(rotate_one_hot)
action_one_label = trans_one_label * 9 + rotate_one_label
return torch.tensor(w2c_list), torch.tensor(intrinsic_list), action_one_label
# --- Model Loading and Inference ---
MODEL_PATH = "tencent/HunyuanVideo-1.5"
ACTION_CKPT = "tencent/HY-WorldPlay"
# Global pipeline variable
pipe = None
def load_model():
global pipe
if pipe is None:
print("Loading Model...")
# Ensure we have weights
# We might rely on the diffusers pipeline to download, but for custom pipeline it often expects local path
# Let's use snapshot_download to be safe and clear
# Check if we are in an environment where we can download
model_dir = snapshot_download(repo_id=MODEL_PATH, allow_patterns=["*.safetensors", "*.json", "*.txt"])
action_dir = snapshot_download(repo_id=ACTION_CKPT) # Downloads everything from HY-WorldPlay repo (checkpoints)
# We need to pinpoint the specific subfolder for action checkpoint if it has one
# Based on user description: "ar_distilled_action_model", "bidirectional_model", etc.
# Let's assume we use bidirectional for better quality or whatever default is best
# The user provided paths like "ar_model", "bidirectional_model".
# Let's use bidirectional_model from the snapshot.
action_subpath = os.path.join(action_dir, "bidirectional_model")
# Configs from args
transformer_dtype = torch.bfloat16
# Initialize parallel state (for single GPU usually world_size=1)
# Check if initialized
if not torch.distributed.is_initialized():
initialize_parallel_state(sp=1)
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=model_dir,
transformer_version="480p_i2v", # Hardcoded based on provided args in snippets
enable_offloading=True,
enable_group_offloading=True,
create_sr_pipeline=True, # Enable SR by default
force_sparse_attn=False,
transformer_dtype=transformer_dtype,
action_ckpt=action_subpath,
)
print("Model Loaded Successfully!")
return pipe
@spaces.GPU(duration=300)
def generate(prompt, image_input, pose_json, seed, num_inference_steps, video_length):
pipeline = load_model()
# Handle Pose JSON
pose_path = None
if pose_json is not None:
pose_path = pose_json.name
else:
# Create a default forward movement pose if not provided
default_pose_content = {
"0": {
"extrinsic": [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]],
"K": [[500.0, 0.0, 400.0], [0.0, 500.0, 240.0], [0.0, 0.0, 1.0]]
}
}
# Expand for a few frames to simulate forward movement
for i in range(1, 16):
# Move forward along Z (just a dummy generic forward)
# In camera conventions often Z is forward or -Z.
# Here we just keep static as safe default or minimal drift
default_pose_content[str(i)] = default_pose_content["0"]
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as tmp_json:
json.dump(default_pose_content, tmp_json)
pose_path = tmp_json.name
# Prepare inputs
latent_chunk_num = (video_length - 1) // 4 + 1
viewmats, Ks, action = pose_to_input(pose_path, latent_chunk_num)
# Handle Image Input (I2V vs T2V)
extra_kwargs = {}
if image_input is not None:
extra_kwargs['reference_image'] = image_input
# Run inference
out = pipeline(
enable_sr=True,
prompt=prompt,
aspect_ratio="16:9",
num_inference_steps=num_inference_steps,
sr_num_inference_steps=None,
video_length=video_length,
negative_prompt="",
seed=seed,
output_type="pt",
prompt_rewrite=False,
return_pre_sr_video=False,
viewmats=viewmats.unsqueeze(0),
Ks=Ks.unsqueeze(0),
action=action.unsqueeze(0),
few_step=False,
chunk_latent_frames=16,
model_type="bi",
user_height=480,
user_width=832,
**extra_kwargs
)
# Save video
output_path = "output.mp4"
import imageio
import einops
def save_vid(video, path):
if video.ndim == 5:
video = video[0]
vid = (video * 255).clamp(0, 255).to(torch.uint8)
vid = einops.rearrange(vid, 'c f h w -> f h w c')
imageio.mimwrite(path, vid, fps=24)
if hasattr(out, 'sr_videos'):
save_vid(out.sr_videos, output_path)
else:
save_vid(out.videos, output_path)
return output_path
# --- Gradio UI ---
default_pose_json_content = """
{
"0": {
"extrinsic": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
"K": [[500, 0, 400], [0, 500, 240], [0, 0, 1]]
}
}
""" # Very minimal dummy, ideally we want a real trajectory
with gr.Blocks() as app:
gr.Markdown("# HY-WorldPlay (HunyuanWorld 1.5) Demo")
gr.Markdown("Generate streaming videos with camera control using WorldPlay.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="A cinematic shot of a forest.")
image = gr.Image(label="Input Image", type="filepath")
pose_file = gr.File(label="Camera Path JSON", file_types=[".json"])
seed = gr.Number(label="Seed", value=123)
steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=50, step=1)
length = gr.Slider(label="Video Length (frames)", minimum=17, maximum=129, value=65, step=16) # 16*4 + 1
submit = gr.Button("Generate")
with gr.Column():
output_video = gr.Video(label="Generated Video")
submit.click(
fn=generate,
inputs=[prompt, image, pose_file, seed, steps, length],
outputs=[output_video]
)
if __name__ == "__main__":
app.launch()