From fc83d97b41ef1d3b35e42eb4cce1a4d165279d94 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 21:39:37 +0100 Subject: [PATCH] Only support the contiguous case for cublas matmul. --- src/tensor.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/tensor.rs b/src/tensor.rs index f1e60efc..78359144 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -275,6 +275,15 @@ impl Tensor { op: "matmul", }); } + if let crate::DeviceLocation::Cuda { .. } = self.device().location() { + if !self.is_contiguous() || !rhs.is_contiguous() { + // It looks like the cublas implementation of XgemmStridedBatched only supports + // non-standard strides on the batch dimension. + return Err(Error::RequiresContiguous { + op: "matmul-cublas", + }); + } + } let m = a_dims[dim - 2]; let k = a_dims[dim - 1];