Add some very basic backprop.

This commit is contained in:
laurent
2023-06-20 20:33:44 +01:00
parent 3b7984ccce
commit c4c303b6f1
4 changed files with 112 additions and 5 deletions

View File

@ -56,6 +56,25 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
}
impl Device {
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {
let elem_count = shape.elem_count();
let storage = match dtype {
DType::F32 => {
let data = vec![1f32; elem_count];
CpuStorage::F32(data)
}
DType::F64 => {
let data = vec![1f64; elem_count];
CpuStorage::F64(data)
}
};
Storage::Cpu(storage)
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {