Add some very basic backprop.

This commit is contained in:
laurent
2023-06-20 20:33:44 +01:00
parent 3b7984ccce
commit c4c303b6f1
4 changed files with 112 additions and 5 deletions

View File

@ -56,6 +56,25 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
}
impl Device {
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {
let elem_count = shape.elem_count();
let storage = match dtype {
DType::F32 => {
let data = vec![1f32; elem_count];
CpuStorage::F32(data)
}
DType::F64 => {
let data = vec![1f64; elem_count];
CpuStorage::F64(data)
}
};
Storage::Cpu(storage)
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
match self {
Device::Cpu => {

View File

@ -1,6 +1,5 @@
use crate::Tensor;
#[allow(dead_code)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),

View File

@ -33,6 +33,12 @@ impl From<&[usize]> for Shape {
}
}
impl From<&Shape> for Shape {
fn from(shape: &Shape) -> Self {
Self(shape.0.to_vec())
}
}
impl From<()> for Shape {
fn from(_: ()) -> Self {
Self(vec![])

View File

@ -1,4 +1,5 @@
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
use std::collections::HashMap;
use std::sync::Arc;
/// Unique identifier for tensors.
@ -14,7 +15,6 @@ impl TensorId {
}
}
#[allow(dead_code)]
pub struct Tensor_ {
id: TensorId,
storage: Storage,
@ -81,6 +81,38 @@ macro_rules! binary_op {
}
impl Tensor {
fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: Device,
is_variable: bool,
) -> Self {
let shape = shape.into();
let storage = device.ones(&shape, dtype);
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op: None,
is_variable,
};
Self(Arc::new(tensor_))
}
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::ones_impl(shape, dtype, device, false)
}
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::ones_impl(shape, dtype, device, true)
}
pub fn ones_like(&self) -> Self {
Tensor::ones(self.shape(), self.dtype(), self.device())
}
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
@ -109,6 +141,10 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, true)
}
pub fn zeros_like(&self) -> Self {
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
device: Device,
@ -246,9 +282,7 @@ impl Tensor {
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
use std::collections::HashMap;
fn sorted_nodes(&self) -> Vec<&Tensor> {
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(
@ -292,4 +326,53 @@ impl Tensor {
nodes.reverse();
nodes
}
pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> {
let sorted_nodes = self.sorted_nodes();
let mut grads = HashMap::new();
grads.insert(self.id, self.ones_like());
for node in sorted_nodes.iter() {
if node.is_variable {
continue;
}
let grad = grads.remove(&node.id).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph).
// The only drawback would be if we wanted to support grad of grad but this is out of
// scope.
if let Some(op) = &node.op {
match op {
Op::Add(lhs, rhs) => {
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Sqr(_arg) => {
todo!()
// TODO: Add scaling by a constant to enable the following.
// let arg_grad = 2 * arg * grad;
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?
}
Op::Sqrt(_arg) => {
todo!()
// TODO: Add scaling by a constant and divide to enable the following.
// let arg_grad = grad / (2 * arg)
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?
}
};
}
}
Ok(grads)
}
}