Start adding the cublas based matmul.

This commit is contained in:
laurent
2023-06-22 18:45:10 +01:00
parent 683730c21d
commit cc78900922
3 changed files with 89 additions and 4 deletions

View File

@ -132,12 +132,13 @@ impl Storage {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(_), Self::Cuda(_)) => {
todo!()
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),