Use the fast RmsNorm in the quantized model. (#1904)

This commit is contained in:
Laurent Mazare
2024-03-21 18:49:35 +01:00
committed by GitHub
parent 9563a5fee4
commit c0bdd9c7a6
3 changed files with 21 additions and 35 deletions

View File

@ -1,5 +1,6 @@
use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
@ -7,26 +8,6 @@ use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&scale.device())?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
// QMatMul wrapper adding some tracing.
#[derive(Debug, Clone)]
struct QMatMul {
@ -301,7 +282,7 @@ impl ModelWeights {
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
@ -330,9 +311,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
@ -381,7 +362,7 @@ impl ModelWeights {
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f64()?;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
@ -391,7 +372,7 @@ impl ModelWeights {
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
@ -450,9 +431,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,

View File

@ -327,6 +327,7 @@ impl Model {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.contiguous()?
.apply(&self.norm)?
.apply(&self.lm_head)
}

View File

@ -1,5 +1,6 @@
use crate::models::with_tracing::QMatMul;
use crate::quantized_var_builder::VarBuilder;
use candle::quantized::QTensor;
use candle::{Module, Result, Tensor};
#[derive(Debug, Clone)]
@ -35,10 +36,7 @@ pub struct Linear {
}
impl Linear {
pub fn from_arc(
weight: std::sync::Arc<candle::quantized::QTensor>,
bias: Option<Tensor>,
) -> Result<Self> {
pub fn from_arc(weight: std::sync::Arc<QTensor>, bias: Option<Tensor>) -> Result<Self> {
let weight = QMatMul::from_weights(weight)?;
Ok(Self { weight, bias })
}
@ -95,7 +93,8 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm,
weight: Tensor,
eps: f64,
span: tracing::Span,
}
@ -103,14 +102,19 @@ impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let inner = candle_nn::RmsNorm::new(weight, eps);
Ok(Self { inner, span })
Ok(Self { weight, eps, span })
}
pub fn from_qtensor(weight: QTensor, eps: f64) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let weight = weight.dequantize(&weight.device())?;
Ok(Self { weight, eps, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32)
}
}