mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
parler-tts support (#2431)
* Start sketching parler-tts support. * Implement the attention. * Add the example code. * Fix the example. * Add the description + t5 encode it. * More of the parler forward pass. * Fix the positional embeddings. * Support random sampling in generation. * Handle EOS. * Add the python decoder. * Proper causality mask.
This commit is contained in:
29
candle-examples/examples/parler-tts/decode.py
Normal file
29
candle-examples/examples/parler-tts/decode.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from safetensors.torch import load_file
|
||||
from parler_tts import DACModel
|
||||
|
||||
tensors = load_file("out.safetensors")
|
||||
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
|
||||
output_ids = tensors["codes"][None, None]
|
||||
print(output_ids, "\n", output_ids.shape)
|
||||
batch_size = 1
|
||||
with torch.no_grad():
|
||||
output_values = []
|
||||
for sample_id in range(batch_size):
|
||||
sample = output_ids[:, sample_id]
|
||||
sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
|
||||
if sample_mask.sum() > 0:
|
||||
sample = sample[:, :, sample_mask]
|
||||
sample = dac_model.decode(sample[None, ...], [None]).audio_values
|
||||
output_values.append(sample.transpose(0, 2))
|
||||
else:
|
||||
output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
|
||||
output_lengths = [audio.shape[0] for audio in output_values]
|
||||
pcm = (
|
||||
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
|
||||
.squeeze(-1)
|
||||
.squeeze(-1)
|
||||
)
|
||||
print(pcm.shape, pcm.dtype)
|
||||
torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)
|
175
candle-examples/examples/parler-tts/main.rs
Normal file
175
candle-examples/examples/parler-tts/main.rs
Normal file
@ -0,0 +1,175 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::parler_tts::{Config, Model};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long, default_value = "Hey, how are you doing today?")]
|
||||
prompt: String,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
|
||||
)]
|
||||
description: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long, default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Use f16 precision for all the computations rather than f32.
|
||||
#[arg(long)]
|
||||
f16: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_steps: usize,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "parler-tts/parler-tts-large-v1".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(r) => r,
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
revision,
|
||||
));
|
||||
let model_files = match args.model_file {
|
||||
Some(m) => vec![m.into()],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
let config = match args.config_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer = match args.tokenizer_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||
let mut model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let description_tokens = tokenizer
|
||||
.encode(args.description, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{description_tokens}");
|
||||
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{prompt_tokens}");
|
||||
|
||||
let lp = candle_transformers::generation::LogitsProcessor::new(
|
||||
args.seed,
|
||||
Some(args.temperature),
|
||||
args.top_p,
|
||||
);
|
||||
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
|
||||
println!("{codes}");
|
||||
let codes = codes.to_dtype(DType::I64)?;
|
||||
codes.save_safetensors("codes", "out.safetensors")?;
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user