Sketch the candle-nn crate. (#115)

* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
This commit is contained in:
Laurent Mazare
2023-07-10 08:50:09 +01:00
committed by GitHub
parent bc3be6f9b0
commit 9ce0f1c010
13 changed files with 230 additions and 315 deletions

25
candle-nn/src/linear.rs Normal file
View File

@ -0,0 +1,25 @@
use candle::Tensor;
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}