diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs index 1148a2e5..3f2950d5 100644 --- a/candle-examples/examples/csm/main.rs +++ b/candle-examples/examples/csm/main.rs @@ -34,7 +34,7 @@ struct Args { #[arg(long)] use_flash_attn: bool, - #[arg(long)] + #[arg(long, default_value = "[0]Hey how are you doing?")] prompt: String, /// The temperature used to generate samples. @@ -76,6 +76,10 @@ struct Args { #[arg(long)] weights: Option, + /// The mimi model weight file, in safetensor format. + #[arg(long)] + mimi_weights: Option, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -139,9 +143,14 @@ fn main() -> Result<()> { .model("meta-llama/Llama-3.2-1B".to_string()) .get("tokenizer.json")?, }; - + let mimi_filename = match args.mimi_weights { + Some(model) => std::path::PathBuf::from(model), + None => Api::new()? + .model("kyutai/mimi".to_string()) + .get("model.safetensors")?, + }; println!("retrieved the files in {:?}", start.elapsed()); - let _tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); let config: Config = match args.config { @@ -152,14 +161,23 @@ fn main() -> Result<()> { } }; let device = candle_examples::device(args.cpu)?; - let (_model, _device) = { - let dtype = DType::F32; + let (_model, device) = { + let dtype = device.bf16_default_to_f32(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; (model, device) }; + let _mimi_model = { + use candle_transformers::models::mimi; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? }; + let config = mimi::Config::v0_1(None); + mimi::Model::new(config, vb)? + }; println!("loaded the model in {:?}", start.elapsed()); + let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; + println!("{prompt:?}"); Ok(()) } diff --git a/candle-transformers/src/models/csm.rs b/candle-transformers/src/models/csm.rs index e40fc637..1cf4e5e6 100644 --- a/candle-transformers/src/models/csm.rs +++ b/candle-transformers/src/models/csm.rs @@ -7,6 +7,7 @@ /// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a /// smaller audio decoder that produces Mimi audio codes. /// +use crate::generation::LogitsProcessor; use crate::models::encodec; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder}; @@ -363,6 +364,7 @@ pub struct Model { text_embeddings: Embedding, projection: Linear, audio_head: Tensor, + config: Config, } impl Model { @@ -403,6 +405,42 @@ impl Model { text_embeddings, projection, audio_head, + config: cfg.clone(), }) } + + pub fn clear_kv_cache(&mut self) { + self.backbone.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } + + pub fn generate_frame( + &mut self, + tokens: &Tensor, + tokens_mask: &Tensor, + input_pos: usize, + lp: &mut LogitsProcessor, + ) -> Result> { + let h = tokens.clone(); // TODO + let h = self.backbone.forward(&h, input_pos)?; + let c0_logits = h.apply(&self.codebook0_head)?; + let c0_sample = lp.sample(&c0_logits)?; + let mut all_samples = vec![c0_sample]; + let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?; + let c0_embed = self.audio_embeddings.forward(&c0_sample)?; + let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?; + + self.decoder.clear_kv_cache(); + for i in 0..(self.config.audio_num_codebooks - 1) { + let proj_h = curr_h.apply(&self.projection)?; + let decoder_h = self.decoder.forward(&proj_h, i)?; + let ci_logits = decoder_h.matmul(&self.audio_head.get(i)?)?; + let ci_sample = lp.sample(&ci_logits)?; + all_samples.push(ci_sample); + let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?; + let ci_embed = self.audio_embeddings.forward(&ci_sample)?; + curr_h = ci_embed + } + Ok(all_samples) + } }