Another fix for squeezing. (#1943)

This commit is contained in:
Laurent Mazare
2024-03-26 17:05:26 +01:00
committed by GitHub
parent 4523ecfb2a
commit 66f0a4eeea
2 changed files with 4 additions and 4 deletions

View File

@ -171,7 +171,7 @@ impl Shape {
} }
let mut acc = 1; let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if stride != acc { if dim > 1 && stride != acc {
return false; return false;
} }
acc *= dim; acc *= dim;
@ -186,7 +186,7 @@ impl Shape {
} }
let mut acc = 1; let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()) { for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
if stride != acc { if dim > 1 && stride != acc {
return false; return false;
} }
acc *= dim; acc *= dim;

View File

@ -52,8 +52,8 @@ impl Module for Attention {
.transpose(0, 1)? // 20134 .transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314 .transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?; let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?; let k = qkv.i(1)?.contiguous()?;
let v = qkv.i(2)?; let v = qkv.i(2)?.contiguous()?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn) self.proj.forward(&attn)