Move the common quantized-nn code to a shared module. (#1063)

This commit is contained in:
Laurent Mazare
2023-10-09 06:22:22 +01:00
committed by GitHub
parent 59ab6d7832
commit 392fe02fba
7 changed files with 100 additions and 166 deletions

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::{layer_norm, linear, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::Activation;
@ -9,12 +9,12 @@ const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug)]
struct Embedding {
wte: super::quantized_t5::Embedding,
wte: crate::quantized_nn::Embedding,
}
impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
Ok(Self { wte })
}
}
@ -25,37 +25,6 @@ impl Module for Embedding {
}
}
#[derive(Debug)]
struct Linear {
weight: QMatMul,
bias: Option<Tensor>,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let x = x.apply(&self.weight)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear {
weight,
bias: Some(bias),
})
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))