mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Get the sampling to work.
This commit is contained in:
@ -161,7 +161,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (_model, device) = {
|
let (mut model, device) = {
|
||||||
let dtype = device.bf16_default_to_f32();
|
let dtype = device.bf16_default_to_f32();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
@ -176,8 +176,22 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
if args.prompt.ends_with(".safetensors") {
|
||||||
println!("{prompt:?}");
|
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||||
|
let tokens = prompt
|
||||||
|
.get("tokens")
|
||||||
|
.expect("no tokens in prompt")
|
||||||
|
.to_dtype(DType::U32)?;
|
||||||
|
let mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||||
|
println!("tokens:\n{tokens:?}");
|
||||||
|
println!("mask:\n{mask:?}");
|
||||||
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None);
|
||||||
|
let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?;
|
||||||
|
println!("frame:\n{frame:?}");
|
||||||
|
} else {
|
||||||
|
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||||
|
println!("{prompt:?}");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
#![allow(unused)]
|
|
||||||
//! Implementation of the Conversational Speech Model (CSM) from Sesame
|
//! Implementation of the Conversational Speech Model (CSM) from Sesame
|
||||||
//!
|
//!
|
||||||
//! See: [CSM](Conversational Speech Model)
|
//! See: [CSM](Conversational Speech Model)
|
||||||
@ -8,7 +7,6 @@
|
|||||||
/// smaller audio decoder that produces Mimi audio codes.
|
/// smaller audio decoder that produces Mimi audio codes.
|
||||||
///
|
///
|
||||||
use crate::generation::LogitsProcessor;
|
use crate::generation::LogitsProcessor;
|
||||||
use crate::models::encodec;
|
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -30,6 +28,7 @@ pub struct Config {
|
|||||||
pub text_vocab_size: usize,
|
pub text_vocab_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -421,10 +420,32 @@ impl Model {
|
|||||||
input_pos: usize,
|
input_pos: usize,
|
||||||
lp: &mut LogitsProcessor,
|
lp: &mut LogitsProcessor,
|
||||||
) -> Result<Vec<u32>> {
|
) -> Result<Vec<u32>> {
|
||||||
let h = tokens.clone(); // TODO
|
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
|
||||||
let h = self.backbone.forward(&h, input_pos)?;
|
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
|
||||||
|
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
|
||||||
|
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
|
||||||
|
let arange = (Tensor::arange(
|
||||||
|
0u32,
|
||||||
|
self.config.audio_num_codebooks as u32,
|
||||||
|
&self.decoder.device,
|
||||||
|
)? * self.config.audio_vocab_size as f64)?;
|
||||||
|
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
|
||||||
|
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
seq_len,
|
||||||
|
self.config.audio_num_codebooks,
|
||||||
|
(),
|
||||||
|
))?;
|
||||||
|
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
|
||||||
|
let embeds = embeds.broadcast_mul(
|
||||||
|
&tokens_mask
|
||||||
|
.to_dtype(self.backbone.dtype)?
|
||||||
|
.unsqueeze(D::Minus1)?,
|
||||||
|
)?;
|
||||||
|
let embeds = embeds.sum(2)?;
|
||||||
|
let h = self.backbone.forward(&embeds, input_pos)?;
|
||||||
let c0_logits = h.apply(&self.codebook0_head)?;
|
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||||
let c0_sample = lp.sample(&c0_logits)?;
|
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
|
||||||
let mut all_samples = vec![c0_sample];
|
let mut all_samples = vec![c0_sample];
|
||||||
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||||
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||||
@ -434,8 +455,8 @@ impl Model {
|
|||||||
for i in 0..(self.config.audio_num_codebooks - 1) {
|
for i in 0..(self.config.audio_num_codebooks - 1) {
|
||||||
let proj_h = curr_h.apply(&self.projection)?;
|
let proj_h = curr_h.apply(&self.projection)?;
|
||||||
let decoder_h = self.decoder.forward(&proj_h, i)?;
|
let decoder_h = self.decoder.forward(&proj_h, i)?;
|
||||||
let ci_logits = decoder_h.matmul(&self.audio_head.get(i)?)?;
|
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i)?)?;
|
||||||
let ci_sample = lp.sample(&ci_logits)?;
|
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
|
||||||
all_samples.push(ci_sample);
|
all_samples.push(ci_sample);
|
||||||
let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?;
|
let ci_sample = Tensor::from_slice(&[ci_sample], (1, 1), &self.decoder.device)?;
|
||||||
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||||
|
Reference in New Issue
Block a user