Use the fast RmsNorm in the quantized model. (#1904)

This commit is contained in:
Laurent Mazare
2024-03-21 18:49:35 +01:00
committed by GitHub
parent 9563a5fee4
commit c0bdd9c7a6
3 changed files with 21 additions and 35 deletions

View File

@ -1,5 +1,6 @@
use crate::models::with_tracing::QMatMul;
use crate::quantized_var_builder::VarBuilder;
use candle::quantized::QTensor;
use candle::{Module, Result, Tensor};
#[derive(Debug, Clone)]
@ -35,10 +36,7 @@ pub struct Linear {
}
impl Linear {
pub fn from_arc(
weight: std::sync::Arc<candle::quantized::QTensor>,
bias: Option<Tensor>,
) -> Result<Self> {
pub fn from_arc(weight: std::sync::Arc<QTensor>, bias: Option<Tensor>) -> Result<Self> {
let weight = QMatMul::from_weights(weight)?;
Ok(Self { weight, bias })
}
@ -95,7 +93,8 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm,
weight: Tensor,
eps: f64,
span: tracing::Span,
}
@ -103,14 +102,19 @@ impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let inner = candle_nn::RmsNorm::new(weight, eps);
Ok(Self { inner, span })
Ok(Self { weight, eps, span })
}
pub fn from_qtensor(weight: QTensor, eps: f64) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let weight = weight.dequantize(&weight.device())?;
Ok(Self { weight, eps, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32)
}
}