mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some more developed training examples. (#199)
* Use contiguous tensors for variables. * Sketch the mnist example. * Start adding the reduce ops. * Renaming. * Refactor the reduce operations. * Bugfix for the broadcasting vectorization.
This commit is contained in:
@ -80,14 +80,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn reduce_op(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
s: &[usize],
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user