diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs index 3f2950d5..b8f1768b 100644 --- a/candle-examples/examples/csm/main.rs +++ b/candle-examples/examples/csm/main.rs @@ -161,7 +161,7 @@ fn main() -> Result<()> { } }; let device = candle_examples::device(args.cpu)?; - let (_model, device) = { + let (mut model, device) = { let dtype = device.bf16_default_to_f32(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; @@ -176,8 +176,22 @@ fn main() -> Result<()> { }; println!("loaded the model in {:?}", start.elapsed()); - let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; - println!("{prompt:?}"); + if args.prompt.ends_with(".safetensors") { + let prompt = candle::safetensors::load(args.prompt, &device)?; + let tokens = prompt + .get("tokens") + .expect("no tokens in prompt") + .to_dtype(DType::U32)?; + let mask = prompt.get("mask").expect("no mask in prompt").clone(); + println!("tokens:\n{tokens:?}"); + println!("mask:\n{mask:?}"); + let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None); + let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?; + println!("frame:\n{frame:?}"); + } else { + let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; + println!("{prompt:?}"); + } Ok(()) } diff --git a/candle-transformers/src/models/csm.rs b/candle-transformers/src/models/csm.rs index 1cf4e5e6..b8d3c2f1 100644 --- a/candle-transformers/src/models/csm.rs +++ b/candle-transformers/src/models/csm.rs @@ -1,4 +1,3 @@ -#![allow(unused)] //! Implementation of the Conversational Speech Model (CSM) from Sesame //! //! See: [CSM](Conversational Speech Model) @@ -8,7 +7,6 @@ /// smaller audio decoder that produces Mimi audio codes. /// use crate::generation::LogitsProcessor; -use crate::models::encodec; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder}; use std::sync::Arc; @@ -30,6 +28,7 @@ pub struct Config { pub text_vocab_size: usize, } +#[allow(unused)] #[derive(Debug, Clone)] pub struct LlamaConfig { vocab_size: usize, @@ -421,10 +420,32 @@ impl Model { input_pos: usize, lp: &mut LogitsProcessor, ) -> Result> { - let h = tokens.clone(); // TODO - let h = self.backbone.forward(&h, input_pos)?; + let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?; + let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?; + let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?; + let text_embeds = self.text_embeddings.forward(&text_tokens)?; + let arange = (Tensor::arange( + 0u32, + self.config.audio_num_codebooks as u32, + &self.decoder.device, + )? * self.config.audio_vocab_size as f64)?; + let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?; + let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape(( + b_sz, + seq_len, + self.config.audio_num_codebooks, + (), + ))?; + let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?; + let embeds = embeds.broadcast_mul( + &tokens_mask + .to_dtype(self.backbone.dtype)? + .unsqueeze(D::Minus1)?, + )?; + let embeds = embeds.sum(2)?; + let h = self.backbone.forward(&embeds, input_pos)?; let c0_logits = h.apply(&self.codebook0_head)?; - let c0_sample = lp.sample(&c0_logits)?; + let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?; let mut all_samples = vec![c0_sample]; let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?; let c0_embed = self.audio_embeddings.forward(&c0_sample)?; @@ -434,8 +455,8 @@ impl Model { for i in 0..(self.config.audio_num_codebooks - 1) { let proj_h = curr_h.apply(&self.projection)?; let decoder_h = self.decoder.forward(&proj_h, i)?; - let ci_logits = decoder_h.matmul(&self.audio_head.get(i)?)?; - let ci_sample = lp.sample(&ci_logits)?; + let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i)?)?; + let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?; all_samples.push(ci_sample); let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?; let ci_embed = self.audio_embeddings.forward(&ci_sample)?;