mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Another fix for squeezing. (#1943)
This commit is contained in:
@ -171,7 +171,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -186,7 +186,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||||
if stride != acc {
|
if dim > 1 && stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
|
@ -52,8 +52,8 @@ impl Module for Attention {
|
|||||||
.transpose(0, 1)? // 20134
|
.transpose(0, 1)? // 20134
|
||||||
.transpose(2, 3)?; // 20314
|
.transpose(2, 3)?; // 20314
|
||||||
let q = (qkv.i(0)? * self.scale)?;
|
let q = (qkv.i(0)? * self.scale)?;
|
||||||
let k = qkv.i(1)?;
|
let k = qkv.i(1)?.contiguous()?;
|
||||||
let v = qkv.i(2)?;
|
let v = qkv.i(2)?.contiguous()?;
|
||||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||||
self.proj.forward(&attn)
|
self.proj.forward(&attn)
|
||||||
|
Reference in New Issue
Block a user