Use the faster rms-norm kernel for llama. (#2107)

* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
This commit is contained in:
Laurent Mazare
2024-04-22 18:52:00 +02:00
committed by GitHub
parent 618ecf5e23
commit b2e816752b
2 changed files with 18 additions and 4 deletions

View File

@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@ -105,7 +105,7 @@ impl LayerNorm {
}
}
impl crate::Module for LayerNorm {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -162,13 +162,22 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}
impl crate::Module for RmsNorm {
impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
if xs.is_contiguous() {
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
} else {
self.0.forward(xs)
}
}
}
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
let config = LayerNormConfig {

View File

@ -180,6 +180,11 @@ impl RmsNorm {
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward_diff(x)
}
}
impl Module for RmsNorm {