mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Autoregressive generation.
This commit is contained in:
@ -9,7 +9,7 @@ use clap::Parser;
|
||||
|
||||
use candle_transformers::models::csm::{Config, Model};
|
||||
|
||||
use candle::DType;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -174,20 +174,31 @@ fn main() -> Result<()> {
|
||||
let config = mimi::Config::v0_1(None);
|
||||
mimi::Model::new(config, vb)?
|
||||
};
|
||||
let cb = config.audio_num_codebooks;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
if args.prompt.ends_with(".safetensors") {
|
||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||
let tokens = prompt
|
||||
let mut tokens = prompt
|
||||
.get("tokens")
|
||||
.expect("no tokens in prompt")
|
||||
.to_dtype(DType::U32)?;
|
||||
let mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||
println!("tokens:\n{tokens:?}");
|
||||
println!("mask:\n{mask:?}");
|
||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None);
|
||||
let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?;
|
||||
println!("frame:\n{frame:?}");
|
||||
let mut const_mask = vec![1u8; cb];
|
||||
const_mask.push(0);
|
||||
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
||||
let mut pos = 0;
|
||||
for i in 0..100 {
|
||||
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
pos += tokens.dim(1)?;
|
||||
frame.push(0);
|
||||
println!("frame {i} {pos}:\n{frame:?}");
|
||||
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
||||
mask = const_mask.clone();
|
||||
}
|
||||
} else {
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
|
Reference in New Issue
Block a user