Use the faster rms-norm kernel for llama. (#2107)

* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
This commit is contained in:
Laurent Mazare
2024-04-22 18:52:00 +02:00
committed by GitHub
parent 618ecf5e23
commit b2e816752b
2 changed files with 18 additions and 4 deletions

View File

@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@ -105,7 +105,7 @@ impl LayerNorm {
}
}
impl crate::Module for LayerNorm {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -162,11 +162,20 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}
impl crate::Module for RmsNorm {
impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
if xs.is_contiguous() {
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
} else {
self.0.forward(xs)
}
}
}