Use the faster rms-norm kernel for llama. (#2107)

* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
This commit is contained in:
Laurent Mazare
2024-04-22 18:52:00 +02:00
committed by GitHub
parent 618ecf5e23
commit b2e816752b
2 changed files with 18 additions and 4 deletions

View File

@ -28,7 +28,7 @@
//! ``` //! ```
//! //!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor, D}; use candle::{DType, Module, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig { pub struct LayerNormConfig {
@ -105,7 +105,7 @@ impl LayerNorm {
} }
} }
impl crate::Module for LayerNorm { impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype(); let x_dtype = x.dtype();
let internal_dtype = match x_dtype { let internal_dtype = match x_dtype {
@ -162,11 +162,20 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm { pub fn into_inner(self) -> LayerNorm {
self.0 self.0
} }
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
} }
impl crate::Module for RmsNorm { impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs) if xs.is_contiguous() {
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
} else {
self.0.forward(xs)
}
} }
} }

View File

@ -180,6 +180,11 @@ impl RmsNorm {
let inner = candle_nn::rms_norm(size, eps, vb)?; let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span }) Ok(Self { inner, span })
} }
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward_diff(x)
}
} }
impl Module for RmsNorm { impl Module for RmsNorm {