mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Move the common quantized-nn code to a shared module. (#1063)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::{layer_norm, linear, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::Activation;
|
||||
@ -9,12 +9,12 @@ const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding {
|
||||
wte: super::quantized_t5::Embedding,
|
||||
wte: crate::quantized_nn::Embedding,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||
let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||
Ok(Self { wte })
|
||||
}
|
||||
}
|
||||
@ -25,37 +25,6 @@ impl Module for Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: QMatMul,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let x = x.apply(&self.weight)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
})
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
|
Reference in New Issue
Block a user