SpyC0der77's picture
Update app.py
0f94b53 verified
import cv2
import numpy as np
import torch
import tempfile
import gradio as gr
import time
import io
from contextlib import redirect_stdout
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
try:
print("[INFO] Attempting to load RAFT model from torch.hub...")
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
raft_model = raft_model.to(device)
raft_model.eval()
print("[INFO] RAFT model loaded successfully.")
except Exception as e:
print("[ERROR] Error loading RAFT model:", e)
print("[INFO] Falling back to OpenCV Farneback optical flow.")
raft_model = None
gr.Warning("Falling back to OpenCV Farneback optical flow.")
def _resize(frame, w, h):
if frame.shape[1] == w and frame.shape[0] == h:
return frame
return cv2.resize(
frame,
(w, h),
interpolation=cv2.INTER_AREA if (w < frame.shape[1] or h < frame.shape[0]) else cv2.INTER_LINEAR,
)
def _frame_to_raft_tensor_bgr(frame_bgr):
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
t = torch.from_numpy(frame_rgb).permute(2, 0, 1).contiguous().float().unsqueeze(0).div_(255.0)
return t.to(device, non_blocking=(device.type == "cuda"))
def compute_offsets(
video_file,
out_w,
out_h,
motion_scale=0.5,
raft_iters=12,
progress=gr.Progress(),
progress_offset=0.0,
progress_scale=0.55,
):
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open video file for motion estimation.")
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
mw = max(64, int(out_w * float(motion_scale)))
mh = max(64, int(out_h * float(motion_scale)))
sx = float(out_w) / float(mw)
sy = float(out_h) / float(mh)
ret, prev = cap.read()
if not ret:
cap.release()
raise gr.Error("Cannot read first frame from video.")
prev_out = _resize(prev, out_w, out_h)
prev_small = _resize(prev_out, mw, mh)
use_raft = raft_model is not None
use_amp = device.type == "cuda"
if use_raft:
prev_t = _frame_to_raft_tensor_bgr(prev_small)
else:
prev_g = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY)
offsets = [(0.0, 0.0)]
cum_dx = 0.0
cum_dy = 0.0
idx = 1
while True:
ret, frame = cap.read()
if not ret:
break
frame_out = _resize(frame, out_w, out_h)
curr_small = _resize(frame_out, mw, mh)
if use_raft:
curr_t = _frame_to_raft_tensor_bgr(curr_small)
with torch.no_grad():
if use_amp:
with torch.cuda.amp.autocast(True):
_, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True)
else:
_, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True)
flow = flow_up[0]
dx = float(flow[0].median().item())
dy = float(flow[1].median().item())
prev_t = curr_t
else:
curr_g = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(
prev_g,
curr_g,
None,
pyr_scale=0.5,
levels=3,
winsize=15,
iterations=3,
poly_n=5,
poly_sigma=1.2,
flags=0,
)
dx = float(np.median(flow[..., 0]))
dy = float(np.median(flow[..., 1]))
prev_g = curr_g
dx *= sx
dy *= sy
cum_dx += dx
cum_dy += dy
offsets.append((-cum_dx, -cum_dy))
if total > 0 and (idx % 5 == 0 or idx == total - 1):
progress(progress_offset + (idx / max(1, total - 1)) * progress_scale, desc="Estimating Motion")
idx += 1
cap.release()
return offsets
def compute_auto_zoom(offsets, width, height):
dxs = [o[0] for o in offsets] or [0.0]
dys = [o[1] for o in offsets] or [0.0]
left = max(0.0, -min(dxs))
right = max(0.0, max(dxs))
top = max(0.0, -min(dys))
bottom = max(0.0, max(dys))
safe_w = float(width) - (left + right)
safe_h = float(height) - (top + bottom)
zx = (float(width) / safe_w) if safe_w > 1.0 else 1.0
zy = (float(height) / safe_h) if safe_h > 1.0 else 1.0
return max(1.0, zx, zy)
def stabilize_stream(
video_file,
offsets,
zoom=1.0,
vertical_only=False,
out_w=None,
out_h=None,
progress=gr.Progress(),
progress_offset=0.55,
progress_scale=0.45,
output_file=None,
):
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open video file for stabilization.")
fps = cap.get(cv2.CAP_PROP_FPS)
in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if out_w is None:
out_w = in_w
if out_h is None:
out_h = in_h
if output_file is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
output_file = temp_file.name
temp_file.close()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_file, fourcc, fps, (int(out_w), int(out_h)))
center = (float(out_w) / 2.0, float(out_h) / 2.0)
base = cv2.getRotationMatrix2D(center, 0.0, float(zoom))
total = len(offsets)
i = 0
while i < total:
ret, frame = cap.read()
if not ret:
break
frame_out = _resize(frame, int(out_w), int(out_h))
dx, dy = offsets[i]
if vertical_only:
dx = 0.0
M = base.copy()
M[0, 2] += float(dx)
M[1, 2] += float(dy)
stabilized = cv2.warpAffine(frame_out, M, (int(out_w), int(out_h)), borderMode=cv2.BORDER_REPLICATE)
out.write(stabilized)
if total > 0 and (i % 5 == 0 or i == total - 1):
progress(progress_offset + (i / max(1, total - 1)) * progress_scale, desc="Stabilizing Video")
i += 1
cap.release()
out.release()
return output_file
def process_video_ai(
video_file,
zoom,
max_zoom,
vertical_only,
compress_mode,
target_width,
target_height,
auto_zoom,
progress=gr.Progress(track_tqdm=True),
):
gr.Info("Starting AI-powered video processing...")
log_buffer = io.StringIO()
with redirect_stdout(log_buffer):
if isinstance(video_file, dict):
video_file = video_file.get("name", None)
if video_file is None:
raise gr.Error("Please upload a video file.")
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open uploaded video.")
in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if compress_mode:
out_w = int(target_width)
out_h = int(target_height)
else:
out_w = in_w
out_h = in_h
offsets = compute_offsets(
video_file,
out_w,
out_h,
motion_scale=0.5,
raft_iters=12,
progress=progress,
progress_offset=0.0,
progress_scale=0.55,
)
gr.Info("Motion estimated successfully.")
if auto_zoom:
z = compute_auto_zoom(offsets, out_w, out_h)
if max_zoom is not None:
try:
mz = float(max_zoom)
if mz > 0:
z = min(z, mz)
except Exception:
pass
gr.Info(f"Auto zoom factor computed: {z:.2f}")
zoom = z
else:
if max_zoom is not None:
try:
mz = float(max_zoom)
if mz > 0:
zoom = min(float(zoom), mz)
except Exception:
zoom = float(zoom)
stabilized_path = stabilize_stream(
video_file,
offsets,
zoom=float(zoom),
vertical_only=bool(vertical_only),
out_w=out_w,
out_h=out_h,
progress=progress,
progress_offset=0.55,
progress_scale=0.45,
)
gr.Info("Video stabilization complete.")
print("[INFO] Video processing complete.")
logs = log_buffer.getvalue()
return video_file, stabilized_path, logs
with gr.Blocks() as demo:
gr.Markdown("# AI-Powered Video Stabilization")
gr.Markdown(
"Upload a video, select a zoom factor (or use Auto Zoom Mode), optionally cap the maximum zoom, choose whether to apply only vertical stabilization, and optionally compress the output resolution. "
"The system estimates motion using RAFT if available (otherwise Farneback) and stabilizes the video with progress updates."
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Input Video")
zoom_slider = gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Zoom Factor (ignored if Auto Zoom enabled)")
max_zoom_slider = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, value=3.0, label="Max Zoom (caps manual + auto zoom)")
auto_zoom_checkbox = gr.Checkbox(label="Auto Zoom Mode", value=False)
vertical_checkbox = gr.Checkbox(label="Vertical Stabilization Only", value=False)
compress_checkbox = gr.Checkbox(label="Compress Output Resolution", value=False)
target_width = gr.Number(label="Target Width (px)", value=640)
target_height = gr.Number(label="Target Height (px)", value=360)
process_button = gr.Button("Process Video")
with gr.Column():
original_video = gr.Video(label="Original Video")
stabilized_video = gr.Video(label="Stabilized Video")
logs_output = gr.Textbox(label="Logs", lines=10)
process_button.click(
fn=process_video_ai,
inputs=[video_input, zoom_slider, max_zoom_slider, vertical_checkbox, compress_checkbox, target_width, target_height, auto_zoom_checkbox],
outputs=[original_video, stabilized_video, logs_output],
)
if __name__ == "__main__":
demo.launch()