mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the quantized mixformer model. (#953)
* Add the quantized mixformer model. * Add the quantized option in the phi example.
This commit is contained in:
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -15,6 +16,11 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
MixFormer(MixFormer),
|
||||||
|
Quantized(QMixFormer),
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -58,7 +64,10 @@ impl TextGeneration {
|
|||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = self.model.forward(&input)?;
|
let logits = match &mut self.model {
|
||||||
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
@ -115,6 +124,9 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_file: Option<String>,
|
weight_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
quantized: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -150,10 +162,18 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
let (model, device) = if args.quantized {
|
||||||
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||||
|
let config = Config::v1_5();
|
||||||
|
let model = QMixFormer::new(&config, vb)?;
|
||||||
|
(Model::Quantized(model), Device::Cpu)
|
||||||
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let config = Config::v1_5();
|
let config = Config::v1_5();
|
||||||
let model = Model::new(&config, vb)?;
|
let model = MixFormer::new(&config, vb)?;
|
||||||
|
(Model::MixFormer(model), device)
|
||||||
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let mut pipeline = TextGeneration::new(
|
||||||
|
@ -10,17 +10,17 @@ const MAX_SEQ_LEN: usize = 4096;
|
|||||||
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
pub(crate) vocab_size: usize,
|
||||||
n_positions: usize,
|
pub(crate) n_positions: usize,
|
||||||
n_embd: usize,
|
pub(crate) n_embd: usize,
|
||||||
n_layer: usize,
|
pub(crate) n_layer: usize,
|
||||||
n_inner: Option<usize>,
|
pub(crate) n_inner: Option<usize>,
|
||||||
n_head: usize,
|
pub(crate) n_head: usize,
|
||||||
rotary_dim: usize,
|
pub(crate) rotary_dim: usize,
|
||||||
activation_function: Activation,
|
pub(crate) activation_function: Activation,
|
||||||
layer_norm_epsilon: f64,
|
pub(crate) layer_norm_epsilon: f64,
|
||||||
tie_word_embeddings: bool,
|
pub(crate) tie_word_embeddings: bool,
|
||||||
pad_vocab_size_multiple: usize,
|
pub(crate) pad_vocab_size_multiple: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -6,6 +6,7 @@ pub mod falcon;
|
|||||||
pub mod llama;
|
pub mod llama;
|
||||||
pub mod mixformer;
|
pub mod mixformer;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
|
pub mod quantized_mixformer;
|
||||||
pub mod quantized_t5;
|
pub mod quantized_t5;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
|
344
candle-transformers/src/models/quantized_mixformer.rs
Normal file
344
candle-transformers/src/models/quantized_mixformer.rs
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
use crate::models::with_tracing::QMatMul;
|
||||||
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::Activation;
|
||||||
|
|
||||||
|
pub use crate::models::mixformer::Config;
|
||||||
|
|
||||||
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Embedding {
|
||||||
|
wte: super::quantized_t5::Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||||
|
Ok(Self { wte })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embedding {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.wte.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Linear {
|
||||||
|
weight: QMatMul,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let x = x.apply(&self.weight)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||||
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||||
|
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||||
|
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size, size), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
|
let shape = mask.shape();
|
||||||
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
qkv: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
|
let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
|
||||||
|
if three != 3 {
|
||||||
|
candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
|
||||||
|
}
|
||||||
|
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
||||||
|
let rotary_dim = rotary_dim * 2;
|
||||||
|
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
|
||||||
|
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
|
||||||
|
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
|
||||||
|
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
|
||||||
|
let q12 = q_rot.chunk(2, D::Minus1)?;
|
||||||
|
let k12 = k_rot.chunk(2, D::Minus1)?;
|
||||||
|
let (q1, q2) = (&q12[0], &q12[1]);
|
||||||
|
let (k1, k2) = (&k12[0], &k12[1]);
|
||||||
|
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||||
|
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||||
|
let q_rot = Tensor::cat(
|
||||||
|
&[
|
||||||
|
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
|
||||||
|
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
|
||||||
|
],
|
||||||
|
D::Minus1,
|
||||||
|
)?;
|
||||||
|
let k_rot = Tensor::cat(
|
||||||
|
&[
|
||||||
|
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
|
||||||
|
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
|
||||||
|
],
|
||||||
|
D::Minus1,
|
||||||
|
)?;
|
||||||
|
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
|
||||||
|
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
|
||||||
|
let v = qkv.i((.., .., 2))?;
|
||||||
|
Ok((q, k, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
act: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
|
||||||
|
let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
||||||
|
let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
act: cfg.activation_function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct CausalLMHead {
|
||||||
|
ln: candle_nn::LayerNorm,
|
||||||
|
linear: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalLMHead {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
|
let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
||||||
|
Ok(Self { ln, linear })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for CausalLMHead {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.ln)?
|
||||||
|
.apply(&self.linear)?
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MHA {
|
||||||
|
wqkv: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
rotary_emb: RotaryEmbedding,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
head_dim: usize,
|
||||||
|
n_head: usize,
|
||||||
|
softmax_scale: f64,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MHA {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let head_dim = cfg.n_embd / cfg.n_head;
|
||||||
|
let op_size = cfg.n_embd;
|
||||||
|
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||||
|
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||||
|
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||||
|
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
|
Ok(Self {
|
||||||
|
wqkv,
|
||||||
|
out_proj,
|
||||||
|
head_dim,
|
||||||
|
n_head: cfg.n_head,
|
||||||
|
kv_cache: None,
|
||||||
|
rotary_emb,
|
||||||
|
softmax_scale,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||||
|
let qkv = self
|
||||||
|
.wqkv
|
||||||
|
.forward(xs)?
|
||||||
|
.reshape((b_size, seq_len, 3, (), self.head_dim))?;
|
||||||
|
let seqlen_offset = match &self.kv_cache {
|
||||||
|
None => 0,
|
||||||
|
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||||
|
};
|
||||||
|
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||||
|
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||||
|
let (k, v) = match &self.kv_cache {
|
||||||
|
None => (k, v),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let k = Tensor::cat(&[prev_k, &k], 1)?;
|
||||||
|
let v = Tensor::cat(&[prev_v, &v], 1)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
// scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
||||||
|
let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d
|
||||||
|
let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
|
||||||
|
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
|
||||||
|
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
|
||||||
|
|
||||||
|
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
|
||||||
|
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
|
let attn_weights = match mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => masked_fill(
|
||||||
|
&attn_weights,
|
||||||
|
&mask.broadcast_left(b_size * self.n_head)?,
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||||
|
|
||||||
|
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||||
|
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||||
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
|
// b*h,t,d
|
||||||
|
let attn_output = attn_output
|
||||||
|
.reshape((b_size, (), seq_len, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.flatten_from(D::Minus2)?;
|
||||||
|
attn_output.apply(&self.out_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ParallelBlock {
|
||||||
|
ln: candle_nn::LayerNorm,
|
||||||
|
mixer: MHA,
|
||||||
|
mlp: MLP,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParallelBlock {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
|
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
Ok(Self {
|
||||||
|
ln,
|
||||||
|
mixer,
|
||||||
|
mlp,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.ln)?;
|
||||||
|
let attn_outputs = self.mixer.forward(&xs, mask)?;
|
||||||
|
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||||
|
attn_outputs + feed_forward_hidden_states + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MixFormerSequentialForCausalLM {
|
||||||
|
embedding: Embedding,
|
||||||
|
blocks: Vec<ParallelBlock>,
|
||||||
|
head: CausalLMHead,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MixFormerSequentialForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb = vb.pp("layers");
|
||||||
|
let embedding = Embedding::new(cfg, vb.pp(0))?;
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
for i in 0..cfg.n_layer {
|
||||||
|
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
|
||||||
|
Ok(Self {
|
||||||
|
embedding,
|
||||||
|
blocks,
|
||||||
|
head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (_b_size, seq_len) = xs.dims2()?;
|
||||||
|
let mut xs = xs.apply(&self.embedding)?;
|
||||||
|
let mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(get_mask(seq_len, xs.device())?)
|
||||||
|
};
|
||||||
|
for block in self.blocks.iter_mut() {
|
||||||
|
xs = block.forward(&xs, mask.as_ref())?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
// T5 Text Model, quantized version
|
// T5 Text Model, quantized version
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
use crate::models::with_tracing::QMatMul;
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::Activation;
|
use candle_nn::Activation;
|
||||||
@ -8,20 +9,20 @@ use serde::Deserialize;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Embedding {
|
pub struct Embedding {
|
||||||
inner: candle_nn::Embedding,
|
inner: candle_nn::Embedding,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
||||||
let inner = candle_nn::Embedding::new(embeddings, d2);
|
let inner = candle_nn::Embedding::new(embeddings, d2);
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embeddings(&self) -> &Tensor {
|
pub fn embeddings(&self) -> &Tensor {
|
||||||
self.inner.embeddings()
|
self.inner.embeddings()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,34 +34,6 @@ impl Module for Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// QMatMul wrapper adding some tracing.
|
|
||||||
struct QMatMul {
|
|
||||||
inner: candle::quantized::QMatMul,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMatMul {
|
|
||||||
fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
|
||||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for QMatMul {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "QMatMul")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_relative_attention_max_distance() -> usize {
|
fn default_relative_attention_max_distance() -> usize {
|
||||||
128
|
128
|
||||||
}
|
}
|
||||||
|
@ -76,3 +76,35 @@ pub fn conv2d(
|
|||||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||||
Ok(Conv2d { inner, span })
|
Ok(Conv2d { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QMatMul wrapper adding some tracing.
|
||||||
|
pub struct QMatMul {
|
||||||
|
inner: candle::quantized::QMatMul,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMatMul {
|
||||||
|
pub fn new(
|
||||||
|
out_dim: usize,
|
||||||
|
in_dim: usize,
|
||||||
|
vb: crate::quantized_var_builder::VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||||
|
let inner = candle::quantized::QMatMul::from_arc(ws);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for QMatMul {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for QMatMul {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "QMatMul")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user