|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
from cdim.noise import get_noise |
|
|
from cdim.operators import get_operator |
|
|
from cdim.image_utils import save_to_image |
|
|
from cdim.dps_model.dps_unet import create_model |
|
|
from cdim.diffusion.scheduling_ddim import DDIMScheduler |
|
|
from cdim.diffusion.diffusion_pipeline import run_diffusion |
|
|
|
|
|
|
|
|
def load_image(path): |
|
|
""" |
|
|
Load the image and normalize to [-1, 1] |
|
|
""" |
|
|
original_image = Image.open(path) |
|
|
|
|
|
|
|
|
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC)) |
|
|
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2) |
|
|
return (original_image / 127.5 - 1.0).to(torch.float)[:, :3] |
|
|
|
|
|
|
|
|
def load_yaml(file_path: str) -> dict: |
|
|
with open(file_path) as f: |
|
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
|
return config |
|
|
|
|
|
|
|
|
def process_image(image_path, output_dir, model, ddim_scheduler, operator, noise_function, |
|
|
device, args, model_type): |
|
|
""" |
|
|
Process a single image with the given model and parameters |
|
|
""" |
|
|
original_image = load_image(image_path).to(device) |
|
|
|
|
|
|
|
|
base_name = Path(image_path).stem |
|
|
|
|
|
noisy_measurement = noise_function(operator(original_image)) |
|
|
save_to_image(noisy_measurement, os.path.join(output_dir, f"{base_name}_noisy_measurement.png")) |
|
|
|
|
|
t0 = time.time() |
|
|
output_image = run_diffusion( |
|
|
model, ddim_scheduler, |
|
|
noisy_measurement, operator, noise_function, device, |
|
|
args.stopping_sigma, |
|
|
num_inference_steps=args.T, |
|
|
K=args.K, |
|
|
model_type=model_type, |
|
|
original_image=original_image) |
|
|
print(f"Processing time for {base_name}: {time.time() - t0:.2f}s") |
|
|
|
|
|
save_to_image(output_image, os.path.join(output_dir, f"{base_name}_output.png")) |
|
|
|
|
|
|
|
|
def main(args): |
|
|
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu' |
|
|
print(f"Using device {device_str}") |
|
|
device = torch.device(device_str) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
noise_config = load_yaml(args.noise_config) |
|
|
noise_function = get_noise(**noise_config) |
|
|
|
|
|
|
|
|
operator_config = load_yaml(args.operator_config) |
|
|
operator_config["device"] = device |
|
|
operator = get_operator(**operator_config) |
|
|
|
|
|
if args.model_config.endswith(".yaml"): |
|
|
|
|
|
model_type = "dps" |
|
|
model_config = load_yaml(args.model_config) |
|
|
model = create_model(**model_config) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
else: |
|
|
|
|
|
model_type = "diffusers" |
|
|
model = DiffusionPipeline.from_pretrained(args.model_config).to(device).unet |
|
|
|
|
|
|
|
|
|
|
|
ddim_scheduler = DDIMScheduler( |
|
|
num_train_timesteps=1000, |
|
|
beta_start=0.0001, |
|
|
beta_end=0.02, |
|
|
beta_schedule="linear", |
|
|
prediction_type="epsilon", |
|
|
timestep_spacing="leading", |
|
|
steps_offset=0, |
|
|
) |
|
|
|
|
|
|
|
|
input_path = Path(args.input) |
|
|
|
|
|
if input_path.is_file(): |
|
|
|
|
|
print(f"Processing single image: {input_path.name}") |
|
|
process_image( |
|
|
str(input_path), args.output_dir, model, ddim_scheduler, |
|
|
operator, noise_function, device, args, model_type |
|
|
) |
|
|
elif input_path.is_dir(): |
|
|
|
|
|
image_files = [ |
|
|
f for f in input_path.iterdir() |
|
|
if not f.name.startswith('.') and f.suffix.lower() in ['.png', '.jpg', '.jpeg'] |
|
|
] |
|
|
image_files = sorted(image_files) |
|
|
|
|
|
print(f"Found {len(image_files)} images to process") |
|
|
|
|
|
for image_file in image_files: |
|
|
print(f"Processing {image_file.name}...") |
|
|
|
|
|
operator = get_operator(**operator_config) |
|
|
process_image( |
|
|
str(image_file), args.output_dir, model, ddim_scheduler, |
|
|
operator, noise_function, device, args, model_type |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Input path '{input_path}' is neither a file nor a directory") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("input", type=str, help="Path to input image or folder containing input images") |
|
|
parser.add_argument("T", type=int) |
|
|
parser.add_argument("operator_config", type=str) |
|
|
parser.add_argument("noise_config", type=str) |
|
|
parser.add_argument("model_config", type=str) |
|
|
parser.add_argument("--stopping-sigma", type=float, default=0.1, help="How many std deviations away to stop") |
|
|
parser.add_argument("--lambda-val", type=float, |
|
|
default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.") |
|
|
parser.add_argument("--output-dir", default="output", type=str) |
|
|
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction) |
|
|
parser.add_argument("--K", type=int, default=20, |
|
|
help="Cap the number of steps K at any iteration. Helps avoid edge cases or cap NFEs.") |
|
|
|
|
|
|
|
|
main(parser.parse_args()) |