From b44d38de0e965b632f28a648ff53bfb10d5ce6d1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 13 Apr 2025 12:02:17 +0200 Subject: [PATCH] Add the Orpheus TTS. (#2886) * Add the Orpheus TTS. * Add a small readme. * Token fix. * Support more voices. * Clippy fixes. --- candle-examples/examples/orpheus/README.md | 14 + candle-examples/examples/orpheus/main.rs | 329 +++++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 candle-examples/examples/orpheus/README.md create mode 100644 candle-examples/examples/orpheus/main.rs diff --git a/candle-examples/examples/orpheus/README.md b/candle-examples/examples/orpheus/README.md new file mode 100644 index 00000000..fde3cb91 --- /dev/null +++ b/candle-examples/examples/orpheus/README.md @@ -0,0 +1,14 @@ +# Orpheus + +Orpheus is a 3B text-to-speech model based on Llama. + +- Weights on HuggingFace + [canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft). +- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS). + + +```bash +cargo run --example orpheus --features cuda -r +``` + + diff --git a/candle-examples/examples/orpheus/main.rs b/candle-examples/examples/orpheus/main.rs new file mode 100644 index 00000000..706e08ca --- /dev/null +++ b/candle-examples/examples/orpheus/main.rs @@ -0,0 +1,329 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle::{DType, Device, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::llama::{Cache, Llama, LlamaConfig}; +use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel}; +use tokenizers::Tokenizer; + +// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43 +const STOP_TOKEN_ID: u32 = 128258; + +#[derive(Parser)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + #[arg(long, default_value = "Hey, how are you doing today?")] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.6)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + model_file: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + config_file: Option, + + /// The output wav file. + #[arg(long, default_value = "out.wav")] + out_file: String, + + #[arg(long, default_value = "3b-0.1-ft")] + which: Which, + + #[arg(long, default_value = "tara")] + voice: Voice, + + #[arg(long)] + use_flash_attn: bool, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Voice { + #[value(name = "tara")] + Tara, + #[value(name = "leah")] + Leah, + #[value(name = "jess")] + Jess, + #[value(name = "leo")] + Leo, + #[value(name = "dan")] + Dan, + #[value(name = "mia")] + Mia, + #[value(name = "zac")] + Zac, + #[value(name = "zoe")] + Zoe, +} + +impl Voice { + fn as_str(&self) -> &'static str { + match self { + Voice::Tara => "tara", + Voice::Leah => "leah", + Voice::Jess => "jess", + Voice::Leo => "leo", + Voice::Dan => "dan", + Voice::Mia => "mia", + Voice::Zac => "zac", + Voice::Zoe => "zoe", + } + } +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "3b-0.1-ft")] + ThreeB0_1Ft, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + let prompt = args.prompt.clone(); + let mut model = Model::load(args)?; + model.run(&prompt)?; + Ok(()) +} + +struct Model { + model: Llama, + tokenizer: Tokenizer, + logits_processor: candle_transformers::generation::LogitsProcessor, + cache: Cache, + device: Device, + verbose_prompt: bool, + snac: SnacModel, + out_file: String, + voice: Voice, +} + +fn load_snac(device: &Device) -> Result { + let api = hf_hub::api::sync::Api::new()?; + let m = api.model("hubertsiuzdak/snac_24khz".to_string()); + let config = m.get("config.json")?; + let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?; + let m = api.model("lmz/candle-snac".to_string()); + let model = m.get("snac_24khz.safetensors")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? }; + let model = SnacModel::new(&config, vb)?; + Ok(model) +} + +impl Model { + fn load(args: Args) -> Result { + let start = std::time::Instant::now(); + let api = hf_hub::api::sync::Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => match args.which { + Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(), + }, + }; + let revision = match args.revision { + Some(r) => r, + None => "main".to_string(), + }; + let repo = api.repo(hf_hub::Repo::with_revision( + model_id, + hf_hub::RepoType::Model, + revision, + )); + let model_files = match args.model_file { + Some(m) => vec![m.into()], + None => match args.which { + Which::ThreeB0_1Ft => { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } + }, + }; + let config = match args.config_file { + Some(m) => m.into(), + None => repo.get("config.json")?, + }; + let tokenizer = match args.tokenizer_file { + Some(m) => m.into(), + None => repo.get("tokenizer.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let dtype = device.bf16_default_to_f32(); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? }; + let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?; + let config = config.into_config(args.use_flash_attn); + let model = Llama::load(vb, &config)?; + let logits_processor = { + use candle_transformers::generation::{LogitsProcessor, Sampling}; + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k.as_ref(), args.top_p.as_ref()) { + (None, None) => Sampling::All { temperature }, + (Some(&k), None) => Sampling::TopK { k, temperature }, + (None, Some(&p)) => Sampling::TopP { p, temperature }, + (Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + println!("loaded the model in {:?}", start.elapsed()); + let cache = Cache::new(true, dtype, &config, &device)?; + let snac = load_snac(&device)?; + Ok(Self { + model, + tokenizer, + logits_processor, + cache, + device, + verbose_prompt: args.verbose_prompt, + snac, + voice: args.voice, + out_file: args.out_file, + }) + } + + fn run(&mut self, prompt: &str) -> Result<()> { + println!("running the model on '{}'", prompt); + let device = &self.device; + let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str()); + let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?; + // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82 + let mut tokens = [ + &[128259], + tokens.get_ids(), + &[128009, 128260, 128261, 128257], + ] + .concat(); + if self.verbose_prompt { + println!("{:?}", tokens); + } + let mut cache = self.cache.clone(); + + println!("starting the inference loop"); + let mut index_pos = 0; + let mut audio_tokens = vec![]; + for index in 0..2000 { + let (context_size, context_index) = if index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = self.logits_processor.sample(&logits)?; + if let Some(tok) = self.tokenizer.id_to_token(next_token) { + match tok.strip_prefix(" match tok.strip_suffix('>') { + Some(tok) => { + let tok = tok.parse::()?; + // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63 + let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096); + audio_tokens.push(tok); + } + None => { + println!("{index}: unexpected custom token {next_token} {tok}"); + } + }, + None => { + println!("{index}: unexpected token {next_token} {tok}"); + } + } + } + if next_token == STOP_TOKEN_ID { + println!("reached stop token"); + break; + } + tokens.push(next_token); + } + println!("generated {} audio tokens", audio_tokens.len()); + let mut codes0 = vec![]; + let mut codes1 = vec![]; + let mut codes2 = vec![]; + for audio_tokens in audio_tokens.chunks_exact(7) { + codes0.push(audio_tokens[0]); + for i in [1, 4] { + codes1.push(audio_tokens[i]); + } + for i in [2, 3, 5, 6] { + codes2.push(audio_tokens[i]); + } + } + let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?; + let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?; + let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?; + let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?; + println!("decoded to pcm {pcm:?}"); + let mut output = std::fs::File::create(&self.out_file)?; + let pcm = pcm.i(0)?.i(0)?.to_vec1::()?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?; + Ok(()) + } +}