mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.9.1
...
linear-tra
Author | SHA1 | Date | |
---|---|---|---|
a35a935118 |
@ -27,13 +27,14 @@ pub struct Linear {
|
|||||||
|
|
||||||
impl Linear {
|
impl Linear {
|
||||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||||
|
let weight = weight.t().unwrap().contiguous().unwrap();
|
||||||
Self { weight, bias }
|
Self { weight, bias }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
let w = match x.dims() {
|
let w = match x.dims() {
|
||||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?,
|
||||||
_ => self.weight.t()?,
|
_ => self.weight.clone(),
|
||||||
};
|
};
|
||||||
let x = x.matmul(&w)?;
|
let x = x.matmul(&w)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
|
Reference in New Issue
Block a user