From 3be12b8b506dd344a3c47880afd7effb51bb89bd Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Apr 2025 18:38:00 +0200 Subject: [PATCH] Autoregressive generation. --- candle-examples/examples/csm/main.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs index b8f1768b..21ed9e54 100644 --- a/candle-examples/examples/csm/main.rs +++ b/candle-examples/examples/csm/main.rs @@ -9,7 +9,7 @@ use clap::Parser; use candle_transformers::models::csm::{Config, Model}; -use candle::DType; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; @@ -174,20 +174,31 @@ fn main() -> Result<()> { let config = mimi::Config::v0_1(None); mimi::Model::new(config, vb)? }; + let cb = config.audio_num_codebooks; println!("loaded the model in {:?}", start.elapsed()); if args.prompt.ends_with(".safetensors") { let prompt = candle::safetensors::load(args.prompt, &device)?; - let tokens = prompt + let mut tokens = prompt .get("tokens") .expect("no tokens in prompt") .to_dtype(DType::U32)?; - let mask = prompt.get("mask").expect("no mask in prompt").clone(); + let mut mask = prompt.get("mask").expect("no mask in prompt").clone(); println!("tokens:\n{tokens:?}"); println!("mask:\n{mask:?}"); let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None); - let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?; - println!("frame:\n{frame:?}"); + let mut const_mask = vec![1u8; cb]; + const_mask.push(0); + let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?; + let mut pos = 0; + for i in 0..100 { + let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?; + pos += tokens.dim(1)?; + frame.push(0); + println!("frame {i} {pos}:\n{frame:?}"); + tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?; + mask = const_mask.clone(); + } } else { let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?; println!("{prompt:?}");