|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from collections import defaultdict |
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import imageio |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import tqdm |
|
|
import tyro |
|
|
import yaml |
|
|
from fused_ssim import fused_ssim |
|
|
from gsplat.distributed import cli |
|
|
from gsplat.rendering import rasterization |
|
|
from gsplat.strategy import DefaultStrategy, MCMCStrategy |
|
|
from torch import Tensor |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torchmetrics.image import ( |
|
|
PeakSignalNoiseRatio, |
|
|
StructuralSimilarityIndexMeasure, |
|
|
) |
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|
|
from typing_extensions import Literal, assert_never |
|
|
from embodied_gen.data.datasets import PanoGSplatDataset |
|
|
from embodied_gen.utils.config import GsplatTrainConfig |
|
|
from embodied_gen.utils.gaussian import ( |
|
|
create_splats_with_optimizers, |
|
|
export_splats, |
|
|
resize_pinhole_intrinsics, |
|
|
set_random_seed, |
|
|
) |
|
|
|
|
|
|
|
|
class Runner: |
|
|
"""Engine for training and testing from gsplat example. |
|
|
|
|
|
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
local_rank: int, |
|
|
world_rank, |
|
|
world_size: int, |
|
|
cfg: GsplatTrainConfig, |
|
|
) -> None: |
|
|
set_random_seed(42 + local_rank) |
|
|
|
|
|
self.cfg = cfg |
|
|
self.world_rank = world_rank |
|
|
self.local_rank = local_rank |
|
|
self.world_size = world_size |
|
|
self.device = f"cuda:{local_rank}" |
|
|
|
|
|
|
|
|
os.makedirs(cfg.result_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.ckpt_dir = f"{cfg.result_dir}/ckpts" |
|
|
os.makedirs(self.ckpt_dir, exist_ok=True) |
|
|
self.stats_dir = f"{cfg.result_dir}/stats" |
|
|
os.makedirs(self.stats_dir, exist_ok=True) |
|
|
self.render_dir = f"{cfg.result_dir}/renders" |
|
|
os.makedirs(self.render_dir, exist_ok=True) |
|
|
self.ply_dir = f"{cfg.result_dir}/ply" |
|
|
os.makedirs(self.ply_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") |
|
|
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train") |
|
|
self.valset = PanoGSplatDataset( |
|
|
cfg.data_dir, split="train", max_sample_num=6 |
|
|
) |
|
|
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval") |
|
|
self.scene_scale = cfg.scene_scale |
|
|
|
|
|
|
|
|
self.splats, self.optimizers = create_splats_with_optimizers( |
|
|
self.trainset.points, |
|
|
self.trainset.points_rgb, |
|
|
init_num_pts=cfg.init_num_pts, |
|
|
init_extent=cfg.init_extent, |
|
|
init_opacity=cfg.init_opa, |
|
|
init_scale=cfg.init_scale, |
|
|
means_lr=cfg.means_lr, |
|
|
scales_lr=cfg.scales_lr, |
|
|
opacities_lr=cfg.opacities_lr, |
|
|
quats_lr=cfg.quats_lr, |
|
|
sh0_lr=cfg.sh0_lr, |
|
|
shN_lr=cfg.shN_lr, |
|
|
scene_scale=self.scene_scale, |
|
|
sh_degree=cfg.sh_degree, |
|
|
sparse_grad=cfg.sparse_grad, |
|
|
visible_adam=cfg.visible_adam, |
|
|
batch_size=cfg.batch_size, |
|
|
feature_dim=None, |
|
|
device=self.device, |
|
|
world_rank=world_rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
print("Model initialized. Number of GS:", len(self.splats["means"])) |
|
|
|
|
|
|
|
|
self.cfg.strategy.check_sanity(self.splats, self.optimizers) |
|
|
|
|
|
if isinstance(self.cfg.strategy, DefaultStrategy): |
|
|
self.strategy_state = self.cfg.strategy.initialize_state( |
|
|
scene_scale=self.scene_scale |
|
|
) |
|
|
elif isinstance(self.cfg.strategy, MCMCStrategy): |
|
|
self.strategy_state = self.cfg.strategy.initialize_state() |
|
|
else: |
|
|
assert_never(self.cfg.strategy) |
|
|
|
|
|
|
|
|
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to( |
|
|
self.device |
|
|
) |
|
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) |
|
|
|
|
|
if cfg.lpips_net == "alex": |
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity( |
|
|
net_type="alex", normalize=True |
|
|
).to(self.device) |
|
|
elif cfg.lpips_net == "vgg": |
|
|
|
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity( |
|
|
net_type="vgg", normalize=False |
|
|
).to(self.device) |
|
|
else: |
|
|
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") |
|
|
|
|
|
def rasterize_splats( |
|
|
self, |
|
|
camtoworlds: Tensor, |
|
|
Ks: Tensor, |
|
|
width: int, |
|
|
height: int, |
|
|
masks: Optional[Tensor] = None, |
|
|
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, |
|
|
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[Tensor, Tensor, Dict]: |
|
|
means = self.splats["means"] |
|
|
|
|
|
|
|
|
quats = self.splats["quats"] |
|
|
scales = torch.exp(self.splats["scales"]) |
|
|
opacities = torch.sigmoid(self.splats["opacities"]) |
|
|
image_ids = kwargs.pop("image_ids", None) |
|
|
|
|
|
colors = torch.cat( |
|
|
[self.splats["sh0"], self.splats["shN"]], 1 |
|
|
) |
|
|
|
|
|
if rasterize_mode is None: |
|
|
rasterize_mode = ( |
|
|
"antialiased" if self.cfg.antialiased else "classic" |
|
|
) |
|
|
if camera_model is None: |
|
|
camera_model = self.cfg.camera_model |
|
|
|
|
|
render_colors, render_alphas, info = rasterization( |
|
|
means=means, |
|
|
quats=quats, |
|
|
scales=scales, |
|
|
opacities=opacities, |
|
|
colors=colors, |
|
|
viewmats=torch.linalg.inv(camtoworlds), |
|
|
Ks=Ks, |
|
|
width=width, |
|
|
height=height, |
|
|
packed=self.cfg.packed, |
|
|
absgrad=( |
|
|
self.cfg.strategy.absgrad |
|
|
if isinstance(self.cfg.strategy, DefaultStrategy) |
|
|
else False |
|
|
), |
|
|
sparse_grad=self.cfg.sparse_grad, |
|
|
rasterize_mode=rasterize_mode, |
|
|
distributed=self.world_size > 1, |
|
|
camera_model=self.cfg.camera_model, |
|
|
with_ut=self.cfg.with_ut, |
|
|
with_eval3d=self.cfg.with_eval3d, |
|
|
**kwargs, |
|
|
) |
|
|
if masks is not None: |
|
|
render_colors[~masks] = 0 |
|
|
return render_colors, render_alphas, info |
|
|
|
|
|
def train(self): |
|
|
cfg = self.cfg |
|
|
device = self.device |
|
|
world_rank = self.world_rank |
|
|
|
|
|
|
|
|
if world_rank == 0: |
|
|
with open(f"{cfg.result_dir}/cfg.yml", "w") as f: |
|
|
yaml.dump(vars(cfg), f) |
|
|
|
|
|
max_steps = cfg.max_steps |
|
|
init_step = 0 |
|
|
|
|
|
schedulers = [ |
|
|
|
|
|
torch.optim.lr_scheduler.ExponentialLR( |
|
|
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) |
|
|
), |
|
|
] |
|
|
trainloader = torch.utils.data.DataLoader( |
|
|
self.trainset, |
|
|
batch_size=cfg.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=4, |
|
|
persistent_workers=True, |
|
|
pin_memory=True, |
|
|
) |
|
|
trainloader_iter = iter(trainloader) |
|
|
|
|
|
|
|
|
global_tic = time.time() |
|
|
pbar = tqdm.tqdm(range(init_step, max_steps)) |
|
|
for step in pbar: |
|
|
try: |
|
|
data = next(trainloader_iter) |
|
|
except StopIteration: |
|
|
trainloader_iter = iter(trainloader) |
|
|
data = next(trainloader_iter) |
|
|
|
|
|
camtoworlds = data["camtoworld"].to(device) |
|
|
Ks = data["K"].to(device) |
|
|
pixels = data["image"].to(device) / 255.0 |
|
|
image_ids = data["image_id"].to(device) |
|
|
masks = ( |
|
|
data["mask"].to(device) if "mask" in data else None |
|
|
) |
|
|
if cfg.depth_loss: |
|
|
points = data["points"].to(device) |
|
|
depths_gt = data["depths"].to(device) |
|
|
|
|
|
height, width = pixels.shape[1:3] |
|
|
|
|
|
|
|
|
sh_degree_to_use = min( |
|
|
step // cfg.sh_degree_interval, cfg.sh_degree |
|
|
) |
|
|
|
|
|
|
|
|
renders, alphas, info = self.rasterize_splats( |
|
|
camtoworlds=camtoworlds, |
|
|
Ks=Ks, |
|
|
width=width, |
|
|
height=height, |
|
|
sh_degree=sh_degree_to_use, |
|
|
near_plane=cfg.near_plane, |
|
|
far_plane=cfg.far_plane, |
|
|
image_ids=image_ids, |
|
|
render_mode="RGB+ED" if cfg.depth_loss else "RGB", |
|
|
masks=masks, |
|
|
) |
|
|
if renders.shape[-1] == 4: |
|
|
colors, depths = renders[..., 0:3], renders[..., 3:4] |
|
|
else: |
|
|
colors, depths = renders, None |
|
|
|
|
|
if cfg.random_bkgd: |
|
|
bkgd = torch.rand(1, 3, device=device) |
|
|
colors = colors + bkgd * (1.0 - alphas) |
|
|
|
|
|
self.cfg.strategy.step_pre_backward( |
|
|
params=self.splats, |
|
|
optimizers=self.optimizers, |
|
|
state=self.strategy_state, |
|
|
step=step, |
|
|
info=info, |
|
|
) |
|
|
|
|
|
|
|
|
l1loss = F.l1_loss(colors, pixels) |
|
|
ssimloss = 1.0 - fused_ssim( |
|
|
colors.permute(0, 3, 1, 2), |
|
|
pixels.permute(0, 3, 1, 2), |
|
|
padding="valid", |
|
|
) |
|
|
loss = ( |
|
|
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda |
|
|
) |
|
|
if cfg.depth_loss: |
|
|
|
|
|
points = torch.stack( |
|
|
[ |
|
|
points[:, :, 0] / (width - 1) * 2 - 1, |
|
|
points[:, :, 1] / (height - 1) * 2 - 1, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
grid = points.unsqueeze(2) |
|
|
depths = F.grid_sample( |
|
|
depths.permute(0, 3, 1, 2), grid, align_corners=True |
|
|
) |
|
|
depths = depths.squeeze(3).squeeze(1) |
|
|
|
|
|
disp = torch.where( |
|
|
depths > 0.0, 1.0 / depths, torch.zeros_like(depths) |
|
|
) |
|
|
disp_gt = 1.0 / depths_gt |
|
|
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale |
|
|
loss += depthloss * cfg.depth_lambda |
|
|
|
|
|
|
|
|
if cfg.opacity_reg > 0.0: |
|
|
loss += ( |
|
|
cfg.opacity_reg |
|
|
* torch.sigmoid(self.splats["opacities"]).mean() |
|
|
) |
|
|
if cfg.scale_reg > 0.0: |
|
|
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
desc = ( |
|
|
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " |
|
|
) |
|
|
if cfg.depth_loss: |
|
|
desc += f"depth loss={depthloss.item():.6f}| " |
|
|
pbar.set_description(desc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
world_rank == 0 |
|
|
and cfg.tb_every > 0 |
|
|
and step % cfg.tb_every == 0 |
|
|
): |
|
|
mem = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
self.writer.add_scalar("train/loss", loss.item(), step) |
|
|
self.writer.add_scalar("train/l1loss", l1loss.item(), step) |
|
|
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) |
|
|
self.writer.add_scalar( |
|
|
"train/num_GS", len(self.splats["means"]), step |
|
|
) |
|
|
self.writer.add_scalar("train/mem", mem, step) |
|
|
if cfg.depth_loss: |
|
|
self.writer.add_scalar( |
|
|
"train/depthloss", depthloss.item(), step |
|
|
) |
|
|
if cfg.tb_save_image: |
|
|
canvas = ( |
|
|
torch.cat([pixels, colors], dim=2) |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
canvas = canvas.reshape(-1, *canvas.shape[2:]) |
|
|
self.writer.add_image("train/render", canvas, step) |
|
|
self.writer.flush() |
|
|
|
|
|
|
|
|
if ( |
|
|
step in [i - 1 for i in cfg.save_steps] |
|
|
or step == max_steps - 1 |
|
|
): |
|
|
mem = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
stats = { |
|
|
"mem": mem, |
|
|
"ellipse_time": time.time() - global_tic, |
|
|
"num_GS": len(self.splats["means"]), |
|
|
} |
|
|
print("Step: ", step, stats) |
|
|
with open( |
|
|
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", |
|
|
"w", |
|
|
) as f: |
|
|
json.dump(stats, f) |
|
|
data = {"step": step, "splats": self.splats.state_dict()} |
|
|
torch.save( |
|
|
data, |
|
|
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt", |
|
|
) |
|
|
if ( |
|
|
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 |
|
|
) and cfg.save_ply: |
|
|
sh0 = self.splats["sh0"] |
|
|
shN = self.splats["shN"] |
|
|
means = self.splats["means"] |
|
|
scales = self.splats["scales"] |
|
|
quats = self.splats["quats"] |
|
|
opacities = self.splats["opacities"] |
|
|
export_splats( |
|
|
means=means, |
|
|
scales=scales, |
|
|
quats=quats, |
|
|
opacities=opacities, |
|
|
sh0=sh0, |
|
|
shN=shN, |
|
|
format="ply", |
|
|
save_to=f"{self.ply_dir}/point_cloud_{step}.ply", |
|
|
) |
|
|
|
|
|
|
|
|
if cfg.sparse_grad: |
|
|
assert ( |
|
|
cfg.packed |
|
|
), "Sparse gradients only work with packed mode." |
|
|
gaussian_ids = info["gaussian_ids"] |
|
|
for k in self.splats.keys(): |
|
|
grad = self.splats[k].grad |
|
|
if grad is None or grad.is_sparse: |
|
|
continue |
|
|
self.splats[k].grad = torch.sparse_coo_tensor( |
|
|
indices=gaussian_ids[None], |
|
|
values=grad[gaussian_ids], |
|
|
size=self.splats[k].size(), |
|
|
is_coalesced=len(Ks) == 1, |
|
|
) |
|
|
|
|
|
if cfg.visible_adam: |
|
|
gaussian_cnt = self.splats.means.shape[0] |
|
|
if cfg.packed: |
|
|
visibility_mask = torch.zeros_like( |
|
|
self.splats["opacities"], dtype=bool |
|
|
) |
|
|
visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
|
|
else: |
|
|
visibility_mask = (info["radii"] > 0).all(-1).any(0) |
|
|
|
|
|
|
|
|
for optimizer in self.optimizers.values(): |
|
|
if cfg.visible_adam: |
|
|
optimizer.step(visibility_mask) |
|
|
else: |
|
|
optimizer.step() |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
for scheduler in schedulers: |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
if isinstance(self.cfg.strategy, DefaultStrategy): |
|
|
self.cfg.strategy.step_post_backward( |
|
|
params=self.splats, |
|
|
optimizers=self.optimizers, |
|
|
state=self.strategy_state, |
|
|
step=step, |
|
|
info=info, |
|
|
packed=cfg.packed, |
|
|
) |
|
|
elif isinstance(self.cfg.strategy, MCMCStrategy): |
|
|
self.cfg.strategy.step_post_backward( |
|
|
params=self.splats, |
|
|
optimizers=self.optimizers, |
|
|
state=self.strategy_state, |
|
|
step=step, |
|
|
info=info, |
|
|
lr=schedulers[0].get_last_lr()[0], |
|
|
) |
|
|
else: |
|
|
assert_never(self.cfg.strategy) |
|
|
|
|
|
|
|
|
if step in [i - 1 for i in cfg.eval_steps]: |
|
|
self.eval(step) |
|
|
self.render_video(step) |
|
|
|
|
|
@torch.no_grad() |
|
|
def eval( |
|
|
self, |
|
|
step: int, |
|
|
stage: str = "val", |
|
|
canvas_h: int = 512, |
|
|
canvas_w: int = 1024, |
|
|
): |
|
|
"""Entry for evaluation.""" |
|
|
print("Running evaluation...") |
|
|
cfg = self.cfg |
|
|
device = self.device |
|
|
world_rank = self.world_rank |
|
|
|
|
|
valloader = torch.utils.data.DataLoader( |
|
|
self.valset, batch_size=1, shuffle=False, num_workers=1 |
|
|
) |
|
|
ellipse_time = 0 |
|
|
metrics = defaultdict(list) |
|
|
for i, data in enumerate(valloader): |
|
|
camtoworlds = data["camtoworld"].to(device) |
|
|
Ks = data["K"].to(device) |
|
|
pixels = data["image"].to(device) / 255.0 |
|
|
height, width = pixels.shape[1:3] |
|
|
masks = data["mask"].to(device) if "mask" in data else None |
|
|
|
|
|
pixels = pixels.permute(0, 3, 1, 2) |
|
|
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2)) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
tic = time.time() |
|
|
colors, _, _ = self.rasterize_splats( |
|
|
camtoworlds=camtoworlds, |
|
|
Ks=Ks, |
|
|
width=width, |
|
|
height=height, |
|
|
sh_degree=cfg.sh_degree, |
|
|
near_plane=cfg.near_plane, |
|
|
far_plane=cfg.far_plane, |
|
|
masks=masks, |
|
|
) |
|
|
torch.cuda.synchronize() |
|
|
ellipse_time += max(time.time() - tic, 1e-10) |
|
|
|
|
|
colors = colors.permute(0, 3, 1, 2) |
|
|
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2)) |
|
|
colors = torch.clamp(colors, 0.0, 1.0) |
|
|
canvas_list = [pixels, colors] |
|
|
|
|
|
if world_rank == 0: |
|
|
canvas = torch.cat(canvas_list, dim=2).squeeze(0) |
|
|
canvas = canvas.permute(1, 2, 0) |
|
|
canvas = (canvas * 255).to(torch.uint8).cpu().numpy() |
|
|
cv2.imwrite( |
|
|
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", |
|
|
canvas[..., ::-1], |
|
|
) |
|
|
metrics["psnr"].append(self.psnr(colors, pixels)) |
|
|
metrics["ssim"].append(self.ssim(colors, pixels)) |
|
|
metrics["lpips"].append(self.lpips(colors, pixels)) |
|
|
|
|
|
if world_rank == 0: |
|
|
ellipse_time /= len(valloader) |
|
|
|
|
|
stats = { |
|
|
k: torch.stack(v).mean().item() for k, v in metrics.items() |
|
|
} |
|
|
stats.update( |
|
|
{ |
|
|
"ellipse_time": ellipse_time, |
|
|
"num_GS": len(self.splats["means"]), |
|
|
} |
|
|
) |
|
|
print( |
|
|
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " |
|
|
f"Time: {stats['ellipse_time']:.3f}s/image " |
|
|
f"Number of GS: {stats['num_GS']}" |
|
|
) |
|
|
|
|
|
with open( |
|
|
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w" |
|
|
) as f: |
|
|
json.dump(stats, f) |
|
|
|
|
|
for k, v in stats.items(): |
|
|
self.writer.add_scalar(f"{stage}/{k}", v, step) |
|
|
self.writer.flush() |
|
|
|
|
|
@torch.no_grad() |
|
|
def render_video( |
|
|
self, step: int, canvas_h: int = 512, canvas_w: int = 1024 |
|
|
): |
|
|
testloader = torch.utils.data.DataLoader( |
|
|
self.testset, batch_size=1, shuffle=False, num_workers=1 |
|
|
) |
|
|
|
|
|
images_cache = [] |
|
|
depth_global_min, depth_global_max = float("inf"), -float("inf") |
|
|
for data in testloader: |
|
|
camtoworlds = data["camtoworld"].to(self.device) |
|
|
Ks = resize_pinhole_intrinsics( |
|
|
data["K"].squeeze(), |
|
|
raw_hw=(data["image_h"].item(), data["image_w"].item()), |
|
|
new_hw=(canvas_h, canvas_w // 2), |
|
|
).to(self.device) |
|
|
renders, _, _ = self.rasterize_splats( |
|
|
camtoworlds=camtoworlds, |
|
|
Ks=Ks[None, ...], |
|
|
width=canvas_w // 2, |
|
|
height=canvas_h, |
|
|
sh_degree=self.cfg.sh_degree, |
|
|
near_plane=self.cfg.near_plane, |
|
|
far_plane=self.cfg.far_plane, |
|
|
render_mode="RGB+ED", |
|
|
) |
|
|
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) |
|
|
colors = (colors * 255).to(torch.uint8).cpu().numpy() |
|
|
depths = renders[0, ..., 3:4] |
|
|
images_cache.append([colors, depths]) |
|
|
depth_global_min = min(depth_global_min, depths.min().item()) |
|
|
depth_global_max = max(depth_global_max, depths.max().item()) |
|
|
|
|
|
video_path = f"{self.render_dir}/video_step{step}.mp4" |
|
|
writer = imageio.get_writer(video_path, fps=30) |
|
|
for rgb, depth in images_cache: |
|
|
depth_normalized = torch.clip( |
|
|
(depth - depth_global_min) |
|
|
/ (depth_global_max - depth_global_min + 1e-8), |
|
|
0, |
|
|
1, |
|
|
) |
|
|
depth_normalized = ( |
|
|
(depth_normalized * 255).to(torch.uint8).cpu().numpy() |
|
|
) |
|
|
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET) |
|
|
image = np.concatenate([rgb, depth_map], axis=1) |
|
|
writer.append_data(image) |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
def entrypoint( |
|
|
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig |
|
|
): |
|
|
runner = Runner(local_rank, world_rank, world_size, cfg) |
|
|
|
|
|
if cfg.ckpt is not None: |
|
|
|
|
|
ckpts = [ |
|
|
torch.load(file, map_location=runner.device, weights_only=True) |
|
|
for file in cfg.ckpt |
|
|
] |
|
|
for k in runner.splats.keys(): |
|
|
runner.splats[k].data = torch.cat( |
|
|
[ckpt["splats"][k] for ckpt in ckpts] |
|
|
) |
|
|
step = ckpts[0]["step"] |
|
|
runner.eval(step=step) |
|
|
runner.render_video(step=step) |
|
|
else: |
|
|
runner.train() |
|
|
runner.render_video(step=cfg.max_steps - 1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
configs = { |
|
|
"default": ( |
|
|
"Gaussian splatting training using densification heuristics from the original paper.", |
|
|
GsplatTrainConfig( |
|
|
strategy=DefaultStrategy(verbose=True), |
|
|
), |
|
|
), |
|
|
"mcmc": ( |
|
|
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", |
|
|
GsplatTrainConfig( |
|
|
init_scale=0.1, |
|
|
opacity_reg=0.01, |
|
|
scale_reg=0.01, |
|
|
strategy=MCMCStrategy(verbose=True), |
|
|
), |
|
|
), |
|
|
} |
|
|
cfg = tyro.extras.overridable_config_cli(configs) |
|
|
cfg.adjust_steps(cfg.steps_scaler) |
|
|
|
|
|
cli(entrypoint, cfg, verbose=True) |
|
|
|