mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama. * Use the fast variant by default.
This commit is contained in:
@ -180,6 +180,11 @@ impl RmsNorm {
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward_diff(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
|
Reference in New Issue
Block a user