Move the common quantized-nn code to a shared module. (#1063)

This commit is contained in:
Laurent Mazare
2023-10-09 06:22:22 +01:00
committed by GitHub
parent 59ab6d7832
commit 392fe02fba
7 changed files with 100 additions and 166 deletions

View File

@ -1,5 +1,4 @@
use crate::models::quantized_t5::Embedding;
use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm};
@ -8,28 +7,6 @@ use std::sync::Arc;
pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)]
struct Linear {
weight: QMatMul,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
x.apply(&self.weight)
}
}
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear { weight })
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
}
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {