import argparse import os import yaml import time from pathlib import Path from PIL import Image import numpy as np import torch from diffusers import DiffusionPipeline from cdim.noise import get_noise from cdim.operators import get_operator from cdim.image_utils import save_to_image from cdim.dps_model.dps_unet import create_model from cdim.diffusion.scheduling_ddim import DDIMScheduler from cdim.diffusion.diffusion_pipeline import run_diffusion def load_image(path): """ Load the image and normalize to [-1, 1] """ original_image = Image.open(path) # Resize if needed original_image = np.array(original_image.resize((256, 256), Image.BICUBIC)) original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2) return (original_image / 127.5 - 1.0).to(torch.float)[:, :3] def load_yaml(file_path: str) -> dict: with open(file_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) return config def process_image(image_path, output_dir, model, ddim_scheduler, operator, noise_function, device, args, model_type): """ Process a single image with the given model and parameters """ original_image = load_image(image_path).to(device) # Get the base filename without extension base_name = Path(image_path).stem noisy_measurement = noise_function(operator(original_image)) save_to_image(noisy_measurement, os.path.join(output_dir, f"{base_name}_noisy_measurement.png")) t0 = time.time() output_image = run_diffusion( model, ddim_scheduler, noisy_measurement, operator, noise_function, device, args.stopping_sigma, num_inference_steps=args.T, K=args.K, model_type=model_type, original_image=original_image) print(f"Processing time for {base_name}: {time.time() - t0:.2f}s") save_to_image(output_image, os.path.join(output_dir, f"{base_name}_output.png")) def main(args): device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu' print(f"Using device {device_str}") device = torch.device(device_str) os.makedirs(args.output_dir, exist_ok=True) # Load the noise function noise_config = load_yaml(args.noise_config) noise_function = get_noise(**noise_config) # Load the measurement function A operator_config = load_yaml(args.operator_config) operator_config["device"] = device operator = get_operator(**operator_config) if args.model_config.endswith(".yaml"): # Local model from DPS model_type = "dps" model_config = load_yaml(args.model_config) model = create_model(**model_config) model = model.to(device) model.eval() else: # Huggingface diffusers model model_type = "diffusers" model = DiffusionPipeline.from_pretrained(args.model_config).to(device).unet # All the models have the same scheduler. # you can change this for different models ddim_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", prediction_type="epsilon", timestep_spacing="leading", steps_offset=0, ) # Process input (either a single image or all images in a directory) input_path = Path(args.input) if input_path.is_file(): # Process a single image print(f"Processing single image: {input_path.name}") process_image( str(input_path), args.output_dir, model, ddim_scheduler, operator, noise_function, device, args, model_type ) elif input_path.is_dir(): # Process all images in the directory image_files = [ f for f in input_path.iterdir() if not f.name.startswith('.') and f.suffix.lower() in ['.png', '.jpg', '.jpeg'] ] image_files = sorted(image_files) print(f"Found {len(image_files)} images to process") for image_file in image_files: print(f"Processing {image_file.name}...") # Optional, recreate operator (uncomment to use same operator) operator = get_operator(**operator_config) process_image( str(image_file), args.output_dir, model, ddim_scheduler, operator, noise_function, device, args, model_type ) else: raise ValueError(f"Input path '{input_path}' is neither a file nor a directory") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("input", type=str, help="Path to input image or folder containing input images") parser.add_argument("T", type=int) parser.add_argument("operator_config", type=str) parser.add_argument("noise_config", type=str) parser.add_argument("model_config", type=str) parser.add_argument("--stopping-sigma", type=float, default=0.1, help="How many std deviations away to stop") parser.add_argument("--lambda-val", type=float, default=None, help="Constant to scale learning rate. Leave empty to use a heuristic best guess.") parser.add_argument("--output-dir", default="output", type=str) parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction) parser.add_argument("--K", type=int, default=20, help="Cap the number of steps K at any iteration. Helps avoid edge cases or cap NFEs.") main(parser.parse_args())