Use the faster rms-norm kernel for llama. (#2107)

* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
This commit is contained in:
Laurent Mazare
2024-04-22 18:52:00 +02:00
committed by GitHub
parent 618ecf5e23
commit b2e816752b
2 changed files with 18 additions and 4 deletions

View File

@ -180,6 +180,11 @@ impl RmsNorm {
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward_diff(x)
}
}
impl Module for RmsNorm {