|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import torchvision |
|
|
import utils |
|
|
from models.unet import DiffusionUNet |
|
|
import torchdiffeq |
|
|
import math |
|
|
from torchvision.transforms.functional import crop |
|
|
|
|
|
|
|
|
def dict2namespace(config): |
|
|
namespace = argparse.Namespace() |
|
|
for key, value in config.items(): |
|
|
if isinstance(value, dict): |
|
|
new_value = dict2namespace(value) |
|
|
else: |
|
|
new_value = value |
|
|
setattr(namespace, key, new_value) |
|
|
return namespace |
|
|
|
|
|
|
|
|
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): |
|
|
def sigmoid(x): |
|
|
return 1 / (np.exp(-x) + 1) |
|
|
|
|
|
if beta_schedule == "quad": |
|
|
betas = ( |
|
|
np.linspace( |
|
|
beta_start**0.5, |
|
|
beta_end**0.5, |
|
|
num_diffusion_timesteps, |
|
|
dtype=np.float64, |
|
|
) |
|
|
** 2 |
|
|
) |
|
|
elif beta_schedule == "linear": |
|
|
betas = np.linspace( |
|
|
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 |
|
|
) |
|
|
elif beta_schedule == "const": |
|
|
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) |
|
|
elif beta_schedule == "jsd": |
|
|
betas = 1.0 / np.linspace( |
|
|
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 |
|
|
) |
|
|
elif beta_schedule == "sigmoid": |
|
|
betas = np.linspace(-6, 6, num_diffusion_timesteps) |
|
|
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start |
|
|
else: |
|
|
raise NotImplementedError(beta_schedule) |
|
|
return betas |
|
|
|
|
|
|
|
|
class VPDiffusionFlow: |
|
|
def __init__(self, args, config): |
|
|
self.args = args |
|
|
self.flow_mode = getattr(args, "flow_mode", "vp") |
|
|
self.config = config |
|
|
self.device = config.device |
|
|
|
|
|
|
|
|
self.model = DiffusionUNet(config) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
self.num_timesteps = config.diffusion.num_diffusion_timesteps |
|
|
betas = get_beta_schedule( |
|
|
beta_schedule=config.diffusion.beta_schedule, |
|
|
beta_start=config.diffusion.beta_start, |
|
|
beta_end=config.diffusion.beta_end, |
|
|
num_diffusion_timesteps=self.num_timesteps, |
|
|
) |
|
|
self.betas = torch.from_numpy(betas).float().to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.beta_start = config.diffusion.beta_start |
|
|
self.beta_end = config.diffusion.beta_end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alphas = 1.0 - self.betas |
|
|
self.alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
|
|
|
|
def load_ckpt(self, load_path): |
|
|
checkpoint = torch.load(load_path, map_location=self.device) |
|
|
|
|
|
if "state_dict" in checkpoint: |
|
|
state_dict = checkpoint["state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("module."): |
|
|
new_state_dict[k[7:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
state_dict = new_state_dict |
|
|
|
|
|
self.model.load_state_dict(state_dict, strict=True) |
|
|
print(f"=> loaded checkpoint '{load_path}'") |
|
|
self.model.eval() |
|
|
|
|
|
def get_beta_t(self, t): |
|
|
|
|
|
|
|
|
scalar_t = t.item() if isinstance(t, torch.Tensor) else t |
|
|
|
|
|
scalar_t = max(0.0, min(1.0, scalar_t)) |
|
|
return self.beta_start + scalar_t * (self.beta_end - self.beta_start) |
|
|
|
|
|
def get_alpha_bar_t(self, t): |
|
|
|
|
|
scalar_t = t.item() if isinstance(t, torch.Tensor) else t |
|
|
scalar_t = max(0.0, min(1.0, scalar_t)) |
|
|
|
|
|
N = self.num_timesteps |
|
|
|
|
|
|
|
|
|
|
|
b0 = self.beta_start |
|
|
b1 = self.beta_end |
|
|
integral = N * (b0 * scalar_t + 0.5 * (b1 - b0) * scalar_t**2) |
|
|
return math.exp(-integral) |
|
|
|
|
|
def overlapping_grid_indices(self, x_cond, output_size, r=None): |
|
|
_, c, h, w = x_cond.shape |
|
|
r = 16 if r is None else r |
|
|
h_list = [i for i in range(0, h - output_size + 1, r)] |
|
|
w_list = [i for i in range(0, w - output_size + 1, r)] |
|
|
return h_list, w_list |
|
|
|
|
|
def get_blending_window(self, patch_size): |
|
|
|
|
|
|
|
|
w = torch.hann_window(patch_size, periodic=False, device=self.device) |
|
|
w2d = w.unsqueeze(0) * w.unsqueeze(1) |
|
|
return w2d.view(1, 1, patch_size, patch_size) |
|
|
|
|
|
def get_velocity(self, x, t, x_cond, patch_size=None, r_stride=16): |
|
|
|
|
|
if patch_size is None or ( |
|
|
x.shape[2] == patch_size and x.shape[3] == patch_size |
|
|
): |
|
|
return self._get_velocity_single(x, t, x_cond) |
|
|
|
|
|
|
|
|
N = self.num_timesteps |
|
|
t_idx = min(int(t * N), N - 1) |
|
|
t_input_scalar = t_idx |
|
|
|
|
|
|
|
|
|
|
|
pad_size = patch_size // 2 |
|
|
x_padded = torch.nn.functional.pad( |
|
|
x, (pad_size, pad_size, pad_size, pad_size), mode="reflect" |
|
|
) |
|
|
x_cond_padded = torch.nn.functional.pad( |
|
|
x_cond, (pad_size, pad_size, pad_size, pad_size), mode="reflect" |
|
|
) |
|
|
|
|
|
|
|
|
h_list, w_list = self.overlapping_grid_indices(x_padded, patch_size, r=r_stride) |
|
|
corners = [(i, j) for i in h_list for j in w_list] |
|
|
|
|
|
|
|
|
window = self.get_blending_window(patch_size) |
|
|
|
|
|
|
|
|
x_grid_mask = torch.zeros_like(x_padded, device=self.device) |
|
|
for hi, wi in corners: |
|
|
x_grid_mask[:, :, hi : hi + patch_size, wi : wi + patch_size] += window |
|
|
|
|
|
|
|
|
output_accum = torch.zeros_like(x_padded, device=self.device) |
|
|
|
|
|
|
|
|
batch_size = 64 |
|
|
|
|
|
|
|
|
if self.flow_mode == "vp": |
|
|
beta_discrete = self.get_beta_t(t) |
|
|
beta_cont = beta_discrete * N |
|
|
ab = self.alphas_cumprod[t_idx] |
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, len(corners), batch_size): |
|
|
batch_corners = corners[i : i + batch_size] |
|
|
|
|
|
|
|
|
x_batch = torch.cat( |
|
|
[ |
|
|
crop(x_padded, hi, wi, patch_size, patch_size) |
|
|
for (hi, wi) in batch_corners |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
cond_batch = torch.cat( |
|
|
[ |
|
|
crop(x_cond_padded, hi, wi, patch_size, patch_size) |
|
|
for (hi, wi) in batch_corners |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
t_batch = torch.tensor( |
|
|
[t_input_scalar] * x_batch.size(0), device=self.device |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
model_output = self.model( |
|
|
torch.cat([cond_batch, x_batch], dim=1), t_batch |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
weighted_output = model_output * window |
|
|
|
|
|
for idx, (hi, wi) in enumerate(batch_corners): |
|
|
output_accum[0, :, hi : hi + patch_size, wi : wi + patch_size] += ( |
|
|
weighted_output[idx] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model_output_full = torch.div(output_accum, x_grid_mask + 1e-8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pad_size > 0: |
|
|
model_output_full = model_output_full[ |
|
|
:, :, pad_size:-pad_size, pad_size:-pad_size |
|
|
] |
|
|
|
|
|
|
|
|
if self.flow_mode == "reflow": |
|
|
|
|
|
v = model_output_full |
|
|
else: |
|
|
|
|
|
epsilon = model_output_full |
|
|
coeff1 = -0.5 * beta_cont |
|
|
coeff2 = 0.5 * beta_cont / torch.sqrt(1 - ab) |
|
|
v = coeff1 * x + coeff2 * epsilon |
|
|
|
|
|
return v |
|
|
|
|
|
def _get_velocity_single(self, x, t, x_cond): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N = self.num_timesteps |
|
|
t_idx = min(int(t * N), N - 1) |
|
|
t_input = torch.tensor([t_idx] * x.size(0), device=self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
model_output = self.model(torch.cat([x_cond, x], dim=1), t_input) |
|
|
|
|
|
if self.flow_mode == "reflow": |
|
|
return model_output |
|
|
else: |
|
|
epsilon = model_output |
|
|
beta_discrete = self.get_beta_t(t) |
|
|
beta_cont = beta_discrete * N |
|
|
ab = self.alphas_cumprod[t_idx] |
|
|
|
|
|
coeff1 = -0.5 * beta_cont |
|
|
coeff2 = 0.5 * beta_cont / torch.sqrt(1 - ab) |
|
|
|
|
|
v = coeff1 * x + coeff2 * epsilon |
|
|
return v |
|
|
|
|
|
|
|
|
def ode_solve( |
|
|
flow_model, |
|
|
x_init, |
|
|
x_cond, |
|
|
steps=100, |
|
|
method="dopri5", |
|
|
patch_size=64, |
|
|
atol=1e-4, |
|
|
rtol=1e-4, |
|
|
): |
|
|
|
|
|
step = 0 |
|
|
|
|
|
print(f"ODE Solve: Method={method}, Steps={steps}, atol={atol}, rtol={rtol}") |
|
|
|
|
|
def drift_func(t, x): |
|
|
nonlocal step |
|
|
step += 1 |
|
|
print(f"Step {step}, t={t.item():.6f}") |
|
|
|
|
|
return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size) |
|
|
|
|
|
t_eval = torch.linspace(1.0, 0.0, steps + 1, device=x_init.device) |
|
|
out = torchdiffeq.odeint( |
|
|
drift_func, x_init, t_eval, method=method, atol=atol, rtol=rtol |
|
|
) |
|
|
return out[-1] |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument("--resume", type=str, required=True) |
|
|
parser.add_argument( |
|
|
"--data_dir", type=str, default=None, help="Override data_dir in config" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset", type=str, default=None, help="Override dataset name" |
|
|
) |
|
|
parser.add_argument("--steps", type=int, default=100) |
|
|
parser.add_argument("--output", type=str, default="results/diff2flow") |
|
|
parser.add_argument("--seed", type=int, default=61) |
|
|
parser.add_argument( |
|
|
"--patch_size", type=int, default=64, help="Patch size for model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--method", type=str, default="dopri5", help="ODE solver method" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--atol", type=float, default=1e-4, help="Absolute tolerance for ODE solver" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rtol", type=float, default=1e-4, help="Relative tolerance for ODE solver" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--flow_mode", |
|
|
type=str, |
|
|
default="vp", |
|
|
choices=["vp", "reflow"], |
|
|
help="Flow mode: vp (default) or reflow", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(os.path.join("configs", args.config), "r") as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
config = dict2namespace(config_dict) |
|
|
|
|
|
if args.data_dir: |
|
|
config.data.data_dir = args.data_dir |
|
|
if args.dataset: |
|
|
config.data.dataset = args.dataset |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
config.device = device |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
|
|
|
print("Initializing VPDiffusionFlow...") |
|
|
flow = VPDiffusionFlow(args, config) |
|
|
flow.load_ckpt(args.resume) |
|
|
|
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
|
|
|
import datasets |
|
|
|
|
|
print(f"Loading dataset {config.data.dataset}...") |
|
|
DATASET = datasets.__dict__[config.data.dataset](config) |
|
|
|
|
|
_, val_loader = DATASET.get_loaders( |
|
|
parse_patches=False, |
|
|
validation=config.data.dataset if args.dataset else "raindrop", |
|
|
) |
|
|
|
|
|
for i, (x_batch, img_id) in enumerate(val_loader): |
|
|
print(f"Processing image {img_id}...") |
|
|
|
|
|
x_batch = x_batch.to(device) |
|
|
|
|
|
|
|
|
x_cond = x_batch[:, :3, :, :] |
|
|
|
|
|
|
|
|
x_cond = utils.sampling.data_transform(x_cond) |
|
|
|
|
|
B, C, H, W = x_cond.shape |
|
|
x_init = torch.randn(B, 3, H, W, device=device) |
|
|
|
|
|
print(f"Starting flow matching inference for image {img_id}, shape {H}x{W}...") |
|
|
x_pred = ode_solve( |
|
|
flow, |
|
|
x_init, |
|
|
x_cond, |
|
|
steps=args.steps, |
|
|
patch_size=args.patch_size, |
|
|
method=args.method, |
|
|
atol=args.atol, |
|
|
rtol=args.rtol, |
|
|
) |
|
|
|
|
|
x_pred = utils.sampling.inverse_data_transform(x_pred) |
|
|
x_cond_img = utils.sampling.inverse_data_transform(x_cond) |
|
|
|
|
|
|
|
|
if isinstance(img_id, tuple) or isinstance(img_id, list): |
|
|
idx = img_id[0] |
|
|
else: |
|
|
idx = img_id |
|
|
|
|
|
utils.logging.save_image( |
|
|
x_cond_img[0], os.path.join(args.output, f"{idx}_input.png") |
|
|
) |
|
|
utils.logging.save_image( |
|
|
x_pred[0], os.path.join(args.output, f"{idx}_flow.png") |
|
|
) |
|
|
|
|
|
print("Done.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|