mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
25
candle-nn/src/linear.rs
Normal file
25
candle-nn/src/linear.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use candle::Tensor;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user