mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.8.3
...
linear-tra
Author | SHA1 | Date | |
---|---|---|---|
a35a935118 |
@ -27,13 +27,14 @@ pub struct Linear {
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let weight = weight.t().unwrap().contiguous().unwrap();
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?,
|
||||
_ => self.weight.clone(),
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
|
Reference in New Issue
Block a user