Add some more developed training examples. (#199)

* Use contiguous tensors for variables.

* Sketch the mnist example.

* Start adding the reduce ops.

* Renaming.

* Refactor the reduce operations.

* Bugfix for the broadcasting vectorization.
This commit is contained in:
Laurent Mazare
2023-07-19 16:37:52 +02:00
committed by GitHub
parent 67e20c3792
commit cb687b4897
10 changed files with 232 additions and 65 deletions

View File

@ -80,14 +80,19 @@ impl Storage {
}
}
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
pub(crate) fn reduce_op(
&self,
op: crate::op::ReduceOp,
layout: &Layout,
s: &[usize],
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(layout, s)?;
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(layout, s)?;
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
}