mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Move the common quantized-nn code to a shared module. (#1063)
This commit is contained in:
@ -2,5 +2,6 @@ pub mod generation;
|
|||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod object_detection;
|
pub mod object_detection;
|
||||||
pub mod pipelines;
|
pub mod pipelines;
|
||||||
|
pub mod quantized_nn;
|
||||||
pub mod quantized_var_builder;
|
pub mod quantized_var_builder;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use crate::models::quantized_t5::Embedding;
|
use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm};
|
||||||
use crate::models::with_tracing::QMatMul;
|
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::Activation;
|
use candle_nn::Activation;
|
||||||
@ -7,44 +6,6 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
pub use crate::models::mistral::Config;
|
pub use crate::models::mistral::Config;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
weight: QMatMul,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
x.apply(&self.weight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
||||||
Ok(Linear { weight })
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct RmsNorm {
|
|
||||||
inner: candle_nn::RmsNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
|
||||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
||||||
let inner = candle_nn::RmsNorm::new(weight, eps);
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for RmsNorm {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::models::with_tracing::QMatMul;
|
use crate::quantized_nn::{layer_norm, linear, Linear};
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::Activation;
|
use candle_nn::Activation;
|
||||||
@ -9,12 +9,12 @@ const MAX_SEQ_LEN: usize = 4096;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
wte: super::quantized_t5::Embedding,
|
wte: crate::quantized_nn::Embedding,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||||
Ok(Self { wte })
|
Ok(Self { wte })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -25,37 +25,6 @@ impl Module for Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
weight: QMatMul,
|
|
||||||
bias: Option<Tensor>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let x = x.apply(&self.weight)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
||||||
Ok(Linear {
|
|
||||||
weight,
|
|
||||||
bias: Some(bias),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
|
||||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
||||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
|
||||||
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..size)
|
let mask: Vec<_> = (0..size)
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use crate::models::quantized_t5::Embedding;
|
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
||||||
use crate::models::with_tracing::QMatMul;
|
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, LayerNorm};
|
use candle_nn::{Activation, LayerNorm};
|
||||||
@ -8,28 +7,6 @@ use std::sync::Arc;
|
|||||||
pub use crate::models::stable_lm::Config;
|
pub use crate::models::stable_lm::Config;
|
||||||
use crate::models::stable_lm::RotaryEmbedding;
|
use crate::models::stable_lm::RotaryEmbedding;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
weight: QMatMul,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
x.apply(&self.weight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
||||||
Ok(Linear { weight })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
|
||||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
||||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
|
||||||
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
|
@ -2,38 +2,13 @@
|
|||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use crate::models::with_tracing::QMatMul;
|
use crate::models::with_tracing::QMatMul;
|
||||||
|
use crate::quantized_nn::Embedding;
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::Activation;
|
use candle_nn::Activation;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Embedding {
|
|
||||||
inner: candle_nn::Embedding,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
|
||||||
let inner = candle_nn::Embedding::new(embeddings, d2);
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn embeddings(&self) -> &Tensor {
|
|
||||||
self.inner.embeddings()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Embedding {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_relative_attention_max_distance() -> usize {
|
fn default_relative_attention_max_distance() -> usize {
|
||||||
128
|
128
|
||||||
}
|
}
|
||||||
|
@ -1,39 +1,9 @@
|
|||||||
use super::Config;
|
use super::Config;
|
||||||
use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
|
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
|
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
weight: QMatMul,
|
|
||||||
bias: Option<Tensor>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let x = x.apply(&self.weight)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
||||||
Ok(Linear {
|
|
||||||
weight,
|
|
||||||
bias: Some(bias),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
||||||
Ok(Linear { weight, bias: None })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
out_channels: usize,
|
out_channels: usize,
|
||||||
@ -48,12 +18,6 @@ fn conv1d(
|
|||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
|
||||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
||||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
|
||||||
Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||||
struct MultiHeadAttention {
|
struct MultiHeadAttention {
|
||||||
query: Linear,
|
query: Linear,
|
||||||
@ -178,10 +142,10 @@ impl ResidualAttentionBlock {
|
|||||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||||
let cross_attn = if ca {
|
let cross_attn = if ca {
|
||||||
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||||
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
Some((cross_attn, cross_attn_ln))
|
Some((cross_attn, cross_attn_ln))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -189,7 +153,7 @@ impl ResidualAttentionBlock {
|
|||||||
let n_mlp = n_state * 4;
|
let n_mlp = n_state * 4;
|
||||||
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||||
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||||
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn,
|
attn,
|
||||||
attn_ln,
|
attn_ln,
|
||||||
@ -281,7 +245,7 @@ impl AudioEncoder {
|
|||||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
@ -343,7 +307,7 @@ impl TextDecoder {
|
|||||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
|
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
|
||||||
let mask: Vec<_> = (0..n_ctx)
|
let mask: Vec<_> = (0..n_ctx)
|
||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
.collect();
|
.collect();
|
||||||
|
87
candle-transformers/src/quantized_nn.rs
Normal file
87
candle-transformers/src/quantized_nn.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
use crate::models::with_tracing::QMatMul;
|
||||||
|
use crate::quantized_var_builder::VarBuilder;
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Embedding {
|
||||||
|
inner: candle_nn::Embedding,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
||||||
|
let inner = candle_nn::Embedding::new(embeddings, d2);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embeddings(&self) -> &Tensor {
|
||||||
|
self.inner.embeddings()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embedding {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
weight: QMatMul,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let x = x.apply(&self.weight)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||||
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||||
|
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||||
|
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||||
|
Ok(Linear { weight, bias: None })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RmsNorm {
|
||||||
|
inner: candle_nn::RmsNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||||
|
let inner = candle_nn::RmsNorm::new(weight, eps);
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RmsNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user