VITRA / vitra /datasets /human_dataset.py
arnoldland's picture
Initial commit
aae3ba1
import bisect
import copy
import json
import math
import os
import random
import time
from functools import lru_cache
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from PIL import Image
from scipy.spatial.transform import Rotation as R
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from vitra.datasets.augment_utils import (
augmentation_func,
center_crop_short_side,
project_to_image_space,
)
from vitra.datasets.interp_utils import interp_mano_state
from vitra.datasets.video_utils import load_video_decord
from vitra.datasets.dataset_utils import (
compute_new_intrinsics_crop,
compute_new_intrinsics_resize,
calculate_fov,
ActionFeature,
StateFeature,
)
from vitra.utils.data_utils import (
read_dataset_statistics,
GaussianNormalizer,
)
class EpisodicDatasetCore(object):
"""Core dataset class for episodic hand manipulation data.
Handles loading and processing of video frames, MANO hand parameters,
and action sequences for hand-centric manipulation tasks.
"""
def __init__(
self,
video_root,
annotation_file,
label_folder,
training_path=None,
statistics_path=None,
augmentation=True,
flip_augmentation=True,
set_none_ratio=0.0,
action_type="angle",
use_rel=False,
upsample_factor=1.0,
target_image_width=224,
clip_len=2000,
state_mask_prob=0.1,
action_past_window_size=0,
action_future_window_size=15,
image_past_window_size=0,
image_future_window_size=0,
rel_mode="step",
load_images=True,
):
self.video_root = video_root
annotation_dict = np.load(annotation_file, allow_pickle=True)
self.label_folder = label_folder
self.index_frame_pair = annotation_dict['index_frame_pair'].copy()
self.index_to_episode_id = annotation_dict['index_to_episode_id'].copy()
if training_path is not None:
self.training_idx = np.load(training_path, allow_pickle=True)
self.num_valid_frames = len(self.training_idx)
else:
self.training_idx = None
self.num_valid_frames = len(self.index_frame_pair)
if statistics_path is not None:
self.data_statistics = read_dataset_statistics(statistics_path)
self.global_data_statistics = None
self.clip_len = clip_len # Video clip length in frames
self.augmentation = augmentation
self.target_image_width = target_image_width
self.flip_augmentation = flip_augmentation
self.set_none_ratio = set_none_ratio
self.action_type = action_type # "angle" (Euler angles) or "keypoints" (3D joint positions)
self.use_rel = use_rel # Whether to use relative delta as actions for hand poses (MANO poses)
assert upsample_factor >= 1.0, "only support upsample_factor >= 1.0"
self.upsample_factor = upsample_factor
self.state_mask_prob = state_mask_prob # Probability of masking state input
self.action_past_window_size=action_past_window_size
self.action_future_window_size=action_future_window_size
self.image_past_window_size=image_past_window_size
self.image_future_window_size=image_future_window_size
self.rel_mode=rel_mode
self.load_images=load_images
def __len__(self):
return self.num_valid_frames
@staticmethod
@lru_cache(maxsize=256) # ~256 MB worst case if each npy ≈1 MB
def _load_episode_npy(episode_path: str):
"""Load episode data from .npy file with caching.
Uses LRU cache to keep up to 256 episodes in memory (~256 MB worst case).
The cache automatically purges old entries when full.
Args:
episode_path: Path to the .npy file containing episode data
Returns:
Dictionary containing episode information
"""
return np.load(episode_path, allow_pickle=True).item()
def _load_or_cache_episode(self, episode_id):
"""
Returns episode_result (raw dict) and the pre-extracted camera
extrinsics (R_w2c, t_w2c). No camera-space MANO tensors are cached.
"""
root = self.label_folder
epi_path = os.path.join(root, episode_id + '.npy')
epi = self._load_episode_npy(epi_path)
extr = epi['extrinsics'] # world to cam, (T,4,4)
R_w2c, t_w2c = extr[:, :3, :3], extr[:, :3, 3]
return epi, R_w2c, t_w2c
def _mat2euler(self, R_batch: np.ndarray) -> np.ndarray:
"""Batched XYZ-Euler conversion using SciPy."""
flat = R_batch.reshape(-1, 3, 3)
eul = R.from_matrix(flat).as_euler('xyz', degrees=False)
return eul
def _prepare_side_window(self, side_dict,
R_w2c, t_w2c,
idx_window, idx_anchor,
*, anchor_frame=True, oob=None, start=None, end=None, upsample_factor=1.0):
T, W = len(side_dict['global_orient_worldspace']), len(idx_window)
idx_window_extend = np.append(idx_window, np.clip(idx_window[-1] + 1, start, end))
kept_extend = side_dict['kept_frames'][idx_window_extend].astype(bool)
R_mano_extend = side_dict['global_orient_worldspace'][idx_window_extend]
t_mano_extend = side_dict['transl_worldspace'][idx_window_extend]
hand_P_extend = side_dict['hand_pose'][idx_window_extend]
joints_worldspace_extend = side_dict['joints_worldspace'][idx_window_extend]
oob_indices = np.where(oob)[0]
if len(oob_indices) > 0:
# If more than 0 frames are out of bounds, set kept to False
kept_extend[oob_indices] = False
# consider OOB for the last frame
if idx_window[-1] + 1 > end:
kept_extend[-1] = False
if not np.all(kept_extend):
identity = np.eye(3, dtype=hand_P_extend.dtype)
identity_block = np.broadcast_to(identity, (hand_P_extend.shape[1], 3, 3))
hand_P_extend[~kept_extend] = identity_block
R_mano_extend[~kept_extend] = identity
# -------- camera-space (anchor camera) ----------------------------
R_cam_extend = R_w2c[idx_anchor] @ R_mano_extend
t_cam_extend = (R_w2c[idx_anchor] @ t_mano_extend[..., None])[..., 0] + t_w2c[idx_anchor]
# -------- finger Euler (batched) ----------------------------------
pose_euler_extend = R.from_matrix(hand_P_extend.reshape(-1,3,3)) \
.as_euler('xyz', degrees=False).reshape(-1,45)
# -------- keypoints in mano space (batched) ----------------------------------
joints_manospace_extend = (R_mano_extend.transpose(0, 2, 1) @ (joints_worldspace_extend.transpose(0, 2, 1) - t_mano_extend[..., None])).transpose(0,2,1) # (W+1,21,3)
if upsample_factor > 1:
R_cam_extend, t_cam_extend, hand_P_extend, joints_manospace_extend, kept_extend = \
interp_mano_state(R_cam_extend, t_cam_extend, hand_P_extend, joints_manospace_extend, kept_extend, upsample_factor, method="pchip")
pose_euler_extend = R.from_matrix(hand_P_extend.reshape(-1,3,3)) \
.as_euler('xyz', degrees=False).reshape(-1,45)
# set length back to W+1
R_cam_extend = R_cam_extend[:W+1]
t_cam_extend = t_cam_extend[:W+1]
hand_P_extend = hand_P_extend[:W+1]
pose_euler_extend = pose_euler_extend[:W+1]
joints_manospace_extend = joints_manospace_extend[:W+1]
kept_extend = kept_extend[:W+1]
R_cam = R_cam_extend[:-1]
t_cam = t_cam_extend[:-1]
pose_euler = pose_euler_extend[:-1]
hand_P = hand_P_extend[:-1]
joints_manospace = joints_manospace_extend[:-1]
kept = kept_extend[:-1]
R_cam_next = R_cam_extend[1:]
t_cam_next = t_cam_extend[1:]
pose_euler_next = pose_euler_extend[1:]
hand_P_next = hand_P_extend[1:]
joints_manospace_next = joints_manospace_extend[1:]
kept_next = kept_extend[1:]
return dict(
# current frame tensors
R_cam=R_cam.astype(np.float32),
t_cam=t_cam.astype(np.float32),
pose_euler=pose_euler.astype(np.float32),
hand_P=hand_P.astype(np.float32),
joints_manospace = joints_manospace.astype(np.float32),
kept=kept,
# next-frame tensors (same length W)
R_cam_next = R_cam_next.astype(np.float32),
t_cam_next = t_cam_next.astype(np.float32),
pose_euler_next = pose_euler_next.astype(np.float32),
hand_P_next = hand_P_next.astype(np.float32),
joints_manospace_next = joints_manospace_next.astype(np.float32),
kept_next = kept_next,
)
# ============================================================
# 4. Vectorised action-window constructor (ONE hand)
# ============================================================
def _make_action_window_vec(self, win, anchor_idx: int, *, rel_mode="step", action_type="angle"):
"""
anchor_idx : the position of t0 inside the window (usually = past)
rel_mode : "step" → Δ(t→t+1)
"anchor" → Δ(t→t0)
action_type: "angle" → Euler angles (xyz)
"keypoints " → root keypoints (21x3=63)
"""
R_cur, t_cur = win['R_cam'], win['t_cam']
R_nxt, t_nxt = win['R_cam_next'], win['t_cam_next']
P_cur, P_nxt = win['hand_P'], win['hand_P_next']
pose_next = win['pose_euler_next']
kpoints_root_next = win['joints_manospace_next']
kept, kept_n = win['kept'], win['kept_next']
W = len(t_cur)
# absolute pose of t+1
if action_type == "keypoints":
abs_next = kpoints_root_next.reshape(W, -1)
elif action_type == "angle":
abs_next = pose_next
action_abs = np.concatenate(
[t_nxt,
self._mat2euler(R_nxt),
abs_next],
axis=-1).astype(np.float32)
action_abs = action_abs.reshape(W, -1)
# choose relative formulation
if rel_mode == "step":
t_rel = t_nxt - t_cur
R_rel = R_nxt @ R_cur.transpose(0,2,1)
P_rel = np.matmul(P_nxt, P_cur.transpose(0,1,3,2))
valid = kept & kept_n
elif rel_mode == "anchor":
t_anchor = t_cur[anchor_idx]
R_anchor = R_cur[anchor_idx]
P_anchor = P_cur[anchor_idx]
# broadcast to all W rows
t_rel = t_nxt - t_anchor
R_rel = R_nxt @ R_anchor.T
P_rel = np.matmul(P_nxt, P_anchor.transpose(0,2,1))
valid = kept_n & kept[anchor_idx]
else:
raise ValueError('rel_mode must be "step" or "anchor"')
pose_rel = R.from_matrix(P_rel.reshape(-1,3,3)) \
.as_euler('xyz',False).reshape(W,45)
action_rel = np.concatenate(
[t_rel,
self._mat2euler(R_rel),
pose_rel],
axis=-1).astype(np.float32)
action_abs[~valid] = 0.0
action_rel[~valid] = 0.0
return action_abs, action_rel, valid
def _window_indices(self, frame_id, past, future, start, end):
"""
Returns:
idx_clip : (W,) indices clipped to [0, T-1]
oob : (W,) bool — slots that were originally OOB
"""
win = np.arange(-past, future + 1) + frame_id # (W,)
oob = (win < start) | (win > end)
return win.clip(start, end), oob
def _resolve_video_path(self, dataset_name: str = None, video_name: str = None, part_index: int = None) -> str:
if dataset_name=='Ego4D':
if self.clip_len is not None:
video_path = os.path.join(self.video_root, video_name + '_part' + str(part_index+1) +'.mp4')
else:
video_path = os.path.join(self.video_root, video_name +'.mp4')
return video_path
elif dataset_name=='EgoExo4D':
if self.clip_len is not None:
video_path = os.path.join(self.video_root, video_name +'_part' + str(part_index+1) +'.mp4')
else:
video_path = os.path.join(self.video_root, video_name +'.mp4')
return video_path
elif dataset_name == 'epic':
video_id = video_name.split('_')[0]
if self.clip_len is not None:
video_path = os.path.join(self.video_root, video_name+ '_part' + str(part_index+1) + '.mp4')
else:
video_path = os.path.join(self.video_root, video_name+ '.MP4')
return video_path
elif dataset_name == 'somethingsomethingv2':
if self.clip_len is not None:
video_path = os.path.join(self.video_root, video_name+ '_part' + str(part_index+1) + '.mp4')
else:
video_path = os.path.join(self.video_root, video_name+'.webm')
return video_path
else:
raise ValueError(f'Unknown dataset prefix {dataset_name}')
def _mano_forward(self, betas, pose_m):
"""Runs MANO once and returns (vertices, joints) on CPU NumPy."""
beta_t = torch.tensor(betas).unsqueeze(0).float().cuda()
pose_t = torch.tensor(pose_m).unsqueeze(0).float().cuda()
out = mano(betas=beta_t, hand_pose=pose_t) # no global_orient
return out.vertices[0].cpu().numpy(), out.joints[0].cpu().numpy()
# ------------------------------------------------------------
# Grab the (past + future + 1) frame window
# ------------------------------------------------------------
def _pack_state(self, R_cam, t_cam, pose_euler, idx):
return np.concatenate([t_cam[idx],
self._mat2euler(R_cam[idx][None,...])[0],
pose_euler[idx]])
def _grab_window_images(self,
episode_id: str,
epi: dict,
frame_id: int,
past: int,
future: int
):
"""
Returns
-------
images : (L, H, W, 3) uint8 – raw RGB frames
mask : (L,) bool True where real frame, False where pad
"""
dataset_name = episode_id.split('_')[0]
# video_path = self._resolve_video_path(dataset_name, epi['video_name'])
decode_table = epi['video_decode_frame'] # (T,)
T = len(decode_table)
frame_in_video = epi['video_decode_frame'][frame_id]
# ---------- build padded window indices ---------------------------
# Not support multiple images now
idx_win, oob = self._window_indices(frame_id, past, future, 0, T-1) # (W,)
if self.clip_len is not None:
part_idx = frame_in_video // self.clip_len # clip_len = 2000
frame_in_part = frame_in_video % self.clip_len
video_path = self._resolve_video_path(dataset_name, epi['video_name'], part_idx)
decode_ids = [frame_in_part]
else:
video_path = self._resolve_video_path(dataset_name, epi['video_name'])
decode_ids = decode_table[idx_win]
# ---------- read images --------------------
# Retry mechanism: try up to 3 times to load video frames
for attempt in range(3):
try:
imgs, _ = load_video_decord(video_path, frame_index=decode_ids, rotation=False)
break # Success, exit the retry loop
except Exception as e:
# if attempt == 2:
# raise # Raise the exception after 3 failed attempts
print(f"Warning: failed to load video frames from {video_path} (attempt {attempt+1}/3): {e}")
time.sleep(0.1)
images = np.stack(imgs, axis=0) # (L,H,W,3) uint8
mask = ~oob # (L,) bool
return images, mask
def _find_matching_texts(self, text_list, frame_id):
"""Find text annotations that overlap with the given frame.
Args:
text_list: List of tuples (text, (start_frame, end_frame))
frame_id: Current frame ID to check
Returns:
matching_texts: List of matching text annotations
matching_ranges: List of corresponding time ranges (start_frame, end_frame)
Note:
Uses half-open interval [start_frame, end_frame)
"""
matching_texts = []
matching_ranges = []
for text, (start_frame, end_frame) in text_list:
# Check if frame_id is in the half-open interval [start_frame, end_frame)
if start_frame <= frame_id < end_frame:
matching_texts.append(text)
matching_ranges.append((start_frame, end_frame))
return matching_texts, matching_ranges
def _random_select_text(
self,
text,
text_rephrase,
hand_type,
clip_idx,
):
text_list = [text]
if text_rephrase and isinstance(text_rephrase[hand_type][clip_idx][0], list):
text_list.extend(text_rephrase[hand_type][clip_idx][0])
text_selected = random.choice(text_list).strip()
if not text_selected.endswith('.'):
text_selected += '.'
return text_selected
def _build_instruction(
self,
main_type,
text_clip,
text_rephrase,
idx_win,
oob,
epi_len, # T
frame_id,
action_past_window_size,
action_future_window_size,
):
sub_type = 'right' if main_type == 'left' else 'left'
# Build main text
# text_clip[main_type][0]: ('Place the pink cup on the table.', (0, 26))
main_text_selected = self._random_select_text(
text_clip[main_type][0][0],
text_rephrase,
main_type,
clip_idx=0,
)
# Build sub text
sub_text_list = text_clip[sub_type]
has_sub_text = len(sub_text_list) > 0
sub_oob, sub_idx_win = oob, idx_win
sub_text_selected = "None."
sub_win = (0, epi_len) # Default to the full range if no text available
if has_sub_text:
sub_matching_texts, sub_matching_ranges = self._find_matching_texts(sub_text_list, frame_id)
if len(sub_matching_texts) > 0:
selected_idx = random.randrange(len(sub_matching_texts))
sub_win = sub_matching_ranges[selected_idx]
sub_idx_win, sub_oob = self._window_indices(
frame_id,
action_past_window_size,
action_future_window_size, sub_win[0], sub_win[1]-1
) # (W,)
sub_text_selected = self._random_select_text(
sub_matching_texts[selected_idx].strip(),
text_rephrase,
sub_type,
clip_idx=selected_idx,
)
# Assign left/right based on main_type
is_main_left = (main_type == 'left')
idx_win_left = idx_win if is_main_left else sub_idx_win
idx_win_right = sub_idx_win if is_main_left else idx_win
oob_left = oob if is_main_left else sub_oob
oob_right = sub_oob if is_main_left else oob
text_left = main_text_selected if is_main_left else sub_text_selected
text_right = sub_text_selected if is_main_left else main_text_selected
start_left = 0 if is_main_left else (sub_win[0] if has_sub_text else 0)
start_right = (sub_win[0] if has_sub_text else 0) if is_main_left else 0
end_left = epi_len - 1 if is_main_left else (sub_win[1] - 1 if has_sub_text else epi_len - 1)
end_right = (sub_win[1] - 1 if has_sub_text else epi_len - 1) if is_main_left else epi_len - 1
instruction = f"Left hand: {text_left} Right hand: {text_right}"
return instruction, idx_win_left, oob_left, idx_win_right, oob_right, start_left, end_left, start_right, end_right
def _get_2d_traj_cur_to_end(self, idx_frame, epi, intrinsics, hand_type, image_size):
"""Get the 2D trajectory of the hand palm from current frame to episode end.
Args:
idx_frame: Current frame index
epi: Episode data dictionary
intrinsics: Camera intrinsic matrix
hand_type: 'left' or 'right' hand
image_size: (H, W) tuple of image dimensions
Returns:
Normalized 2D palm trajectory in image space [0, 1]
"""
H, W = image_size
# intrinsics = epi['intrinsics'].copy()
intrinsics = intrinsics.copy()
# normalize intrinsics
intrinsics[0] /= intrinsics[0,2]*2
intrinsics[1] /= intrinsics[1,2]*2
hand_joints_cur_to_end = epi[hand_type]['joints_worldspace'][idx_frame:] # (N, 21, 3)
hand_palm_cur_to_end = np.mean(hand_joints_cur_to_end[:, [0,2,5,9,13,17], :], axis=1, keepdims=True) # (N, 1, 3)
extrinsics = epi['extrinsics'].copy()
extrinsics_cur = extrinsics[idx_frame] # world to cam
R_world_to_cam = extrinsics_cur[None, :3, :3].repeat(len(hand_palm_cur_to_end), axis=0)
t_world_to_cam = extrinsics_cur[None, :3, 3:].repeat(len(hand_palm_cur_to_end), axis=0)
hand_palm_cur_to_end_cam = (R_world_to_cam @ hand_palm_cur_to_end.transpose(0, 2, 1) + t_world_to_cam).transpose(0, 2, 1)
uv_palm_cur_to_end = project_to_image_space(hand_palm_cur_to_end_cam, intrinsics, (H, W)) # (N, M, 2)
uv_palm_cur_to_end[..., 0] = np.clip(uv_palm_cur_to_end[..., 0], 0, W)
uv_palm_cur_to_end[..., 1] = np.clip(uv_palm_cur_to_end[..., 1], 0, H)
uv_palm_cur_to_end = uv_palm_cur_to_end.reshape(-1, 2)
uv_palm_cur_to_end = uv_palm_cur_to_end.astype(np.float32)
uv_palm_cur_to_end[:,0] /= W
uv_palm_cur_to_end[:,1] /= H
return uv_palm_cur_to_end
def get_item_frame(
self, episode_id, frame_id,
action_past_window_size=0,
action_future_window_size=0,
image_past_window_size=0,
image_future_window_size=0,
rel_mode: str = "step",
load_images: bool = True,
):
"""
Vectorised dataloader.
"""
# ------------------------------------------------------------------
# 1. Load episode dict + extrinsics
# ------------------------------------------------------------------
epi, R_w2c, t_w2c = self._load_or_cache_episode(episode_id)
T = len(epi['extrinsics']) #
# ------------------------------------------------------------------
# 2. Build frame-window indices
# ------------------------------------------------------------------
idx_win, oob = self._window_indices(frame_id,
action_past_window_size,
action_future_window_size, 0, T-1) # (W,)
W = len(idx_win)
main_type = epi['anno_type']
sub_type = 'right' if main_type == 'left' else 'left'
# ------------------------------------------------------------------
# 3. Build instruction text
# ------------------------------------------------------------------
instruction, idx_win_left, oob_left, idx_win_right, oob_right, \
start_left, end_left, start_right, end_right = self._build_instruction(
main_type = main_type,
text_clip = epi['text'],
text_rephrase = epi.get('text_rephrase'),
idx_win = idx_win,
oob = oob,
epi_len = T,
frame_id = frame_id,
action_past_window_size = action_past_window_size,
action_future_window_size = action_future_window_size,
)
# ------------------------------------------------------------------
# 4. Vectorised actions (left + right)
# ------------------------------------------------------------------
win_left = self._prepare_side_window(
epi['left'], R_w2c, t_w2c, idx_win_left, frame_id, anchor_frame=True,
oob=oob_left, start=start_left, end=end_left, upsample_factor=self.upsample_factor
)
win_right = self._prepare_side_window(
epi['right'], R_w2c, t_w2c, idx_win_right, frame_id, anchor_frame=True,
oob=oob_right, start=start_right, end=end_right, upsample_factor=self.upsample_factor
)
idx_center = action_past_window_size # local index of t0 in window
# rel_mode: "step" or "anchor" / action_type: "angle" or "keypoints"
# step: relative to previous frame, anchor: relative to t0
abs_L, rel_L, msk_L = self._make_action_window_vec(
win_left, anchor_idx=idx_center, rel_mode=rel_mode, action_type=self.action_type
)
abs_R, rel_R, msk_R = self._make_action_window_vec(
win_right, anchor_idx=idx_center, rel_mode=rel_mode, action_type=self.action_type
)
action_abs = np.concatenate([abs_L, abs_R], axis=1) # (W,action_dim)
action_rel = np.concatenate([rel_L, rel_R], axis=1) # (W,102)
action_mask = np.stack([msk_L, msk_R], axis=1) # (W,2)
cur_L = self._pack_state(win_left['R_cam'],
win_left['t_cam'],
win_left['pose_euler'] if self.action_type=='angle' else win_left['joints_manospace'].reshape(W, -1),
idx_center)
cur_R = self._pack_state(win_right['R_cam'],
win_right['t_cam'],
win_right['pose_euler'] if self.action_type=='angle' else win_right['joints_manospace'].reshape(W, -1),
idx_center)
betas_L = epi['left']['beta']
betas_R = epi['right']['beta']
current_state = np.concatenate([cur_L, betas_L, cur_R, betas_R]) # 2 * (6+action_dim+10,)
# current_state_mask = np.array([msk_L[idx_center],
# msk_R[idx_center]])
current_state_mask = np.array([win_left['kept'][idx_center], win_right['kept'][idx_center]])
# ------------------------------------------------------------------
# 5. RGB window
# ------------------------------------------------------------------
if load_images:
image_list, image_mask = self._grab_window_images(
episode_id, epi,
frame_id,
image_past_window_size,
image_future_window_size
)
H = image_list[0].shape[0]
W = image_list[0].shape[1]
else:
image_list = None
image_mask = None
H, W = epi['intrinsics'][1,2]*2, epi['intrinsics'][0,2]*2
# ------------------------------------------------------------------
# 6. Calculate New_intrinsics
# ------------------------------------------------------------------
dataset_name = episode_id.split('_')[0]
intrinsics = epi['intrinsics']
if dataset_name == 'EgoExo4D':
# For EgoExo4D, the fisheye camera images contain black borders after undistortion.
# We remove these borders using a center crop. Specifically, the video frames are
# first resized from 1408 to 448, and then center-cropped to 256.
new_intrinsics = compute_new_intrinsics_crop(intrinsics, 1408, 256/448*1408, H)
else:
new_intrinsics = compute_new_intrinsics_resize(intrinsics, (H, W))
# ------------------------------------------------------------------
# 7. Do augmentation
# ------------------------------------------------------------------
if self.augmentation:
try:
# randomly sample aspect ratio for augmentation
aspect_ratio = np.exp(random.uniform(np.log(1.0), np.log(2.0)))
target_size = (int(self.target_image_width * aspect_ratio), self.target_image_width) # (W, H)
augment_params = {
'tgt_aspect': aspect_ratio,
'flip_augmentation': self.flip_augmentation,
'set_none_ratio': self.set_none_ratio,
}
uv_traj = self._get_2d_traj_cur_to_end(frame_id, epi, new_intrinsics, main_type, (H, W))
image_list, new_intrinsics, (action_abs, action_rel, action_mask), \
(current_state, current_state_mask), instruction = \
augmentation_func(
image = image_list,
intrinsics = new_intrinsics,
actions = (action_abs, action_rel, action_mask),
states = (current_state, current_state_mask),
captions = instruction,
uv_traj = uv_traj,
target_size = target_size,
augment_params = augment_params,
sub_type = sub_type,
)
except Exception as e:
print(f"Warning: Augmentation failed for episode {episode_id}, frame {frame_id}: {e}. Do center crop only")
import traceback
print(f"Warning: Augmentation failed for episode {episode_id}, frame {frame_id}")
print(f"Exception: {type(e).__name__}: {e}")
print(f"Traceback:\n{traceback.format_exc()}")
image_list = center_crop_short_side(image_list[0])[None, ...]
new_intrinsics[0][2] = 0.5 * image_list[0].shape[1] # update the principal point
new_intrinsics[1][2] = 0.5 * image_list[0].shape[0] # update the principal point
if random.random() < self.state_mask_prob:
current_state_mask = np.array([False, False])
current_state[:] = 0.0
fov = calculate_fov( 2 * new_intrinsics[1][2], 2 * new_intrinsics[0][2], new_intrinsics)
if self.use_rel:
action_list = action_rel
else:
# use abs action for hand pose only
rel_L = action_rel[:, :action_rel.shape[1]//2]
rel_R = action_rel[:, action_rel.shape[1]//2:]
abs_L = action_abs[:, :action_abs.shape[1]//2]
abs_R = action_abs[:, action_abs.shape[1]//2:]
action_list = np.concatenate([rel_L[:, :6], abs_L[:, 6:], rel_R[:, :6], abs_R[:, 6:]], axis=1)
# ------------------------------------------------------------------
# 8. Return to caller
# ------------------------------------------------------------------
result_dict = dict(
instruction = instruction,
action_list = action_list, # (W,2*51) float32
action_mask = action_mask, # (W,2) bool
current_state = current_state, # (2*61,) float32
current_state_mask = current_state_mask, # (2,) bool
fov = fov, # (2,) float32
intrinsics = new_intrinsics, # (3,3) float32
)
if image_list is not None:
result_dict['image_list'] = image_list # (W,H,W,3) uint8
if image_mask is not None:
result_dict['image_mask'] = image_mask # (W,) bool
return result_dict
def set_global_data_statistics(self, global_data_statistics):
self.global_data_statistics = global_data_statistics
if not hasattr(self, 'gaussian_normalizer'):
self.gaussian_normalizer = GaussianNormalizer(self.global_data_statistics)
def transform_trajectory(
self,
sample_dict: dict = None,
normalization: bool = True,
):
"""Pad action and state dimensions to match a unified size."""
action_np = sample_dict["action_list"]
state_np = sample_dict["current_state"]
action_dim = action_np.shape[1]
state_dim = state_np.shape[0]
if normalization:
# Normalize left and right hand actions and states separately
action_np = self.gaussian_normalizer.normalize_action(action_np)
state_np = self.gaussian_normalizer.normalize_state(state_np)
# ===== Pad to unified dimensions =====
unified_action_dim = ActionFeature.ALL_FEATURES[1] # 192
unified_state_dim = StateFeature.ALL_FEATURES[1] # 212
unified_state, unified_state_mask = pad_state_human(
state_np,
sample_dict["current_state_mask"],
action_dim,
state_dim,
unified_state_dim
)
unified_action, unified_action_mask = pad_action(
action_np,
sample_dict["action_mask"],
action_dim,
unified_action_dim
)
sample_dict["action_list"] = unified_action
sample_dict["action_mask"] = unified_action_mask
sample_dict["current_state"] = unified_state
sample_dict["current_state_mask"] = unified_state_mask
return sample_dict
def __getitem__(self, idx):
if self.training_idx is not None:
data_id = self.training_idx[idx]
else:
data_id = idx
corr = self.index_frame_pair[data_id]
episode_id = self.index_to_episode_id[corr[0]]
sample = self.get_item_frame(
episode_id, int(corr[1]),
action_past_window_size=self.action_past_window_size,
action_future_window_size=self.action_future_window_size,
image_past_window_size=self.image_past_window_size,
image_future_window_size=self.image_future_window_size,
rel_mode=self.rel_mode, # 'step'
load_images=self.load_images
)
return sample
def pad_state_human(
state: torch.Tensor,
state_mask: torch.Tensor,
action_dim: int,
state_dim: int,
unified_state_dim: int
):
"""
Expand state mask, mask invalid state dims, and pad current_state to a standard size.
Args:
current_state (Tensor): original state tensor, shape [state_dim]
current_state_mask (Tensor): per-hand state mask, shape [state_dim//2] or [state_dim]
action_dim (int): original action dimension
state_dim (int): original state dimension
unified_state_dim (int): target padded state dimension
Returns:
Tuple[Tensor, Tensor]:
padded current_state [unified_state_dim],
padded current_state_mask [unified_state_dim]
"""
current_state = torch.tensor(state, dtype=torch.float32)
current_state_mask = torch.tensor(state_mask, dtype=torch.bool)
# Expand state mask from per-hand to per-dim
expanded_state_mask = current_state_mask.repeat_interleave(state_dim // 2)
# Mask out invalid state dimensions
current_state_masked = current_state * expanded_state_mask.to(current_state.dtype)
# Initialize output tensors
padded_state = torch.zeros(unified_state_dim, dtype=current_state.dtype)
padded_mask = torch.zeros(unified_state_dim, dtype=torch.bool)
# Fill first half of state_dim (left hand), skipping MANO betas
padded_state[:action_dim//2] = current_state_masked[:action_dim//2].clone()
padded_mask[:action_dim//2] = expanded_state_mask[:action_dim//2].clone()
# Fill second half of state_dim (right hand), skipping MANO betas
padded_state[action_dim//2:action_dim] = current_state_masked[state_dim//2:state_dim//2+action_dim//2].clone()
padded_mask[action_dim//2:action_dim] = expanded_state_mask[state_dim//2:state_dim//2+action_dim//2].clone()
return padded_state, padded_mask
def pad_action(
actions: torch.Tensor = None,
action_mask: torch.Tensor = None,
action_dim: int = None,
unified_action_dim: int = None
):
"""
Expand action mask per dimension, mask invalid actions, and pad actions to a unified size.
Args:
actions (Tensor or None): original actions tensor, shape [T, action_dim] or None.
action_mask (Tensor): per-hand action mask, shape [T, 2].
action_dim (int): original action dimension.
unified_action_dim (int): target padded actions dimension.
Returns:
Tuple[Optional[Tensor], Tensor]:
padded actions [T, unified_action_dim] or None,
padded action mask [T, unified_action_dim]
"""
action_mask = torch.tensor(action_mask, dtype=torch.bool)
# Expand mask from per-hand to per-dimension
mask_left = action_mask[:, 0].unsqueeze(1).expand(-1, action_dim // 2)
mask_right = action_mask[:, 1].unsqueeze(1).expand(-1, action_dim // 2)
expanded_action_mask = torch.cat((mask_left, mask_right), dim=1)
# ---------------------------
# Case 1: actions is None
# ---------------------------
if actions is None:
padding_mask = torch.zeros(
(action_mask.shape[0], unified_action_dim - action_dim),
dtype=torch.bool
)
action_mask_padded = torch.cat((expanded_action_mask, padding_mask), dim=1)
return None, action_mask_padded
# ---------------------------
# Case 2: actions exists
# ---------------------------
actions = torch.tensor(actions, dtype=torch.float32)
# Mask invalid action dims
actions_masked = actions * expanded_action_mask.to(actions.dtype)
# Pad both actions and mask
padding = torch.zeros(
(actions.shape[0], unified_action_dim - action_dim),
dtype=actions.dtype
)
actions_padded = torch.cat((actions_masked, padding), dim=1)
action_mask_padded = torch.cat((expanded_action_mask, padding.bool()), dim=1)
return actions_padded, action_mask_padded