| | |
| | """ |
| | Convert IndexTTS-2 PyTorch models to ONNX format for Rust inference! |
| | |
| | This script converts the three main models: |
| | 1. GPT model (gpt.pth) - Autoregressive text-to-semantic generation |
| | 2. S2Mel model (s2mel.pth) - Semantic-to-mel spectrogram conversion |
| | 3. BigVGAN - Mel-to-waveform vocoder (already available as ONNX from NVIDIA) |
| | |
| | Usage: |
| | python tools/convert_to_onnx.py |
| | |
| | Output: |
| | models/gpt.onnx |
| | models/s2mel.onnx |
| | models/bigvgan.onnx (if needed, otherwise use NVIDIA's) |
| | |
| | Why ONNX? |
| | - Cross-platform: Works on Windows, Linux, macOS, M1/M2 Macs |
| | - Fast: ONNX Runtime is highly optimized |
| | - Rust-native: ort crate provides excellent ONNX Runtime bindings |
| | - No Python: Production inference without Python dependency hell! |
| | |
| | Author: Aye & Hue @ 8b.is |
| | """ |
| |
|
| | import os |
| | import sys |
| |
|
| | |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | project_root = os.path.dirname(script_dir) |
| | os.chdir(project_root) |
| |
|
| | |
| | os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' |
| |
|
| | print("=" * 70) |
| | print(" IndexTTS-2 PyTorch to ONNX Converter") |
| | print(" For Rust inference with ort crate!") |
| | print("=" * 70) |
| | print() |
| |
|
| | |
| | if not os.path.exists("checkpoints/gpt.pth"): |
| | print("ERROR: Models not found!") |
| | print("Run: python tools/download_files.py -s huggingface") |
| | sys.exit(1) |
| |
|
| | import torch |
| | import torch.onnx |
| | import numpy as np |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.insert(0, "indextts - REMOVING - REF ONLY") |
| |
|
| | |
| | output_dir = Path("models") |
| | output_dir.mkdir(exist_ok=True) |
| |
|
| | print(f"PyTorch version: {torch.__version__}") |
| | print(f"Output directory: {output_dir}") |
| | print() |
| |
|
| |
|
| | def export_speaker_encoder(): |
| | """ |
| | Export the CAM++ speaker encoder to ONNX. |
| | |
| | This model extracts speaker embeddings from reference audio. |
| | Input: mel spectrogram [batch, n_mels, time] |
| | Output: speaker embedding [batch, 192] |
| | """ |
| | print("\n" + "=" * 50) |
| | print("Exporting Speaker Encoder (CAM++)") |
| | print("=" * 50) |
| |
|
| | try: |
| | from omegaconf import OmegaConf |
| | from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus |
| |
|
| | |
| | cfg = OmegaConf.load("checkpoints/config.yaml") |
| |
|
| | |
| | model = CAMPPlus(feat_dim=80, embedding_size=192) |
| |
|
| | |
| | weights_path = "./checkpoints/hf_cache/models--funasr--campplus/snapshots/fb71fe990cbf6031ae6987a2d76fe64f94377b7e/campplus_cn_common.bin" |
| | if os.path.exists(weights_path): |
| | state_dict = torch.load(weights_path, map_location='cpu') |
| | model.load_state_dict(state_dict) |
| | print(f"Loaded weights from: {weights_path}") |
| |
|
| | model.eval() |
| |
|
| | |
| | |
| | dummy_input = torch.randn(1, 100, 80) |
| |
|
| | |
| | with torch.no_grad(): |
| | test_output = model(dummy_input) |
| | print(f"Forward pass works! Output shape: {test_output.shape}") |
| |
|
| | |
| | output_path = output_dir / "speaker_encoder.onnx" |
| | torch.onnx.export( |
| | model, |
| | dummy_input, |
| | str(output_path), |
| | input_names=['mel_spectrogram'], |
| | output_names=['speaker_embedding'], |
| | dynamic_axes={ |
| | 'mel_spectrogram': {0: 'batch', 1: 'time'}, |
| | 'speaker_embedding': {0: 'batch'} |
| | }, |
| | opset_version=18, |
| | do_constant_folding=True, |
| | ) |
| |
|
| | |
| | import onnx |
| | onnx_model = onnx.load(str(output_path)) |
| | onnx.checker.check_model(onnx_model) |
| |
|
| | print(f"✓ Exported: {output_path}") |
| | print(f" Input: mel_spectrogram [batch, time, 80]") |
| | print(f" Output: speaker_embedding [batch, 192]") |
| | print(f"✓ ONNX model verified!") |
| | return True |
| |
|
| | except Exception as e: |
| | print(f"✗ Failed to export speaker encoder: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| |
|
| | def export_gpt_model(): |
| | """ |
| | Export the GPT autoregressive model to ONNX. |
| | |
| | This is the most complex model - generates semantic tokens from text. |
| | We may need to export it in parts due to KV caching. |
| | |
| | Input: text_tokens [batch, seq_len], speaker_embedding [batch, 192] |
| | Output: semantic_codes [batch, code_len] |
| | """ |
| | print("\n" + "=" * 50) |
| | print("Exporting GPT Model (Autoregressive)") |
| | print("=" * 50) |
| |
|
| | try: |
| | from omegaconf import OmegaConf |
| |
|
| | |
| | cfg = OmegaConf.load("checkpoints/config.yaml") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | print("GPT model export is complex due to:") |
| | print(" - Autoregressive generation with KV caching") |
| | print(" - Dynamic sequence lengths") |
| | print(" - Multiple internal components") |
| | print() |
| | print("Options:") |
| | print(" A) Export without KV cache (slower but simpler)") |
| | print(" B) Export encoder + single-step decoder (efficient)") |
| | print(" C) Use torch.compile + ONNX tracing") |
| | print() |
| |
|
| | |
| | from infer_v2 import IndexTTS2 |
| |
|
| | |
| | tts = IndexTTS2( |
| | cfg_path="checkpoints/config.yaml", |
| | model_dir="checkpoints", |
| | use_fp16=False, |
| | device="cpu" |
| | ) |
| |
|
| | |
| | gpt = tts.gpt |
| | gpt.eval() |
| |
|
| | print(f"GPT model loaded: {type(gpt)}") |
| | print(f"Parameters: {sum(p.numel() for p in gpt.parameters()):,}") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | output_path = output_dir / "gpt_encoder.onnx" |
| |
|
| | |
| | text_tokens = torch.randint(0, 30000, (1, 32), dtype=torch.int64) |
| |
|
| | |
| | |
| | print(f"Attempting GPT export (may require modifications)...") |
| |
|
| | |
| | print() |
| | print("Note: Full GPT export requires modifying the model code") |
| | print("to remove dynamic control flow. Creating a wrapper...") |
| |
|
| | return False |
| |
|
| | except Exception as e: |
| | print(f"✗ Failed to export GPT: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| |
|
| | def export_s2mel_model(): |
| | """ |
| | Export the Semantic-to-Mel model (flow matching). |
| | |
| | This converts semantic codes to mel spectrograms. |
| | Input: semantic_codes [batch, code_len], speaker_embedding [batch, 192] |
| | Output: mel_spectrogram [batch, 80, mel_len] |
| | """ |
| | print("\n" + "=" * 50) |
| | print("Exporting S2Mel Model (Flow Matching)") |
| | print("=" * 50) |
| |
|
| | try: |
| | from omegaconf import OmegaConf |
| |
|
| | cfg = OmegaConf.load("checkpoints/config.yaml") |
| |
|
| | print("S2Mel model (Diffusion/Flow Matching) is also complex:") |
| | print(" - Multiple denoising steps (iterative)") |
| | print(" - CFM (Conditional Flow Matching) requires ODE solving") |
| | print() |
| | print("Export strategy:") |
| | print(" 1. Export the single denoising step") |
| | print(" 2. Run iteration loop in Rust") |
| | print() |
| |
|
| | return False |
| |
|
| | except Exception as e: |
| | print(f"✗ Failed to export S2Mel: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| |
|
| | def export_bigvgan(): |
| | """ |
| | Export BigVGAN vocoder to ONNX. |
| | |
| | Good news: NVIDIA provides pre-trained BigVGAN models! |
| | Even better: They're designed for easy ONNX export. |
| | |
| | Input: mel_spectrogram [batch, 80, mel_len] |
| | Output: waveform [batch, 1, wave_len] |
| | """ |
| | print("\n" + "=" * 50) |
| | print("Exporting BigVGAN Vocoder") |
| | print("=" * 50) |
| |
|
| | try: |
| | |
| | |
| |
|
| | print("BigVGAN options:") |
| | print(" 1. Use NVIDIA's pre-exported ONNX (recommended)") |
| | print(" https://github.com/NVIDIA/BigVGAN") |
| | print() |
| | print(" 2. Export from PyTorch weights (we'll do this)") |
| | print() |
| |
|
| | |
| | try: |
| | from bigvgan import bigvgan |
| | model = bigvgan.BigVGAN.from_pretrained( |
| | 'nvidia/bigvgan_v2_22khz_80band_256x', |
| | use_cuda_kernel=False |
| | ) |
| | model.eval() |
| | model.remove_weight_norm() |
| |
|
| | print(f"BigVGAN loaded from HuggingFace") |
| | print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") |
| |
|
| | |
| | dummy_mel = torch.randn(1, 80, 100) |
| |
|
| | |
| | output_path = output_dir / "bigvgan.onnx" |
| | torch.onnx.export( |
| | model, |
| | dummy_mel, |
| | str(output_path), |
| | input_names=['mel_spectrogram'], |
| | output_names=['waveform'], |
| | dynamic_axes={ |
| | 'mel_spectrogram': {0: 'batch', 2: 'mel_length'}, |
| | 'waveform': {0: 'batch', 2: 'wave_length'} |
| | }, |
| | opset_version=18, |
| | do_constant_folding=True, |
| | ) |
| |
|
| | print(f"✓ Exported: {output_path}") |
| | print(f" Input: mel_spectrogram [batch, 80, mel_len]") |
| | print(f" Output: waveform [batch, 1, wave_len]") |
| |
|
| | |
| | import onnx |
| | onnx_model = onnx.load(str(output_path)) |
| | onnx.checker.check_model(onnx_model) |
| | print(f"✓ ONNX model verified!") |
| |
|
| | return True |
| |
|
| | except ImportError: |
| | print("bigvgan package not installed, installing...") |
| | os.system("pip install bigvgan") |
| | print("Please re-run the script.") |
| | return False |
| |
|
| | except Exception as e: |
| | print(f"✗ Failed to export BigVGAN: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| |
|
| | def main(): |
| | print("\nStarting ONNX conversion...\n") |
| |
|
| | results = {} |
| |
|
| | |
| | results['speaker_encoder'] = export_speaker_encoder() |
| | results['gpt'] = export_gpt_model() |
| | results['s2mel'] = export_s2mel_model() |
| | results['bigvgan'] = export_bigvgan() |
| |
|
| | |
| | print("\n" + "=" * 70) |
| | print(" CONVERSION SUMMARY") |
| | print("=" * 70) |
| |
|
| | for name, success in results.items(): |
| | status = "✓ SUCCESS" if success else "✗ NEEDS WORK" |
| | print(f" {name:20} {status}") |
| |
|
| | print() |
| |
|
| | if all(results.values()): |
| | print("All models converted! Ready for Rust inference.") |
| | else: |
| | print("Some models need manual intervention.") |
| | print() |
| | print("For complex models (GPT, S2Mel), consider:") |
| | print(" 1. Modifying the Python code to remove dynamic control flow") |
| | print(" 2. Using torch.jit.trace with concrete inputs") |
| | print(" 3. Exporting subcomponents separately") |
| | print(" 4. Using ONNX Runtime's transformer optimizations") |
| |
|
| | print() |
| | print("Output directory:", output_dir.absolute()) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|