import cv2 import numpy as np import torch import tempfile import gradio as gr import time import io from contextlib import redirect_stdout device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {device}") if device.type == "cuda": torch.backends.cudnn.benchmark = True try: print("[INFO] Attempting to load RAFT model from torch.hub...") raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True) raft_model = raft_model.to(device) raft_model.eval() print("[INFO] RAFT model loaded successfully.") except Exception as e: print("[ERROR] Error loading RAFT model:", e) print("[INFO] Falling back to OpenCV Farneback optical flow.") raft_model = None gr.Warning("Falling back to OpenCV Farneback optical flow.") def _resize(frame, w, h): if frame.shape[1] == w and frame.shape[0] == h: return frame return cv2.resize( frame, (w, h), interpolation=cv2.INTER_AREA if (w < frame.shape[1] or h < frame.shape[0]) else cv2.INTER_LINEAR, ) def _frame_to_raft_tensor_bgr(frame_bgr): frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) t = torch.from_numpy(frame_rgb).permute(2, 0, 1).contiguous().float().unsqueeze(0).div_(255.0) return t.to(device, non_blocking=(device.type == "cuda")) def compute_offsets( video_file, out_w, out_h, motion_scale=0.5, raft_iters=12, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.55, ): cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise gr.Error("Could not open video file for motion estimation.") total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 mw = max(64, int(out_w * float(motion_scale))) mh = max(64, int(out_h * float(motion_scale))) sx = float(out_w) / float(mw) sy = float(out_h) / float(mh) ret, prev = cap.read() if not ret: cap.release() raise gr.Error("Cannot read first frame from video.") prev_out = _resize(prev, out_w, out_h) prev_small = _resize(prev_out, mw, mh) use_raft = raft_model is not None use_amp = device.type == "cuda" if use_raft: prev_t = _frame_to_raft_tensor_bgr(prev_small) else: prev_g = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY) offsets = [(0.0, 0.0)] cum_dx = 0.0 cum_dy = 0.0 idx = 1 while True: ret, frame = cap.read() if not ret: break frame_out = _resize(frame, out_w, out_h) curr_small = _resize(frame_out, mw, mh) if use_raft: curr_t = _frame_to_raft_tensor_bgr(curr_small) with torch.no_grad(): if use_amp: with torch.cuda.amp.autocast(True): _, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True) else: _, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True) flow = flow_up[0] dx = float(flow[0].median().item()) dy = float(flow[1].median().item()) prev_t = curr_t else: curr_g = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY) flow = cv2.calcOpticalFlowFarneback( prev_g, curr_g, None, pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0, ) dx = float(np.median(flow[..., 0])) dy = float(np.median(flow[..., 1])) prev_g = curr_g dx *= sx dy *= sy cum_dx += dx cum_dy += dy offsets.append((-cum_dx, -cum_dy)) if total > 0 and (idx % 5 == 0 or idx == total - 1): progress(progress_offset + (idx / max(1, total - 1)) * progress_scale, desc="Estimating Motion") idx += 1 cap.release() return offsets def compute_auto_zoom(offsets, width, height): dxs = [o[0] for o in offsets] or [0.0] dys = [o[1] for o in offsets] or [0.0] left = max(0.0, -min(dxs)) right = max(0.0, max(dxs)) top = max(0.0, -min(dys)) bottom = max(0.0, max(dys)) safe_w = float(width) - (left + right) safe_h = float(height) - (top + bottom) zx = (float(width) / safe_w) if safe_w > 1.0 else 1.0 zy = (float(height) / safe_h) if safe_h > 1.0 else 1.0 return max(1.0, zx, zy) def stabilize_stream( video_file, offsets, zoom=1.0, vertical_only=False, out_w=None, out_h=None, progress=gr.Progress(), progress_offset=0.55, progress_scale=0.45, output_file=None, ): cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise gr.Error("Could not open video file for stabilization.") fps = cap.get(cv2.CAP_PROP_FPS) in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) if out_w is None: out_w = in_w if out_h is None: out_h = in_h if output_file is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") output_file = temp_file.name temp_file.close() fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_file, fourcc, fps, (int(out_w), int(out_h))) center = (float(out_w) / 2.0, float(out_h) / 2.0) base = cv2.getRotationMatrix2D(center, 0.0, float(zoom)) total = len(offsets) i = 0 while i < total: ret, frame = cap.read() if not ret: break frame_out = _resize(frame, int(out_w), int(out_h)) dx, dy = offsets[i] if vertical_only: dx = 0.0 M = base.copy() M[0, 2] += float(dx) M[1, 2] += float(dy) stabilized = cv2.warpAffine(frame_out, M, (int(out_w), int(out_h)), borderMode=cv2.BORDER_REPLICATE) out.write(stabilized) if total > 0 and (i % 5 == 0 or i == total - 1): progress(progress_offset + (i / max(1, total - 1)) * progress_scale, desc="Stabilizing Video") i += 1 cap.release() out.release() return output_file def process_video_ai( video_file, zoom, max_zoom, vertical_only, compress_mode, target_width, target_height, auto_zoom, progress=gr.Progress(track_tqdm=True), ): gr.Info("Starting AI-powered video processing...") log_buffer = io.StringIO() with redirect_stdout(log_buffer): if isinstance(video_file, dict): video_file = video_file.get("name", None) if video_file is None: raise gr.Error("Please upload a video file.") cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise gr.Error("Could not open uploaded video.") in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() if compress_mode: out_w = int(target_width) out_h = int(target_height) else: out_w = in_w out_h = in_h offsets = compute_offsets( video_file, out_w, out_h, motion_scale=0.5, raft_iters=12, progress=progress, progress_offset=0.0, progress_scale=0.55, ) gr.Info("Motion estimated successfully.") if auto_zoom: z = compute_auto_zoom(offsets, out_w, out_h) if max_zoom is not None: try: mz = float(max_zoom) if mz > 0: z = min(z, mz) except Exception: pass gr.Info(f"Auto zoom factor computed: {z:.2f}") zoom = z else: if max_zoom is not None: try: mz = float(max_zoom) if mz > 0: zoom = min(float(zoom), mz) except Exception: zoom = float(zoom) stabilized_path = stabilize_stream( video_file, offsets, zoom=float(zoom), vertical_only=bool(vertical_only), out_w=out_w, out_h=out_h, progress=progress, progress_offset=0.55, progress_scale=0.45, ) gr.Info("Video stabilization complete.") print("[INFO] Video processing complete.") logs = log_buffer.getvalue() return video_file, stabilized_path, logs with gr.Blocks() as demo: gr.Markdown("# AI-Powered Video Stabilization") gr.Markdown( "Upload a video, select a zoom factor (or use Auto Zoom Mode), optionally cap the maximum zoom, choose whether to apply only vertical stabilization, and optionally compress the output resolution. " "The system estimates motion using RAFT if available (otherwise Farneback) and stabilizes the video with progress updates." ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="Input Video") zoom_slider = gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Zoom Factor (ignored if Auto Zoom enabled)") max_zoom_slider = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, value=3.0, label="Max Zoom (caps manual + auto zoom)") auto_zoom_checkbox = gr.Checkbox(label="Auto Zoom Mode", value=False) vertical_checkbox = gr.Checkbox(label="Vertical Stabilization Only", value=False) compress_checkbox = gr.Checkbox(label="Compress Output Resolution", value=False) target_width = gr.Number(label="Target Width (px)", value=640) target_height = gr.Number(label="Target Height (px)", value=360) process_button = gr.Button("Process Video") with gr.Column(): original_video = gr.Video(label="Original Video") stabilized_video = gr.Video(label="Stabilized Video") logs_output = gr.Textbox(label="Logs", lines=10) process_button.click( fn=process_video_ai, inputs=[video_input, zoom_slider, max_zoom_slider, vertical_checkbox, compress_checkbox, target_width, target_height, auto_zoom_checkbox], outputs=[original_video, stabilized_video, logs_output], ) if __name__ == "__main__": demo.launch()