import gradio as gr import spaces import torch import yaml import numpy as np from PIL import Image from cdim.noise import get_noise from cdim.operators import get_operator from cdim.diffusion.scheduling_ddim import DDIMScheduler from cdim.diffusion.diffusion_pipeline import run_diffusion from diffusers import DiffusionPipeline # Global variables for model and scheduler (initialized inside GPU-decorated function) model = None ddim_scheduler = None model_type = None curr_model_name = None def load_image(image_path): """Process input image to tensor format.""" image = Image.open(image_path) original_image = np.array(image.resize((256, 256), Image.BICUBIC)) original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2) return (original_image / 127.5 - 1.0).to(torch.float)[:, :3] def load_yaml(file_path: str) -> dict: """Load configurations from a YAML file.""" with open(file_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) return config def convert_to_np(torch_image): return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8) @spaces.GPU def process_image(image_choice, noise_sigma, operator_key, T, stopping_sigma): """Combined function to handle both generation and restoration.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model inside GPU-decorated function global model, curr_model_name, ddim_scheduler, model_type model_name = "google/ddpm-celebahq-256" if "CelebA" in image_choice else "google/ddpm-church-256" if model is None or curr_model_name != model_name: model_type = "diffusers" model = DiffusionPipeline.from_pretrained(model_name).to(device).unet curr_model_name = model_name ddim_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear" ) image_paths = { "CelebA HQ 1": "sample_images/celebhq_29999.jpg", "CelebA HQ 2": "sample_images/celebhq_00001.jpg", "CelebA HQ 3": "sample_images/celebhq_00000.jpg", "LSUN Church": "sample_images/lsun_church.png" } config_paths = { "Box Inpainting": "operator_configs/box_inpainting_config.yaml", "Random Inpainting": "operator_configs/random_inpainting_config.yaml", "Super Resolution": "operator_configs/super_resolution_config.yaml", "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml" } # Generate noisy image image_path = image_paths[image_choice] original_image = load_image(image_path).to(device) noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml") noise_config["sigma"] = noise_sigma noise_function = get_noise(**noise_config) operator_config = load_yaml(config_paths[operator_key]) operator_config["device"] = device operator = get_operator(**operator_config) noisy_measurement = noise_function(operator(original_image)) noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0])) # Run restoration output_image = run_diffusion( model, ddim_scheduler, noisy_measurement, operator, noise_function, device, stopping_sigma, num_inference_steps=T, model_type=model_type ) output_image = Image.fromarray(convert_to_np(output_image[0])) return noisy_image, output_image # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Noisy Image Restoration with Diffusion Models") with gr.Row(): T = gr.Slider(4, 200, value=25, step=1, label="Number of Inference Steps (T)") stopping_sigma = gr.Slider(0.1, 5.0, value=0.1, step=0.1, label="Stopping Sigma (c)") noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Measurement Noise Sigma (σ_y)") image_select = gr.Dropdown( choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3", "LSUN Church"], value="CelebA HQ 1", label="Select Input Image" ) operator_select = gr.Dropdown( choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"], value="Random Inpainting", label="Select Task" ) run_button = gr.Button("Run Inference") noisy_image = gr.Image(label="Noisy Image") restored_image = gr.Image(label="Restored Image") run_button.click( fn=process_image, inputs=[image_select, noise_sigma, operator_select, T, stopping_sigma], outputs=[noisy_image, restored_image] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)