Get the cpu backend to compile.

This commit is contained in:
laurent
2023-06-28 14:12:38 +01:00
parent 54a6c40f27
commit 14449ff80c
5 changed files with 44 additions and 59 deletions

View File

@ -79,6 +79,7 @@ impl Storage {
}
}
// This assumes a contiguous layout and no offset.
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
@ -196,22 +197,22 @@ impl Storage {
}
}
pub(crate) fn matmul_impl(
pub(crate) fn matmul(
&self,
rhs: &Self,
bmnk: (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
lhs_layout: &Layout,
rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {