diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs index 21ed9e54..fe660689 100644 --- a/candle-examples/examples/csm/main.rs +++ b/candle-examples/examples/csm/main.rs @@ -9,7 +9,7 @@ use clap::Parser; use candle_transformers::models::csm::{Config, Model}; -use candle::{DType, Tensor}; +use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; @@ -162,16 +162,16 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; let (mut model, device) = { - let dtype = device.bf16_default_to_f32(); + let dtype = DType::F32; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; (model, device) }; - let _mimi_model = { + let mut mimi_model = { use candle_transformers::models::mimi; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? }; - let config = mimi::Config::v0_1(None); + let config = mimi::Config::v0_1(Some(32)); mimi::Model::new(config, vb)? }; let cb = config.audio_num_codebooks; @@ -186,19 +186,32 @@ fn main() -> Result<()> { let mut mask = prompt.get("mask").expect("no mask in prompt").clone(); println!("tokens:\n{tokens:?}"); println!("mask:\n{mask:?}"); - let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None); + let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None); let mut const_mask = vec![1u8; cb]; const_mask.push(0); let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?; let mut pos = 0; - for i in 0..100 { + let mut all_tokens = vec![]; + for i in 0.. { let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?; pos += tokens.dim(1)?; frame.push(0); + if frame.iter().all(|&x| x == 0) { + break; + } println!("frame {i} {pos}:\n{frame:?}"); tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?; + all_tokens.push(tokens.clone()); mask = const_mask.clone(); } + let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?; + println!("all_tokens:\n{all_tokens:?}"); + let pcm = mimi_model.decode(&all_tokens)?; + let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?; + let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?; + let pcm = pcm.to_vec1::()?; + let mut output = std::fs::File::create("out.wav")?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; } else { let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; println!("{prompt:?}"); diff --git a/candle-transformers/src/models/csm.rs b/candle-transformers/src/models/csm.rs index 02e95f99..d51fb4bf 100644 --- a/candle-transformers/src/models/csm.rs +++ b/candle-transformers/src/models/csm.rs @@ -480,13 +480,19 @@ impl Model { let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?; self.decoder.clear_kv_cache(); - for i in 0..(self.config.audio_num_codebooks - 1) { + let mut decoder_pos = 0; + for i in 1..self.config.audio_num_codebooks { let proj_h = curr_h.apply(&self.projection)?; - let decoder_h = self.decoder.forward(&proj_h, i)?; - let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i)?)?; + let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?; + decoder_pos += curr_h.dim(1)?; + let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?; let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?; all_samples.push(ci_sample); - let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?; + let ci_sample = Tensor::from_slice( + &[ci_sample + (i * self.config.audio_vocab_size) as u32], + (1, 1), + &self.decoder.device, + )?; let ci_embed = self.audio_embeddings.forward(&ci_sample)?; curr_h = ci_embed }