mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model. * Integrate the quantized version of metavoice.
This commit is contained in:
@ -11,6 +11,7 @@ use std::io::Write;
|
|||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use candle_transformers::models::encodec;
|
use candle_transformers::models::encodec;
|
||||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
||||||
|
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Tensor};
|
use candle::{DType, IndexOp, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -26,6 +27,11 @@ enum ArgDType {
|
|||||||
Bf16,
|
Bf16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Transformer {
|
||||||
|
Normal(transformer::Model),
|
||||||
|
Quantized(qtransformer::Model),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -40,6 +46,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// Use the quantized version of the model.
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
|
|
||||||
/// The guidance scale.
|
/// The guidance scale.
|
||||||
#[arg(long, default_value_t = 3.0)]
|
#[arg(long, default_value_t = 3.0)]
|
||||||
guidance_scale: f64,
|
guidance_scale: f64,
|
||||||
@ -116,10 +126,6 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||||
|
|
||||||
let first_stage_weights = match &args.first_stage_weights {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("first_stage.safetensors")?,
|
|
||||||
};
|
|
||||||
let second_stage_weights = match &args.second_stage_weights {
|
let second_stage_weights = match &args.second_stage_weights {
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("second_stage.safetensors")?,
|
None => repo.get("second_stage.safetensors")?,
|
||||||
@ -135,10 +141,27 @@ fn main() -> Result<()> {
|
|||||||
ArgDType::F16 => DType::F16,
|
ArgDType::F16 => DType::F16,
|
||||||
ArgDType::Bf16 => DType::BF16,
|
ArgDType::Bf16 => DType::BF16,
|
||||||
};
|
};
|
||||||
let first_stage_vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
|
||||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
let mut first_stage_model = if args.quantized {
|
||||||
|
let filename = match &args.first_stage_weights {
|
||||||
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
|
None => repo.get("first_stage_q4k.gguf")?,
|
||||||
|
};
|
||||||
|
let vb =
|
||||||
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||||
|
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
|
||||||
|
Transformer::Quantized(first_stage_model)
|
||||||
|
} else {
|
||||||
|
let first_stage_weights = match &args.first_stage_weights {
|
||||||
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
|
None => repo.get("first_stage.safetensors")?,
|
||||||
|
};
|
||||||
|
let first_stage_vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||||
|
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||||
|
Transformer::Normal(first_stage_model)
|
||||||
|
};
|
||||||
|
|
||||||
let second_stage_vb =
|
let second_stage_vb =
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||||
@ -178,7 +201,12 @@ fn main() -> Result<()> {
|
|||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
let logits = match &mut first_stage_model {
|
||||||
|
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
|
||||||
|
Transformer::Quantized(m) => {
|
||||||
|
m.forward(&input, &spk_emb, tokens.len() - context_size)?
|
||||||
|
}
|
||||||
|
};
|
||||||
let logits0 = logits.i((0, 0))?;
|
let logits0 = logits.i((0, 0))?;
|
||||||
let logits1 = logits.i((1, 0))?;
|
let logits1 = logits.i((1, 0))?;
|
||||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||||
|
@ -2,7 +2,7 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
|
|||||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
|
|
||||||
// Equivalent to torch.repeat_interleave
|
// Equivalent to torch.repeat_interleave
|
||||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||||
let img = img.unsqueeze(dim + 1)?;
|
let img = img.unsqueeze(dim + 1)?;
|
||||||
let mut dims = img.dims().to_vec();
|
let mut dims = img.dims().to_vec();
|
||||||
dims[dim + 1] = repeats;
|
dims[dim + 1] = repeats;
|
||||||
@ -664,15 +664,15 @@ pub mod transformer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn n_local_heads(&self) -> usize {
|
pub(crate) fn n_local_heads(&self) -> usize {
|
||||||
self.n_local_heads.unwrap_or(self.n_head)
|
self.n_local_heads.unwrap_or(self.n_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn head_dim(&self) -> usize {
|
pub(crate) fn head_dim(&self) -> usize {
|
||||||
self.dim / self.n_head
|
self.dim / self.n_head
|
||||||
}
|
}
|
||||||
|
|
||||||
fn intermediate_size(&self) -> usize {
|
pub(crate) fn intermediate_size(&self) -> usize {
|
||||||
match self.intermediate_size {
|
match self.intermediate_size {
|
||||||
Some(intermediate_size) => intermediate_size,
|
Some(intermediate_size) => intermediate_size,
|
||||||
None => {
|
None => {
|
||||||
|
@ -30,6 +30,7 @@ pub mod quantized_blip;
|
|||||||
pub mod quantized_blip_text;
|
pub mod quantized_blip_text;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
pub mod quantized_llama2_c;
|
pub mod quantized_llama2_c;
|
||||||
|
pub mod quantized_metavoice;
|
||||||
pub mod quantized_mistral;
|
pub mod quantized_mistral;
|
||||||
pub mod quantized_mixformer;
|
pub mod quantized_mixformer;
|
||||||
pub mod quantized_mpt;
|
pub mod quantized_mpt;
|
||||||
|
226
candle-transformers/src/models/quantized_metavoice.rs
Normal file
226
candle-transformers/src/models/quantized_metavoice.rs
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
|
||||||
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
|
|
||||||
|
use crate::models::metavoice::repeat_interleave;
|
||||||
|
use candle::{Module, Result, Tensor, D};
|
||||||
|
|
||||||
|
pub mod transformer {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
type Config = crate::models::metavoice::transformer::Config;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct FeedForward {
|
||||||
|
w1: Linear,
|
||||||
|
w2: Linear,
|
||||||
|
w3: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeedForward {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let i_size = cfg.intermediate_size();
|
||||||
|
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||||
|
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||||
|
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||||
|
Ok(Self { w1, w2, w3 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for FeedForward {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||||
|
swiglu.apply(&self.w2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
wqkv: Linear,
|
||||||
|
wo: Linear,
|
||||||
|
dim: usize,
|
||||||
|
kv_size: usize,
|
||||||
|
n_local_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
n_head: usize,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let n_local_heads = cfg.n_local_heads();
|
||||||
|
let head_dim = cfg.head_dim();
|
||||||
|
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
|
||||||
|
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
|
||||||
|
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
|
||||||
|
Ok(Self {
|
||||||
|
wqkv,
|
||||||
|
wo,
|
||||||
|
dim: cfg.dim,
|
||||||
|
kv_size: n_local_heads * head_dim,
|
||||||
|
n_local_heads,
|
||||||
|
head_dim,
|
||||||
|
n_head: cfg.n_head,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let qkv = xs.apply(&self.wqkv)?;
|
||||||
|
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
|
||||||
|
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
|
||||||
|
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
|
||||||
|
let q = q
|
||||||
|
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let (k, v) = match &self.kv_cache {
|
||||||
|
None => (k, v),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let k = Tensor::cat(&[prev_k, &k], 2)?;
|
||||||
|
let v = Tensor::cat(&[prev_v, &v], 2)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
|
||||||
|
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
|
||||||
|
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
|
||||||
|
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = attn_weights.broadcast_add(mask)?;
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, seqlen, self.dim))?
|
||||||
|
.apply(&self.wo)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Block {
|
||||||
|
attention: Attention,
|
||||||
|
feed_forward: FeedForward,
|
||||||
|
ffn_norm: RmsNorm,
|
||||||
|
attention_norm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||||
|
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
|
||||||
|
let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
|
||||||
|
let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
feed_forward,
|
||||||
|
ffn_norm,
|
||||||
|
attention_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let hs = xs.apply(&self.attention_norm)?;
|
||||||
|
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||||
|
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.attention.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
tok_embeddings: Embedding,
|
||||||
|
pos_embeddings: Embedding,
|
||||||
|
speaker_cond_pos: Linear,
|
||||||
|
layers: Vec<Block>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
output: Linear,
|
||||||
|
spk_cond_mask: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
|
||||||
|
let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
|
||||||
|
let speaker_cond_pos = linear_b(
|
||||||
|
cfg.speaker_emb_dim,
|
||||||
|
cfg.dim,
|
||||||
|
false,
|
||||||
|
vb.pp("speaker_cond_pos"),
|
||||||
|
)?;
|
||||||
|
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.n_layer {
|
||||||
|
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||||
|
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
|
||||||
|
let spk_cond_mask = Tensor::cat(
|
||||||
|
&[
|
||||||
|
Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
|
||||||
|
Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
tok_embeddings,
|
||||||
|
pos_embeddings,
|
||||||
|
speaker_cond_pos,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
output,
|
||||||
|
spk_cond_mask,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||||
|
let (_b_sz, seqlen) = xs.dims2()?;
|
||||||
|
let mask: Vec<_> = (0..seqlen)
|
||||||
|
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
|
||||||
|
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
|
||||||
|
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
|
||||||
|
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
|
||||||
|
let mut xs = tok_embeddings
|
||||||
|
.broadcast_add(&pos_embeddings)?
|
||||||
|
.broadcast_add(
|
||||||
|
&spk_emb
|
||||||
|
.apply(&self.speaker_cond_pos)?
|
||||||
|
.broadcast_mul(&self.spk_cond_mask)?,
|
||||||
|
)?;
|
||||||
|
let mask = mask.to_dtype(xs.dtype())?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, pos, &mask)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seqlen - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -50,6 +50,16 @@ impl Module for Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let bias = if bias {
|
||||||
|
Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear { weight, bias })
|
||||||
|
}
|
||||||
|
|
||||||
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Reference in New Issue
Block a user