| import cv2 | |
| import numpy as np | |
| import torch | |
| import tempfile | |
| import gradio as gr | |
| import time | |
| import io | |
| from contextlib import redirect_stdout | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INFO] Using device: {device}") | |
| if device.type == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| try: | |
| print("[INFO] Attempting to load RAFT model from torch.hub...") | |
| raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True) | |
| raft_model = raft_model.to(device) | |
| raft_model.eval() | |
| print("[INFO] RAFT model loaded successfully.") | |
| except Exception as e: | |
| print("[ERROR] Error loading RAFT model:", e) | |
| print("[INFO] Falling back to OpenCV Farneback optical flow.") | |
| raft_model = None | |
| gr.Warning("Falling back to OpenCV Farneback optical flow.") | |
| def _resize(frame, w, h): | |
| if frame.shape[1] == w and frame.shape[0] == h: | |
| return frame | |
| return cv2.resize( | |
| frame, | |
| (w, h), | |
| interpolation=cv2.INTER_AREA if (w < frame.shape[1] or h < frame.shape[0]) else cv2.INTER_LINEAR, | |
| ) | |
| def _frame_to_raft_tensor_bgr(frame_bgr): | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| t = torch.from_numpy(frame_rgb).permute(2, 0, 1).contiguous().float().unsqueeze(0).div_(255.0) | |
| return t.to(device, non_blocking=(device.type == "cuda")) | |
| def compute_offsets( | |
| video_file, | |
| out_w, | |
| out_h, | |
| motion_scale=0.5, | |
| raft_iters=12, | |
| progress=gr.Progress(), | |
| progress_offset=0.0, | |
| progress_scale=0.55, | |
| ): | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| raise gr.Error("Could not open video file for motion estimation.") | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 | |
| mw = max(64, int(out_w * float(motion_scale))) | |
| mh = max(64, int(out_h * float(motion_scale))) | |
| sx = float(out_w) / float(mw) | |
| sy = float(out_h) / float(mh) | |
| ret, prev = cap.read() | |
| if not ret: | |
| cap.release() | |
| raise gr.Error("Cannot read first frame from video.") | |
| prev_out = _resize(prev, out_w, out_h) | |
| prev_small = _resize(prev_out, mw, mh) | |
| use_raft = raft_model is not None | |
| use_amp = device.type == "cuda" | |
| if use_raft: | |
| prev_t = _frame_to_raft_tensor_bgr(prev_small) | |
| else: | |
| prev_g = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY) | |
| offsets = [(0.0, 0.0)] | |
| cum_dx = 0.0 | |
| cum_dy = 0.0 | |
| idx = 1 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_out = _resize(frame, out_w, out_h) | |
| curr_small = _resize(frame_out, mw, mh) | |
| if use_raft: | |
| curr_t = _frame_to_raft_tensor_bgr(curr_small) | |
| with torch.no_grad(): | |
| if use_amp: | |
| with torch.cuda.amp.autocast(True): | |
| _, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True) | |
| else: | |
| _, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True) | |
| flow = flow_up[0] | |
| dx = float(flow[0].median().item()) | |
| dy = float(flow[1].median().item()) | |
| prev_t = curr_t | |
| else: | |
| curr_g = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY) | |
| flow = cv2.calcOpticalFlowFarneback( | |
| prev_g, | |
| curr_g, | |
| None, | |
| pyr_scale=0.5, | |
| levels=3, | |
| winsize=15, | |
| iterations=3, | |
| poly_n=5, | |
| poly_sigma=1.2, | |
| flags=0, | |
| ) | |
| dx = float(np.median(flow[..., 0])) | |
| dy = float(np.median(flow[..., 1])) | |
| prev_g = curr_g | |
| dx *= sx | |
| dy *= sy | |
| cum_dx += dx | |
| cum_dy += dy | |
| offsets.append((-cum_dx, -cum_dy)) | |
| if total > 0 and (idx % 5 == 0 or idx == total - 1): | |
| progress(progress_offset + (idx / max(1, total - 1)) * progress_scale, desc="Estimating Motion") | |
| idx += 1 | |
| cap.release() | |
| return offsets | |
| def compute_auto_zoom(offsets, width, height): | |
| dxs = [o[0] for o in offsets] or [0.0] | |
| dys = [o[1] for o in offsets] or [0.0] | |
| left = max(0.0, -min(dxs)) | |
| right = max(0.0, max(dxs)) | |
| top = max(0.0, -min(dys)) | |
| bottom = max(0.0, max(dys)) | |
| safe_w = float(width) - (left + right) | |
| safe_h = float(height) - (top + bottom) | |
| zx = (float(width) / safe_w) if safe_w > 1.0 else 1.0 | |
| zy = (float(height) / safe_h) if safe_h > 1.0 else 1.0 | |
| return max(1.0, zx, zy) | |
| def stabilize_stream( | |
| video_file, | |
| offsets, | |
| zoom=1.0, | |
| vertical_only=False, | |
| out_w=None, | |
| out_h=None, | |
| progress=gr.Progress(), | |
| progress_offset=0.55, | |
| progress_scale=0.45, | |
| output_file=None, | |
| ): | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| raise gr.Error("Could not open video file for stabilization.") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| if out_w is None: | |
| out_w = in_w | |
| if out_h is None: | |
| out_h = in_h | |
| if output_file is None: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| output_file = temp_file.name | |
| temp_file.close() | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_file, fourcc, fps, (int(out_w), int(out_h))) | |
| center = (float(out_w) / 2.0, float(out_h) / 2.0) | |
| base = cv2.getRotationMatrix2D(center, 0.0, float(zoom)) | |
| total = len(offsets) | |
| i = 0 | |
| while i < total: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_out = _resize(frame, int(out_w), int(out_h)) | |
| dx, dy = offsets[i] | |
| if vertical_only: | |
| dx = 0.0 | |
| M = base.copy() | |
| M[0, 2] += float(dx) | |
| M[1, 2] += float(dy) | |
| stabilized = cv2.warpAffine(frame_out, M, (int(out_w), int(out_h)), borderMode=cv2.BORDER_REPLICATE) | |
| out.write(stabilized) | |
| if total > 0 and (i % 5 == 0 or i == total - 1): | |
| progress(progress_offset + (i / max(1, total - 1)) * progress_scale, desc="Stabilizing Video") | |
| i += 1 | |
| cap.release() | |
| out.release() | |
| return output_file | |
| def process_video_ai( | |
| video_file, | |
| zoom, | |
| max_zoom, | |
| vertical_only, | |
| compress_mode, | |
| target_width, | |
| target_height, | |
| auto_zoom, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| gr.Info("Starting AI-powered video processing...") | |
| log_buffer = io.StringIO() | |
| with redirect_stdout(log_buffer): | |
| if isinstance(video_file, dict): | |
| video_file = video_file.get("name", None) | |
| if video_file is None: | |
| raise gr.Error("Please upload a video file.") | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| raise gr.Error("Could not open uploaded video.") | |
| in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| if compress_mode: | |
| out_w = int(target_width) | |
| out_h = int(target_height) | |
| else: | |
| out_w = in_w | |
| out_h = in_h | |
| offsets = compute_offsets( | |
| video_file, | |
| out_w, | |
| out_h, | |
| motion_scale=0.5, | |
| raft_iters=12, | |
| progress=progress, | |
| progress_offset=0.0, | |
| progress_scale=0.55, | |
| ) | |
| gr.Info("Motion estimated successfully.") | |
| if auto_zoom: | |
| z = compute_auto_zoom(offsets, out_w, out_h) | |
| if max_zoom is not None: | |
| try: | |
| mz = float(max_zoom) | |
| if mz > 0: | |
| z = min(z, mz) | |
| except Exception: | |
| pass | |
| gr.Info(f"Auto zoom factor computed: {z:.2f}") | |
| zoom = z | |
| else: | |
| if max_zoom is not None: | |
| try: | |
| mz = float(max_zoom) | |
| if mz > 0: | |
| zoom = min(float(zoom), mz) | |
| except Exception: | |
| zoom = float(zoom) | |
| stabilized_path = stabilize_stream( | |
| video_file, | |
| offsets, | |
| zoom=float(zoom), | |
| vertical_only=bool(vertical_only), | |
| out_w=out_w, | |
| out_h=out_h, | |
| progress=progress, | |
| progress_offset=0.55, | |
| progress_scale=0.45, | |
| ) | |
| gr.Info("Video stabilization complete.") | |
| print("[INFO] Video processing complete.") | |
| logs = log_buffer.getvalue() | |
| return video_file, stabilized_path, logs | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# AI-Powered Video Stabilization") | |
| gr.Markdown( | |
| "Upload a video, select a zoom factor (or use Auto Zoom Mode), optionally cap the maximum zoom, choose whether to apply only vertical stabilization, and optionally compress the output resolution. " | |
| "The system estimates motion using RAFT if available (otherwise Farneback) and stabilizes the video with progress updates." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Input Video") | |
| zoom_slider = gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Zoom Factor (ignored if Auto Zoom enabled)") | |
| max_zoom_slider = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, value=3.0, label="Max Zoom (caps manual + auto zoom)") | |
| auto_zoom_checkbox = gr.Checkbox(label="Auto Zoom Mode", value=False) | |
| vertical_checkbox = gr.Checkbox(label="Vertical Stabilization Only", value=False) | |
| compress_checkbox = gr.Checkbox(label="Compress Output Resolution", value=False) | |
| target_width = gr.Number(label="Target Width (px)", value=640) | |
| target_height = gr.Number(label="Target Height (px)", value=360) | |
| process_button = gr.Button("Process Video") | |
| with gr.Column(): | |
| original_video = gr.Video(label="Original Video") | |
| stabilized_video = gr.Video(label="Stabilized Video") | |
| logs_output = gr.Textbox(label="Logs", lines=10) | |
| process_button.click( | |
| fn=process_video_ai, | |
| inputs=[video_input, zoom_slider, max_zoom_slider, vertical_checkbox, compress_checkbox, target_width, target_height, auto_zoom_checkbox], | |
| outputs=[original_video, stabilized_video, logs_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |