Add a flag to select the dtype used in metavoice. (#1805)

This commit is contained in:
Laurent Mazare
2024-03-05 12:16:00 +01:00
committed by GitHub
parent bd9ab9bc04
commit 8a99cf7dd2
3 changed files with 35 additions and 15 deletions

View File

@ -19,6 +19,13 @@ use rand::{distributions::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ArgDType {
F32,
F16,
Bf16,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -67,6 +74,9 @@ struct Args {
#[arg(long)]
spk_emb: Option<String>,
#[arg(long, default_value = "f32")]
dtype: ArgDType,
}
fn main() -> Result<()> {
@ -120,15 +130,18 @@ fn main() -> Result<()> {
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let first_stage_vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[first_stage_weights], DType::F32, &device)?
let dtype = match args.dtype {
ArgDType::F32 => DType::F32,
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
let first_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
let first_stage_config = transformer::Config::cfg1b_v0_1();
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
let second_stage_vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)?
};
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
@ -137,9 +150,8 @@ fn main() -> Result<()> {
} else {
&device
};
let encodec_vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, encodec_device)?
};
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
@ -154,7 +166,7 @@ fn main() -> Result<()> {
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
let spk_emb = match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
};
let spk_emb = spk_emb.to_device(&device)?;
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
@ -228,7 +240,7 @@ fn main() -> Result<()> {
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;

View File

@ -297,6 +297,7 @@ impl VectorQuantization {
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
layers: Vec<VectorQuantization>,
dtype: DType,
}
impl ResidualVectorQuantizer {
@ -305,7 +306,10 @@ impl ResidualVectorQuantizer {
let layers = (0..cfg.num_quantizers())
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
Ok(Self {
layers,
dtype: vb.dtype(),
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
@ -321,7 +325,7 @@ impl ResidualVectorQuantizer {
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
let ncodes = codes.dim(0)?;
if ncodes > self.layers.len() {
candle::bail!(

View File

@ -562,6 +562,7 @@ pub mod gpt {
ln_f: Norm,
lm_heads: Vec<Linear>,
cfg: Config,
dtype: DType,
}
impl Model {
@ -596,6 +597,7 @@ pub mod gpt {
ln_f,
lm_heads,
cfg,
dtype: vb.dtype(),
})
}
@ -608,7 +610,7 @@ pub mod gpt {
let (b, _num_hierarchies, t) = idx.dims3()?;
let pos = Tensor::arange(0u32, t as u32, device)?;
let pos_emb = pos.apply(&self.wpe)?;
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), DType::F32, device)?;
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
for (wte_idx, wte) in self.wtes.iter().enumerate() {
let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
tok_emb = (tok_emb + emb)?;
@ -847,10 +849,11 @@ pub mod transformer {
}
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
let dtype = vb.dtype();
let spk_cond_mask = Tensor::cat(
&[
Tensor::ones((1, 1, cfg.dim), DType::F32, vb.device())?,
Tensor::zeros((1, 1, cfg.dim), DType::F32, vb.device())?,
Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
],
0,
)?;
@ -887,6 +890,7 @@ pub mod transformer {
.apply(&self.speaker_cond_pos)?
.broadcast_mul(&self.spk_cond_mask)?,
)?;
let mask = mask.to_dtype(xs.dtype())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, pos, &mask)?
}