linoyts's picture
linoyts HF Staff
Create app.py
af28872 verified
import gradio as gr
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
from datasets import load_dataset
from huggingface_hub import HfApi, login
from PIL import Image
def process_hf_dataset(dataset_name: str, image_col: str, caption_col: str, lora_name: str, progress=gr.Progress()):
"""Process a HuggingFace dataset and create image/txt pairs."""
if not dataset_name.strip():
return None, "Please enter a dataset name", lora_name
try:
progress(0, desc="Loading dataset...")
ds = load_dataset(dataset_name, split="train")
# Create temp directory for output
output_dir = tempfile.mkdtemp()
# Detect columns if not specified
if not image_col:
for col in ds.column_names:
if ds.features[col].dtype == "image" or "image" in col.lower():
image_col = col
break
if not caption_col:
for col in ds.column_names:
if "text" in col.lower() or "caption" in col.lower() or "prompt" in col.lower():
caption_col = col
break
if not image_col or not caption_col:
return None, f"Could not detect columns. Available: {ds.column_names}", lora_name
progress(0.1, desc=f"Processing {len(ds)} images...")
for i, item in enumerate(ds):
progress((i + 1) / len(ds), desc=f"Processing image {i+1}/{len(ds)}")
# Save image
img = item[image_col]
if not isinstance(img, Image.Image):
img = Image.open(img)
img_filename = f"{i:05d}.png"
txt_filename = f"{i:05d}.txt"
img.save(os.path.join(output_dir, img_filename))
# Save caption
caption = item[caption_col]
with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f:
f.write(str(caption))
return output_dir, f"Processed {len(ds)} images from {dataset_name}", lora_name
except Exception as e:
return None, f"Error: {str(e)}", lora_name
def process_uploaded_images(images: list, caption: str, lora_name: str, progress=gr.Progress()):
"""Process uploaded images with a shared caption."""
if not images:
return None, "Please upload some images", lora_name
output_dir = tempfile.mkdtemp()
for i, img_data in enumerate(progress.tqdm(images, desc="Processing images")):
# Gallery returns tuples of (filepath, caption) or just filepath
if isinstance(img_data, tuple):
img_path = img_data[0]
else:
img_path = img_data
img = Image.open(img_path)
# Use original filename without extension
orig_name = Path(img_path).stem
img_filename = f"{orig_name}.png"
txt_filename = f"{orig_name}.txt"
img.save(os.path.join(output_dir, img_filename))
with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f:
f.write(caption if caption else "")
return output_dir, f"Processed {len(images)} images", lora_name
def create_zip(output_dir: str, lora_name: str = None):
"""Create a zip file from the output directory."""
if not output_dir or not os.path.exists(output_dir):
return None
# Use lora_name for zip filename if provided
if lora_name and lora_name.strip():
zip_filename = f"{lora_name.strip().replace(' ', '_')}.zip"
zip_path = os.path.join(tempfile.gettempdir(), zip_filename)
else:
zip_path = tempfile.mktemp(suffix=".zip")
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for file in os.listdir(output_dir):
zf.write(os.path.join(output_dir, file), file)
return zip_path
def push_to_hub(output_dir: str, repo_name: str, token: str, private: bool, progress=gr.Progress()):
"""Push the processed dataset to HuggingFace Hub."""
if not output_dir or not os.path.exists(output_dir):
return "No data to push. Process a dataset first."
if not repo_name or not repo_name.strip():
return "Please enter a repository name (or provide a LoRA name when processing)"
if not token or not token.strip():
return "Please enter your HuggingFace token"
try:
progress(0, desc="Logging in...")
api = HfApi(token=token)
progress(0.2, desc="Creating repository...")
api.create_repo(repo_name, repo_type="dataset", private=private, exist_ok=True)
progress(0.4, desc="Uploading files...")
api.upload_folder(
folder_path=output_dir,
repo_id=repo_name,
repo_type="dataset",
)
return f"Successfully pushed to https://huggingface.co/datasets/{repo_name}"
except Exception as e:
return f"Error: {str(e)}"
# Global state for output directory
current_output_dir = {"path": None, "lora_name": None}
def process_dataset_wrapper(dataset_name, image_col, caption_col, lora_name, progress=gr.Progress()):
output_dir, msg, lora = process_hf_dataset(dataset_name, image_col, caption_col, lora_name, progress)
current_output_dir["path"] = output_dir
current_output_dir["lora_name"] = lora
zip_path = create_zip(output_dir, lora) if output_dir else None
return msg, zip_path
def process_images_wrapper(images, caption, lora_name, progress=gr.Progress()):
output_dir, msg, lora = process_uploaded_images(images, caption, lora_name, progress)
current_output_dir["path"] = output_dir
current_output_dir["lora_name"] = lora
zip_path = create_zip(output_dir, lora) if output_dir else None
return msg, zip_path
def push_wrapper(repo_name, token, private, progress=gr.Progress()):
# Use lora_name as default repo name if repo_name is empty
final_repo_name = repo_name.strip() if repo_name.strip() else None
if not final_repo_name and current_output_dir["lora_name"]:
final_repo_name = current_output_dir["lora_name"].strip().replace(" ", "_")
return push_to_hub(current_output_dir["path"], final_repo_name, token, private, progress)
# Build the Gradio interface
with gr.Blocks(title="AI Toolkit Dataset Converter") as demo:
gr.Markdown("# AI Toolkit Dataset Converter")
gr.Markdown("""Convert your datasets to the format expected by [ostris AI Toolkit](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#dataset-preparation). You can either:
1. provide a dataset name from the hub OR
2. upload your images directly
""")
with gr.Tabs():
# Tab 1: HuggingFace Dataset
with gr.Tab("From HuggingFace Dataset"):
dataset_name = gr.Textbox(
label="Dataset Name",
placeholder="e.g., Norod78/Yarn-art-style"
)
lora_name_ds = gr.Textbox(
label="LoRA Name (optional)",
placeholder="e.g., my-lora-style",
info="Used for ZIP filename and Hub dataset name"
)
with gr.Row():
image_col = gr.Textbox(
label="Image Column (leave empty to auto-detect)",
placeholder="image"
)
caption_col = gr.Textbox(
label="Caption Column (leave empty to auto-detect)",
placeholder="text"
)
process_ds_btn = gr.Button("Process Dataset", variant="primary")
ds_status = gr.Textbox(label="Status", interactive=False)
ds_download = gr.File(label="Download ZIP")
process_ds_btn.click(
process_dataset_wrapper,
inputs=[dataset_name, image_col, caption_col, lora_name_ds],
outputs=[ds_status, ds_download]
)
# Tab 2: Upload Images
with gr.Tab("From Uploaded Images"):
images_input = gr.Gallery(
label="Upload Images",
file_types=["image"],
interactive=True,
columns=4,
height="auto"
)
lora_name_img = gr.Textbox(
label="LoRA Name (optional)",
placeholder="e.g., my-lora-style",
info="Used for ZIP filename and Hub dataset name"
)
shared_caption = gr.Textbox(
label="Caption for all images",
placeholder="Enter a caption to use for all uploaded images",
lines=3
)
process_img_btn = gr.Button("Process Images", variant="primary")
img_status = gr.Textbox(label="Status", interactive=False)
img_download = gr.File(label="Download ZIP")
process_img_btn.click(
process_images_wrapper,
inputs=[images_input, shared_caption, lora_name_img],
outputs=[img_status, img_download]
)
# Push to Hub section
gr.Markdown("---")
gr.Markdown("### Push to HuggingFace Hub")
with gr.Row():
repo_name = gr.Textbox(
label="Repository Name",
placeholder="username/dataset-name (uses LoRA name if empty)",
info="Leave empty to use LoRA name as dataset name"
)
hf_token = gr.Textbox(
label="HuggingFace Token",
type="password",
placeholder="hf_..."
)
private_repo = gr.Checkbox(label="Private Repository", value=False)
push_btn = gr.Button("Push to Hub", variant="secondary")
push_status = gr.Textbox(label="Push Status", interactive=False)
push_btn.click(
push_wrapper,
inputs=[repo_name, hf_token, private_repo],
outputs=[push_status]
)
if __name__ == "__main__":
demo.launch()