Adding matmul?

This commit is contained in:
Nicolas Patry
2023-06-21 16:52:35 +02:00
parent 87a37b3bf3
commit ce977b489e
7 changed files with 243 additions and 3 deletions

View File

@ -241,4 +241,22 @@ impl Storage {
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
bmnk: (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
_ => todo!(),
}
}
}