""" MoMask: Text-to-Motion Generation https://github.com/EricGuo5513/momask-codes Compact HuggingFace Space implementation using ONNX INT8 models. """ import os import sys import tempfile import numpy as np import torch import onnxruntime as ort import clip import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, FFMpegWriter from pathlib import Path # ============ Quaternion Operations ============ def qinv(q): """Invert quaternion""" assert q.shape[-1] == 4 mask = torch.ones_like(q) mask[..., 1:] = -mask[..., 1:] return q * mask def qrot(q, v): """Rotate vector(s) v by quaternion(s) q""" assert q.shape[-1] == 4 assert v.shape[-1] == 3 assert q.shape[:-1] == v.shape[:-1] original_shape = list(v.shape) q = q.contiguous().view(-1, 4) v = v.contiguous().view(-1, 3) qvec = q[:, 1:] uv = torch.cross(qvec, v, dim=1) uuv = torch.cross(qvec, uv, dim=1) return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) # ============ Configuration ============ ONNX_DIR = Path(__file__).parent / "onnx_models" DEVICE = "cpu" JOINTS_NUM = 22 TIMESTEPS = 18 MASK_COND_SCALE = 4.0 # CFG scale for mask transformer RES_COND_SCALE = 5.0 # CFG scale for residual transformer TEMPERATURE = 1.0 TOPK_FILTER = 0.9 # Kinematic chain for visualization T2M_KINEMATIC_CHAIN = [ [0, 2, 5, 8, 11], # Right leg [0, 1, 4, 7, 10], # Left leg [0, 3, 6, 9, 12, 15], # Spine [9, 14, 17, 19, 21], # Right arm [9, 13, 16, 18, 20] # Left arm ] # ============ ONNX Sessions ============ sessions = {} def get_session(name): if name not in sessions: path = ONNX_DIR / f"{name}.onnx" if not path.exists(): raise FileNotFoundError(f"Model not found: {path}") sessions[name] = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) return sessions[name] # ============ Motion Recovery ============ def recover_root_rot_pos(data): """Recover root rotation and position from motion data""" rot_vel = data[..., 0] r_rot_ang = torch.zeros_like(rot_vel) r_rot_ang[..., 1:] = rot_vel[..., :-1] r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) r_rot_quat = torch.zeros(data.shape[:-1] + (4,)) r_rot_quat[..., 0] = torch.cos(r_rot_ang) r_rot_quat[..., 2] = torch.sin(r_rot_ang) r_pos = torch.zeros(data.shape[:-1] + (3,)) r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] r_pos = qrot(qinv(r_rot_quat), r_pos) r_pos = torch.cumsum(r_pos, dim=-2) r_pos[..., 1] = data[..., 3] return r_rot_quat, r_pos def recover_from_ric(data, joints_num=22): """Convert 263-dim motion representation to 3D joint positions""" r_rot_quat, r_pos = recover_root_rot_pos(data) positions = data[..., 4:(joints_num - 1) * 3 + 4] positions = positions.view(positions.shape[:-1] + (-1, 3)) positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) positions[..., 0] += r_pos[..., 0:1] positions[..., 2] += r_pos[..., 2:3] positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) return positions # ============ Foot Skating Fix ============ def fix_foot_skating(joints, vel_threshold=0.01, height_threshold=0.08): """ Simple foot skating fix by pinning feet during contact frames. Args: joints: numpy array (frames, 22, 3) vel_threshold: velocity threshold for contact detection height_threshold: height threshold for contact detection Returns: Fixed joints array """ joints = joints.copy() n_frames = len(joints) # Foot indices: 7=LeftFoot, 8=RightFoot, 10=LeftToe, 11=RightToe foot_idx = {'left': [7, 10], 'right': [8, 11]} for side, indices in foot_idx.items(): # Get foot positions (use toe for contact) toe_idx = indices[1] foot_pos = joints[:, toe_idx, :] # (frames, 3) # Compute velocity (frame-to-frame displacement) vel = np.zeros(n_frames) vel[1:] = np.linalg.norm(foot_pos[1:] - foot_pos[:-1], axis=1) # Detect contact: low velocity AND low height height = foot_pos[:, 1] # Y is up floor_height = np.percentile(height, 10) # Estimate floor contact = (vel < vel_threshold) & (height < floor_height + height_threshold) # Find contact regions and pin feet in_contact = False contact_start = 0 contact_pos = None for f in range(n_frames): if contact[f] and not in_contact: # Start of contact in_contact = True contact_start = f contact_pos = foot_pos[f].copy() contact_pos[1] = floor_height # Pin to floor elif contact[f] and in_contact: # Continue contact - pin both foot joints to contact position for idx in indices: offset = joints[contact_start, idx] - joints[contact_start, toe_idx] joints[f, idx] = contact_pos + offset joints[f, idx, 1] = max(floor_height, joints[f, idx, 1] - (foot_pos[f, 1] - floor_height)) elif not contact[f] and in_contact: # End of contact in_contact = False return joints # ============ Visualization ============ def plot_3d_motion(save_path, joints, title, fps=20): """Create MP4 video of 3D skeleton motion (dark mode)""" fig = plt.figure(figsize=(8, 8), facecolor='#1a1a2e') ax = fig.add_subplot(111, projection="3d", facecolor='#1a1a2e') # Bright colors for dark background COLORS = ["#ff6b6b", "#4ecdc4", "#ffe66d", "#95e1d3", "#dda0dd"] def init(): ax.set_xlim(-1.5, 1.5) ax.set_ylim(-1.5, 1.5) ax.set_zlim(0, 2) ax.set_xlabel("X", color='white') ax.set_ylabel("Z", color='white') ax.set_zlabel("Y (up)", color='white') ax.set_title(title, color='white', fontsize=12, pad=10) # Dark mode styling ax.tick_params(colors='white') ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor('#333355') ax.yaxis.pane.set_edgecolor('#333355') ax.zaxis.pane.set_edgecolor('#333355') ax.grid(True, alpha=0.3, color='#555577') return [] lines = [] for i, chain in enumerate(T2M_KINEMATIC_CHAIN): line, = ax.plot([], [], [], color=COLORS[i], linewidth=2.5, marker="o", markersize=4) lines.append(line) def update(frame): data = joints[frame] for i, chain in enumerate(T2M_KINEMATIC_CHAIN): x = [data[j, 0] for j in chain] y = [data[j, 2] for j in chain] z = [data[j, 1] for j in chain] lines[i].set_data(x, y) lines[i].set_3d_properties(z) ax.view_init(elev=20, azim=45 + frame * 0.5) return lines ani = FuncAnimation(fig, update, frames=len(joints), init_func=init, blit=False, interval=1000//fps) writer = FFMpegWriter(fps=fps, bitrate=2000) ani.save(save_path, writer=writer) plt.close() # ============ BVH Export with Simple IK ============ def _quat_mul(q1, q2): """Quaternion multiplication (w,x,y,z format).""" w1, x1, y1, z1 = q1 w2, x2, y2, z2 = q2 return np.array([ w1*w2 - x1*x2 - y1*y2 - z1*z2, w1*x2 + x1*w2 + y1*z2 - z1*y2, w1*y2 - x1*z2 + y1*w2 + z1*x2, w1*z2 + x1*y2 - y1*x2 + z1*w2 ]) def _quat_inv(q): """Quaternion inverse.""" return np.array([q[0], -q[1], -q[2], -q[3]]) def _quat_from_two_vectors(v1, v2): """Quaternion rotating v1 to v2.""" v1 = v1 / (np.linalg.norm(v1) + 1e-10) v2 = v2 / (np.linalg.norm(v2) + 1e-10) cross = np.cross(v1, v2) dot = np.dot(v1, v2) if dot < -0.999999: # Opposite vectors return np.array([0, 1, 0, 0]) w = 1.0 + dot q = np.array([w, cross[0], cross[1], cross[2]]) return q / (np.linalg.norm(q) + 1e-10) def _quat_to_euler_zyx(q): """Quaternion to Euler ZYX order (degrees) - standard BVH format.""" w, x, y, z = q # Z rotation (yaw) siny = 2 * (w * z + x * y) cosy = 1 - 2 * (y * y + z * z) rz = np.arctan2(siny, cosy) # Y rotation (pitch) sinp = 2 * (w * y - z * x) ry = np.arcsin(np.clip(sinp, -1, 1)) # X rotation (roll) sinr = 2 * (w * x + y * z) cosr = 1 - 2 * (x * x + y * y) rx = np.arctan2(sinr, cosr) return np.degrees([rz, ry, rx]) # ZYX order for BVH def joints_to_bvh(joints, output_path, fps=20): """Convert joint positions to BVH with proper local rotations.""" n_frames, n_joints, _ = joints.shape joint_names = ["Hips", "LeftUpLeg", "RightUpLeg", "Spine", "LeftLeg", "RightLeg", "Spine1", "LeftFoot", "RightFoot", "Spine2", "LeftToe", "RightToe", "Neck", "LeftShoulder", "RightShoulder", "Head", "LeftArm", "RightArm", "LeftForeArm", "RightForeArm", "LeftHand", "RightHand"] parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19] children = [[] for _ in range(n_joints)] for i, p in enumerate(parents): if p >= 0: children[p].append(i) # Rest pose offsets offsets = np.zeros((n_joints, 3)) for i in range(n_joints): if parents[i] >= 0: offsets[i] = joints[0, i] - joints[0, parents[i]] scale = 100.0 all_rotations = np.zeros((n_frames, n_joints, 3)) # Compute local rotations per frame using simple IK for frame in range(n_frames): global_quats = [np.array([1, 0, 0, 0])] * n_joints for j in range(n_joints): if not children[j]: continue child = children[j][0] # Rest pose direction (local) rest_dir = offsets[child] if np.linalg.norm(rest_dir) < 1e-6: continue # Current direction (global) curr_dir = joints[frame, child] - joints[frame, j] if np.linalg.norm(curr_dir) < 1e-6: continue # Global rotation to align rest to current global_rot = _quat_from_two_vectors(rest_dir, curr_dir) # Convert to local (relative to parent) if parents[j] >= 0: parent_inv = _quat_inv(global_quats[parents[j]]) local_rot = _quat_mul(parent_inv, global_rot) else: local_rot = global_rot global_quats[j] = global_rot all_rotations[frame, j] = _quat_to_euler_zyx(local_rot) # Write BVH (ZYX rotation order - standard for BVH) with open(output_path, "w") as f: f.write("HIERARCHY\n") def write_joint(idx, indent): name, off = joint_names[idx], offsets[idx] * scale pre = " " * indent f.write(f"{'ROOT' if idx==0 else pre+'JOINT'} {name}\n{pre}{{\n") f.write(f"{pre} OFFSET {off[0]:.6f} {off[1]:.6f} {off[2]:.6f}\n") f.write(f"{pre} CHANNELS {'6 Xposition Yposition Zposition ' if idx==0 else '3 '}Zrotation Yrotation Xrotation\n") if children[idx]: for c in children[idx]: write_joint(c, indent + 1) else: f.write(f"{pre} End Site\n{pre} {{\n{pre} OFFSET 0.0 0.0 0.0\n{pre} }}\n") f.write(f"{pre}}}\n") write_joint(0, 0) f.write(f"MOTION\nFrames: {n_frames}\nFrame Time: {1.0/fps:.6f}\n") for frame in range(n_frames): vals = list(joints[frame, 0] * scale) # Root position for j in range(n_joints): vals.extend(all_rotations[frame, j]) f.write(" ".join(f"{v:.6f}" for v in vals) + "\n") return output_path # ============ Sampling Utilities ============ def cosine_schedule(t): """Cosine noise schedule""" return torch.cos(t * np.pi * 0.5) def top_k_filter(logits, k=0.9): """Apply top-k filtering""" k = int((1 - k) * logits.shape[-1]) val, ind = torch.topk(logits, k, dim=-1) probs = torch.full_like(logits, float("-inf")) probs.scatter_(-1, ind, val) return probs def gumbel_sample(logits, temperature=1.0): """Gumbel softmax sampling""" gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) return ((logits / max(temperature, 1e-10)) + gumbels).argmax(dim=-1) # ============ Main Generation Pipeline ============ def generate_motion(text, motion_length=0, seed=None): """Generate motion from text prompt with CFG""" if seed is not None: torch.manual_seed(seed) np.random.seed(seed) mean = np.load(ONNX_DIR / "mean.npy") std = np.load(ONNX_DIR / "std.npy") tokens = clip.tokenize([text], truncate=True) clip_sess = get_session("clip_text") text_emb = clip_sess.run(None, {"text_tokens": tokens.numpy()})[0] zero_emb = np.zeros_like(text_emb) # For CFG unconditional path if motion_length <= 0: len_sess = get_session("length_estimator") len_logits = len_sess.run(None, {"text_embedding": text_emb})[0] probs = torch.softmax(torch.from_numpy(len_logits), dim=-1) token_len = torch.multinomial(probs, 1).item() else: token_len = int(motion_length * 20 / 4) token_len = max(2, min(token_len, 49)) m_length = token_len * 4 max_len = 49 print(f"Generating motion: '{text}' ({m_length} frames, {m_length/20:.1f}s)") mask_id = 512 pad_id = 513 ids = torch.full((1, max_len), pad_id, dtype=torch.long) ids[:, :token_len] = mask_id scores = torch.zeros(1, max_len) scores[:, token_len:] = 1e5 padding_mask = np.zeros((1, max_len), dtype=bool) padding_mask[:, token_len:] = True mask_sess = get_session("mask_transformer") for step in range(TIMESTEPS): t = step / TIMESTEPS rand_mask_prob = cosine_schedule(torch.tensor(t)).item() num_masked = max(1, int(rand_mask_prob * token_len)) valid_scores = scores[:, :token_len].clone() _, sorted_idx = valid_scores.sort(dim=1) mask_pos = sorted_idx[:, :num_masked] is_mask = torch.zeros(1, token_len, dtype=torch.bool) is_mask.scatter_(1, mask_pos, True) ids[:, :token_len] = torch.where(is_mask, mask_id, ids[:, :token_len]) # CFG: conditional and unconditional logits cond_logits = mask_sess.run(None, { "motion_ids": ids.numpy(), "cond_vector": text_emb, "padding_mask": padding_mask })[0] uncond_logits = mask_sess.run(None, { "motion_ids": ids.numpy(), "cond_vector": zero_emb, "padding_mask": padding_mask })[0] logits = uncond_logits + (cond_logits - uncond_logits) * MASK_COND_SCALE logits = torch.from_numpy(logits) logits = logits[:, :512, :token_len] logits = logits.permute(0, 2, 1) filtered_logits = top_k_filter(logits / TEMPERATURE, TOPK_FILTER) new_ids = gumbel_sample(filtered_logits, TEMPERATURE) probs = torch.softmax(filtered_logits, dim=-1) new_scores = probs.gather(-1, new_ids.unsqueeze(-1)).squeeze(-1) ids[:, :token_len] = torch.where(is_mask, new_ids, ids[:, :token_len]) scores[:, :token_len] = torch.where(is_mask, new_scores, scores[:, :token_len]) res_sess = get_session("residual_transformer") num_quantizers = 6 res_token_embed = np.load(ONNX_DIR / "res_token_embed.npy") all_codes = torch.zeros(1, max_len, num_quantizers, dtype=torch.long) all_codes[:, :, 0] = ids history_sum = np.zeros((1, max_len, 512), dtype=np.float32) motion_ids = ids.clone() for q in range(1, num_quantizers): token_embed = res_token_embed[q-1] clamped_ids = np.clip(motion_ids[0].numpy(), 0, 512) gathered = token_embed[clamped_ids] history_sum += gathered[np.newaxis, :, :] q_id = np.array([q], dtype=np.int64) # CFG for residual transformer cond_logits = res_sess.run(None, { "motion_codes": history_sum.astype(np.float32), "q_id": q_id, "cond_vector": text_emb, "padding_mask": padding_mask })[0] uncond_logits = res_sess.run(None, { "motion_codes": history_sum.astype(np.float32), "q_id": q_id, "cond_vector": zero_emb, "padding_mask": padding_mask })[0] logits = uncond_logits + (cond_logits - uncond_logits) * RES_COND_SCALE logits = torch.from_numpy(logits)[:, :512, :token_len].permute(0, 2, 1) new_ids_q = gumbel_sample(logits, 1.0) all_codes[:, :token_len, q] = new_ids_q motion_ids[:, :token_len] = new_ids_q decoder_sess = get_session("vqvae_decoder") valid_codes = all_codes[:, :token_len, :].numpy() motion = decoder_sess.run(None, { "code_indices": valid_codes })[0] # Decoder already upsamples 4x internally, just slice to exact length motion = motion[:, :m_length, :] motion = motion * std + mean motion_tensor = torch.from_numpy(motion).float() joints = recover_from_ric(motion_tensor, JOINTS_NUM) joints = joints.squeeze(0).numpy() # Apply foot skating fix joints = fix_foot_skating(joints) video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name plot_3d_motion(video_path, joints, text, fps=20) bvh_path = tempfile.NamedTemporaryFile(suffix=".bvh", delete=False).name joints_to_bvh(joints, bvh_path, fps=20) return joints, video_path, bvh_path # ============ Gradio Interface ============ def create_demo(): import gradio as gr def generate_fn(text, length, seed): if not text or text.strip() == "": return None, None seed = int(seed) if seed else None length = float(length) if length else 0 joints, video_path, bvh_path = generate_motion(text, length, seed) return video_path, bvh_path with gr.Blocks(title="MoMask") as demo: gr.Markdown("## [MoMask](https://github.com/EricGuo5513/momask-codes) - Text to Motion") gr.Markdown("Generate 3D human skeleton animations from text descriptions. Download BVH for Blender!") with gr.Row(): with gr.Column(): text = gr.Textbox(label="Prompt Motion", placeholder="A person walks forward", value="A person walks forward", lines=2) with gr.Row(): length = gr.Number(label="Duration (sec)", value=0, info="0 = auto-estimate") seed = gr.Number(label="Seed", value=42, info="For reproducibility") btn = gr.Button("Generate", variant="primary") with gr.Column(): video = gr.Video(label="Generated Motion") bvh_file = gr.File(label="BVH Download (for Blender)") gr.Examples( examples=[ ["A person walks forward", 0, 42], ["A person is running on a treadmill", 0, 123], ["A person jumps up and then lands", 0, 456], ["A person does a salsa dance", 0, 789], ["A person kicks with their right leg", 0, 101], ], inputs=[text, length, seed], outputs=[video, bvh_file], fn=generate_fn, cache_examples=False, ) btn.click(fn=generate_fn, inputs=[text, length, seed], outputs=[video, bvh_file]) return demo # ============ CLI ============ if __name__ == "__main__": if len(sys.argv) > 1: text = sys.argv[1] length = float(sys.argv[2]) if len(sys.argv) > 2 else 0 seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42 joints, video_path, bvh_path = generate_motion(text, length, seed) print(f"Video: {video_path}") print(f"BVH: {bvh_path}") print(f"Joints shape: {joints.shape}") else: demo = create_demo() demo.launch()