Sketch the candle-nn crate. (#115)

* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
This commit is contained in:
Laurent Mazare
2023-07-10 08:50:09 +01:00
committed by GitHub
parent bc3be6f9b0
commit 9ce0f1c010
13 changed files with 230 additions and 315 deletions

View File

@ -0,0 +1,34 @@
use candle::{DType, Result, Tensor};
// This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}