From 66f0a4eeea02f069838903a18dd6402821e43271 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 26 Mar 2024 17:05:26 +0100 Subject: [PATCH] Another fix for squeezing. (#1943) --- candle-core/src/shape.rs | 4 ++-- candle-transformers/src/models/dinov2.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 32ebb23f..567a711b 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -171,7 +171,7 @@ impl Shape { } let mut acc = 1; for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { - if stride != acc { + if dim > 1 && stride != acc { return false; } acc *= dim; @@ -186,7 +186,7 @@ impl Shape { } let mut acc = 1; for (&stride, &dim) in stride.iter().zip(self.0.iter()) { - if stride != acc { + if dim > 1 && stride != acc { return false; } acc *= dim; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 0edc8494..757aa88a 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -52,8 +52,8 @@ impl Module for Attention { .transpose(0, 1)? // 20134 .transpose(2, 3)?; // 20314 let q = (qkv.i(0)? * self.scale)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; + let k = qkv.i(1)?.contiguous()?; + let v = qkv.i(2)?.contiguous()?; let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; self.proj.forward(&attn)