Generate some audio file.

This commit is contained in:
laurent
2025-04-03 19:16:49 +02:00
parent 3be12b8b50
commit 8977c31b6d
2 changed files with 29 additions and 10 deletions

View File

@ -9,7 +9,7 @@ use clap::Parser;
use candle_transformers::models::csm::{Config, Model}; use candle_transformers::models::csm::{Config, Model};
use candle::{DType, Tensor}; use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -162,16 +162,16 @@ fn main() -> Result<()> {
}; };
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (mut model, device) = { let (mut model, device) = {
let dtype = device.bf16_default_to_f32(); let dtype = DType::F32;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?; let model = Model::new(&config, vb)?;
(model, device) (model, device)
}; };
let _mimi_model = { let mut mimi_model = {
use candle_transformers::models::mimi; use candle_transformers::models::mimi;
let vb = let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? }; unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
let config = mimi::Config::v0_1(None); let config = mimi::Config::v0_1(Some(32));
mimi::Model::new(config, vb)? mimi::Model::new(config, vb)?
}; };
let cb = config.audio_num_codebooks; let cb = config.audio_num_codebooks;
@ -186,19 +186,32 @@ fn main() -> Result<()> {
let mut mask = prompt.get("mask").expect("no mask in prompt").clone(); let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
println!("tokens:\n{tokens:?}"); println!("tokens:\n{tokens:?}");
println!("mask:\n{mask:?}"); println!("mask:\n{mask:?}");
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None); let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
let mut const_mask = vec![1u8; cb]; let mut const_mask = vec![1u8; cb];
const_mask.push(0); const_mask.push(0);
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?; let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
let mut pos = 0; let mut pos = 0;
for i in 0..100 { let mut all_tokens = vec![];
for i in 0.. {
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?; let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?; pos += tokens.dim(1)?;
frame.push(0); frame.push(0);
if frame.iter().all(|&x| x == 0) {
break;
}
println!("frame {i} {pos}:\n{frame:?}"); println!("frame {i} {pos}:\n{frame:?}");
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?; tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
all_tokens.push(tokens.clone());
mask = const_mask.clone(); mask = const_mask.clone();
} }
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
println!("all_tokens:\n{all_tokens:?}");
let pcm = mimi_model.decode(&all_tokens)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create("out.wav")?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
} else { } else {
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
println!("{prompt:?}"); println!("{prompt:?}");

View File

@ -480,13 +480,19 @@ impl Model {
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?; let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
self.decoder.clear_kv_cache(); self.decoder.clear_kv_cache();
for i in 0..(self.config.audio_num_codebooks - 1) { let mut decoder_pos = 0;
for i in 1..self.config.audio_num_codebooks {
let proj_h = curr_h.apply(&self.projection)?; let proj_h = curr_h.apply(&self.projection)?;
let decoder_h = self.decoder.forward(&proj_h, i)?; let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i)?)?; decoder_pos += curr_h.dim(1)?;
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?; let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
all_samples.push(ci_sample); all_samples.push(ci_sample);
let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?; let ci_sample = Tensor::from_slice(
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
(1, 1),
&self.decoder.device,
)?;
let ci_embed = self.audio_embeddings.forward(&ci_sample)?; let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
curr_h = ci_embed curr_h = ci_embed
} }