Only support the contiguous case for cublas matmul.

This commit is contained in:
laurent
2023-06-22 21:39:37 +01:00
parent 2231c717d5
commit fc83d97b41

View File

@ -275,6 +275,15 @@ impl Tensor {
op: "matmul",
});
}
if let crate::DeviceLocation::Cuda { .. } = self.device().location() {
if !self.is_contiguous() || !rhs.is_contiguous() {
// It looks like the cublas implementation of XgemmStridedBatched only supports
// non-standard strides on the batch dimension.
return Err(Error::RequiresContiguous {
op: "matmul-cublas",
});
}
}
let m = a_dims[dim - 2];
let k = a_dims[dim - 1];