mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add frame generation.
This commit is contained in:
@ -34,7 +34,7 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
@ -76,6 +76,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weights: Option<String>,
|
weights: Option<String>,
|
||||||
|
|
||||||
|
/// The mimi model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
mimi_weights: Option<String>,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.1)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -139,9 +143,14 @@ fn main() -> Result<()> {
|
|||||||
.model("meta-llama/Llama-3.2-1B".to_string())
|
.model("meta-llama/Llama-3.2-1B".to_string())
|
||||||
.get("tokenizer.json")?,
|
.get("tokenizer.json")?,
|
||||||
};
|
};
|
||||||
|
let mimi_filename = match args.mimi_weights {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("kyutai/mimi".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let _tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config: Config = match args.config {
|
let config: Config = match args.config {
|
||||||
@ -152,14 +161,23 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (_model, _device) = {
|
let (_model, device) = {
|
||||||
let dtype = DType::F32;
|
let dtype = device.bf16_default_to_f32();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
(model, device)
|
(model, device)
|
||||||
};
|
};
|
||||||
|
let _mimi_model = {
|
||||||
|
use candle_transformers::models::mimi;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
|
||||||
|
let config = mimi::Config::v0_1(None);
|
||||||
|
mimi::Model::new(config, vb)?
|
||||||
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||||
|
println!("{prompt:?}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
||||||
/// smaller audio decoder that produces Mimi audio codes.
|
/// smaller audio decoder that produces Mimi audio codes.
|
||||||
///
|
///
|
||||||
|
use crate::generation::LogitsProcessor;
|
||||||
use crate::models::encodec;
|
use crate::models::encodec;
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
@ -363,6 +364,7 @@ pub struct Model {
|
|||||||
text_embeddings: Embedding,
|
text_embeddings: Embedding,
|
||||||
projection: Linear,
|
projection: Linear,
|
||||||
audio_head: Tensor,
|
audio_head: Tensor,
|
||||||
|
config: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -403,6 +405,42 @@ impl Model {
|
|||||||
text_embeddings,
|
text_embeddings,
|
||||||
projection,
|
projection,
|
||||||
audio_head,
|
audio_head,
|
||||||
|
config: cfg.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.backbone.clear_kv_cache();
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_frame(
|
||||||
|
&mut self,
|
||||||
|
tokens: &Tensor,
|
||||||
|
tokens_mask: &Tensor,
|
||||||
|
input_pos: usize,
|
||||||
|
lp: &mut LogitsProcessor,
|
||||||
|
) -> Result<Vec<u32>> {
|
||||||
|
let h = tokens.clone(); // TODO
|
||||||
|
let h = self.backbone.forward(&h, input_pos)?;
|
||||||
|
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||||
|
let c0_sample = lp.sample(&c0_logits)?;
|
||||||
|
let mut all_samples = vec![c0_sample];
|
||||||
|
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||||
|
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||||
|
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
|
||||||
|
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
for i in 0..(self.config.audio_num_codebooks - 1) {
|
||||||
|
let proj_h = curr_h.apply(&self.projection)?;
|
||||||
|
let decoder_h = self.decoder.forward(&proj_h, i)?;
|
||||||
|
let ci_logits = decoder_h.matmul(&self.audio_head.get(i)?)?;
|
||||||
|
let ci_sample = lp.sample(&ci_logits)?;
|
||||||
|
all_samples.push(ci_sample);
|
||||||
|
let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?;
|
||||||
|
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||||
|
curr_h = ci_embed
|
||||||
|
}
|
||||||
|
Ok(all_samples)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user