diff --git a/src/tensor.rs b/src/tensor.rs index f1e60efc..78359144 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -275,6 +275,15 @@ impl Tensor { op: "matmul", }); } + if let crate::DeviceLocation::Cuda { .. } = self.device().location() { + if !self.is_contiguous() || !rhs.is_contiguous() { + // It looks like the cublas implementation of XgemmStridedBatched only supports + // non-standard strides on the batch dimension. + return Err(Error::RequiresContiguous { + op: "matmul-cublas", + }); + } + } let m = a_dims[dim - 2]; let k = a_dims[dim - 1];