mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add frame generation.
This commit is contained in:
@ -7,6 +7,7 @@
|
||||
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
||||
/// smaller audio decoder that produces Mimi audio codes.
|
||||
///
|
||||
use crate::generation::LogitsProcessor;
|
||||
use crate::models::encodec;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
@ -363,6 +364,7 @@ pub struct Model {
|
||||
text_embeddings: Embedding,
|
||||
projection: Linear,
|
||||
audio_head: Tensor,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -403,6 +405,42 @@ impl Model {
|
||||
text_embeddings,
|
||||
projection,
|
||||
audio_head,
|
||||
config: cfg.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.backbone.clear_kv_cache();
|
||||
self.decoder.clear_kv_cache();
|
||||
}
|
||||
|
||||
pub fn generate_frame(
|
||||
&mut self,
|
||||
tokens: &Tensor,
|
||||
tokens_mask: &Tensor,
|
||||
input_pos: usize,
|
||||
lp: &mut LogitsProcessor,
|
||||
) -> Result<Vec<u32>> {
|
||||
let h = tokens.clone(); // TODO
|
||||
let h = self.backbone.forward(&h, input_pos)?;
|
||||
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||
let c0_sample = lp.sample(&c0_logits)?;
|
||||
let mut all_samples = vec![c0_sample];
|
||||
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
|
||||
|
||||
self.decoder.clear_kv_cache();
|
||||
for i in 0..(self.config.audio_num_codebooks - 1) {
|
||||
let proj_h = curr_h.apply(&self.projection)?;
|
||||
let decoder_h = self.decoder.forward(&proj_h, i)?;
|
||||
let ci_logits = decoder_h.matmul(&self.audio_head.get(i)?)?;
|
||||
let ci_sample = lp.sample(&ci_logits)?;
|
||||
all_samples.push(ci_sample);
|
||||
let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?;
|
||||
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||
curr_h = ci_embed
|
||||
}
|
||||
Ok(all_samples)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user