|
|
import gradio as gr |
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import zipfile |
|
|
from pathlib import Path |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import HfApi, login |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
def process_hf_dataset(dataset_name: str, image_col: str, caption_col: str, lora_name: str, progress=gr.Progress()): |
|
|
"""Process a HuggingFace dataset and create image/txt pairs.""" |
|
|
if not dataset_name.strip(): |
|
|
return None, "Please enter a dataset name", lora_name |
|
|
|
|
|
try: |
|
|
progress(0, desc="Loading dataset...") |
|
|
ds = load_dataset(dataset_name, split="train") |
|
|
|
|
|
|
|
|
output_dir = tempfile.mkdtemp() |
|
|
|
|
|
|
|
|
if not image_col: |
|
|
for col in ds.column_names: |
|
|
if ds.features[col].dtype == "image" or "image" in col.lower(): |
|
|
image_col = col |
|
|
break |
|
|
|
|
|
if not caption_col: |
|
|
for col in ds.column_names: |
|
|
if "text" in col.lower() or "caption" in col.lower() or "prompt" in col.lower(): |
|
|
caption_col = col |
|
|
break |
|
|
|
|
|
if not image_col or not caption_col: |
|
|
return None, f"Could not detect columns. Available: {ds.column_names}", lora_name |
|
|
|
|
|
progress(0.1, desc=f"Processing {len(ds)} images...") |
|
|
|
|
|
for i, item in enumerate(ds): |
|
|
progress((i + 1) / len(ds), desc=f"Processing image {i+1}/{len(ds)}") |
|
|
|
|
|
|
|
|
img = item[image_col] |
|
|
if not isinstance(img, Image.Image): |
|
|
img = Image.open(img) |
|
|
|
|
|
img_filename = f"{i:05d}.png" |
|
|
txt_filename = f"{i:05d}.txt" |
|
|
|
|
|
img.save(os.path.join(output_dir, img_filename)) |
|
|
|
|
|
|
|
|
caption = item[caption_col] |
|
|
with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f: |
|
|
f.write(str(caption)) |
|
|
|
|
|
return output_dir, f"Processed {len(ds)} images from {dataset_name}", lora_name |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}", lora_name |
|
|
|
|
|
|
|
|
def process_uploaded_images(images: list, caption: str, lora_name: str, progress=gr.Progress()): |
|
|
"""Process uploaded images with a shared caption.""" |
|
|
if not images: |
|
|
return None, "Please upload some images", lora_name |
|
|
|
|
|
output_dir = tempfile.mkdtemp() |
|
|
|
|
|
for i, img_data in enumerate(progress.tqdm(images, desc="Processing images")): |
|
|
|
|
|
if isinstance(img_data, tuple): |
|
|
img_path = img_data[0] |
|
|
else: |
|
|
img_path = img_data |
|
|
|
|
|
img = Image.open(img_path) |
|
|
|
|
|
|
|
|
orig_name = Path(img_path).stem |
|
|
img_filename = f"{orig_name}.png" |
|
|
txt_filename = f"{orig_name}.txt" |
|
|
|
|
|
img.save(os.path.join(output_dir, img_filename)) |
|
|
|
|
|
with open(os.path.join(output_dir, txt_filename), "w", encoding="utf-8") as f: |
|
|
f.write(caption if caption else "") |
|
|
|
|
|
return output_dir, f"Processed {len(images)} images", lora_name |
|
|
|
|
|
|
|
|
def create_zip(output_dir: str, lora_name: str = None): |
|
|
"""Create a zip file from the output directory.""" |
|
|
if not output_dir or not os.path.exists(output_dir): |
|
|
return None |
|
|
|
|
|
|
|
|
if lora_name and lora_name.strip(): |
|
|
zip_filename = f"{lora_name.strip().replace(' ', '_')}.zip" |
|
|
zip_path = os.path.join(tempfile.gettempdir(), zip_filename) |
|
|
else: |
|
|
zip_path = tempfile.mktemp(suffix=".zip") |
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: |
|
|
for file in os.listdir(output_dir): |
|
|
zf.write(os.path.join(output_dir, file), file) |
|
|
|
|
|
return zip_path |
|
|
|
|
|
|
|
|
def push_to_hub(output_dir: str, repo_name: str, token: str, private: bool, progress=gr.Progress()): |
|
|
"""Push the processed dataset to HuggingFace Hub.""" |
|
|
if not output_dir or not os.path.exists(output_dir): |
|
|
return "No data to push. Process a dataset first." |
|
|
|
|
|
if not repo_name or not repo_name.strip(): |
|
|
return "Please enter a repository name (or provide a LoRA name when processing)" |
|
|
|
|
|
if not token or not token.strip(): |
|
|
return "Please enter your HuggingFace token" |
|
|
|
|
|
try: |
|
|
progress(0, desc="Logging in...") |
|
|
api = HfApi(token=token) |
|
|
|
|
|
progress(0.2, desc="Creating repository...") |
|
|
api.create_repo(repo_name, repo_type="dataset", private=private, exist_ok=True) |
|
|
|
|
|
progress(0.4, desc="Uploading files...") |
|
|
api.upload_folder( |
|
|
folder_path=output_dir, |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset", |
|
|
) |
|
|
|
|
|
return f"Successfully pushed to https://huggingface.co/datasets/{repo_name}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
current_output_dir = {"path": None, "lora_name": None} |
|
|
|
|
|
|
|
|
def process_dataset_wrapper(dataset_name, image_col, caption_col, lora_name, progress=gr.Progress()): |
|
|
output_dir, msg, lora = process_hf_dataset(dataset_name, image_col, caption_col, lora_name, progress) |
|
|
current_output_dir["path"] = output_dir |
|
|
current_output_dir["lora_name"] = lora |
|
|
zip_path = create_zip(output_dir, lora) if output_dir else None |
|
|
return msg, zip_path |
|
|
|
|
|
|
|
|
def process_images_wrapper(images, caption, lora_name, progress=gr.Progress()): |
|
|
output_dir, msg, lora = process_uploaded_images(images, caption, lora_name, progress) |
|
|
current_output_dir["path"] = output_dir |
|
|
current_output_dir["lora_name"] = lora |
|
|
zip_path = create_zip(output_dir, lora) if output_dir else None |
|
|
return msg, zip_path |
|
|
|
|
|
|
|
|
def push_wrapper(repo_name, token, private, progress=gr.Progress()): |
|
|
|
|
|
final_repo_name = repo_name.strip() if repo_name.strip() else None |
|
|
if not final_repo_name and current_output_dir["lora_name"]: |
|
|
final_repo_name = current_output_dir["lora_name"].strip().replace(" ", "_") |
|
|
return push_to_hub(current_output_dir["path"], final_repo_name, token, private, progress) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Toolkit Dataset Converter") as demo: |
|
|
gr.Markdown("# AI Toolkit Dataset Converter") |
|
|
gr.Markdown("""Convert your datasets to the format expected by [ostris AI Toolkit](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#dataset-preparation). You can either: |
|
|
1. provide a dataset name from the hub OR |
|
|
2. upload your images directly |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("From HuggingFace Dataset"): |
|
|
dataset_name = gr.Textbox( |
|
|
label="Dataset Name", |
|
|
placeholder="e.g., Norod78/Yarn-art-style" |
|
|
) |
|
|
lora_name_ds = gr.Textbox( |
|
|
label="LoRA Name (optional)", |
|
|
placeholder="e.g., my-lora-style", |
|
|
info="Used for ZIP filename and Hub dataset name" |
|
|
) |
|
|
with gr.Row(): |
|
|
image_col = gr.Textbox( |
|
|
label="Image Column (leave empty to auto-detect)", |
|
|
placeholder="image" |
|
|
) |
|
|
caption_col = gr.Textbox( |
|
|
label="Caption Column (leave empty to auto-detect)", |
|
|
placeholder="text" |
|
|
) |
|
|
process_ds_btn = gr.Button("Process Dataset", variant="primary") |
|
|
ds_status = gr.Textbox(label="Status", interactive=False) |
|
|
ds_download = gr.File(label="Download ZIP") |
|
|
|
|
|
process_ds_btn.click( |
|
|
process_dataset_wrapper, |
|
|
inputs=[dataset_name, image_col, caption_col, lora_name_ds], |
|
|
outputs=[ds_status, ds_download] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("From Uploaded Images"): |
|
|
images_input = gr.Gallery( |
|
|
label="Upload Images", |
|
|
file_types=["image"], |
|
|
interactive=True, |
|
|
columns=4, |
|
|
height="auto" |
|
|
) |
|
|
lora_name_img = gr.Textbox( |
|
|
label="LoRA Name (optional)", |
|
|
placeholder="e.g., my-lora-style", |
|
|
info="Used for ZIP filename and Hub dataset name" |
|
|
) |
|
|
shared_caption = gr.Textbox( |
|
|
label="Caption for all images", |
|
|
placeholder="Enter a caption to use for all uploaded images", |
|
|
lines=3 |
|
|
) |
|
|
process_img_btn = gr.Button("Process Images", variant="primary") |
|
|
img_status = gr.Textbox(label="Status", interactive=False) |
|
|
img_download = gr.File(label="Download ZIP") |
|
|
|
|
|
process_img_btn.click( |
|
|
process_images_wrapper, |
|
|
inputs=[images_input, shared_caption, lora_name_img], |
|
|
outputs=[img_status, img_download] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Push to HuggingFace Hub") |
|
|
|
|
|
with gr.Row(): |
|
|
repo_name = gr.Textbox( |
|
|
label="Repository Name", |
|
|
placeholder="username/dataset-name (uses LoRA name if empty)", |
|
|
info="Leave empty to use LoRA name as dataset name" |
|
|
) |
|
|
hf_token = gr.Textbox( |
|
|
label="HuggingFace Token", |
|
|
type="password", |
|
|
placeholder="hf_..." |
|
|
) |
|
|
|
|
|
private_repo = gr.Checkbox(label="Private Repository", value=False) |
|
|
push_btn = gr.Button("Push to Hub", variant="secondary") |
|
|
push_status = gr.Textbox(label="Push Status", interactive=False) |
|
|
|
|
|
push_btn.click( |
|
|
push_wrapper, |
|
|
inputs=[repo_name, hf_token, private_repo], |
|
|
outputs=[push_status] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |