diff --git a/src/backprop.rs b/src/backprop.rs new file mode 100644 index 00000000..ca463cdb --- /dev/null +++ b/src/backprop.rs @@ -0,0 +1,290 @@ +use crate::{op::Op, Error, Result, Tensor, TensorId}; +use std::collections::HashMap; + +impl Tensor { + /// Return all the nodes that lead to this value in a topologically sorted vec, the first + /// elements having dependencies on the latter ones, e.g. the first element if any is the + /// argument. + /// This assumes that the op graph is a DAG. + fn sorted_nodes(&self) -> Vec<&Tensor> { + // The vec of sorted nodes is passed as an owned value rather than a mutable reference + // to get around some lifetime limitations. + fn walk<'a>( + node: &'a Tensor, + nodes: Vec<&'a Tensor>, + already_seen: &mut HashMap, + ) -> (bool, Vec<&'a Tensor>) { + if let Some(&tg) = already_seen.get(&node.id()) { + return (tg, nodes); + } + let mut track_grad = false; + let mut nodes = if node.is_variable() { + // Do not call recursively on the "leaf" nodes. + track_grad = true; + nodes + } else if let Some(op) = node.op() { + match op { + Op::Add(lhs, rhs) + | Op::Mul(lhs, rhs) + | Op::Sub(lhs, rhs) + | Op::Div(lhs, rhs) + | Op::BroadcastAdd(lhs, rhs) + | Op::BroadcastMul(lhs, rhs) + | Op::BroadcastSub(lhs, rhs) + | Op::BroadcastDiv(lhs, rhs) + | Op::Embedding(lhs, rhs) + | Op::Matmul(lhs, rhs) => { + let (tg, nodes) = walk(lhs, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(rhs, nodes, already_seen); + track_grad |= tg; + nodes + } + Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + }), + Op::Affine { arg, mul, .. } => { + if *mul == 0. { + nodes + } else { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + } + } + Op::Reshape(node) + | Op::Broadcast(node) + | Op::ToDType(node) + | Op::ToDevice(node) + | Op::Transpose(node, _, _) + | Op::Narrow(node, _, _, _) + | Op::Softmax(node, _) + | Op::Sqr(node) + | Op::Sqrt(node) + | Op::Gelu(node) + | Op::Exp(node) + | Op::Log(node) + | Op::Sin(node) + | Op::Cos(node) + | Op::Abs(node) + | Op::Neg(node) => { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } + } + } else { + nodes + }; + already_seen.insert(node.id(), track_grad); + if track_grad { + nodes.push(node); + } + (track_grad, nodes) + } + let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new()); + nodes.reverse(); + nodes + } + + pub fn backward(&self) -> Result { + let sorted_nodes = self.sorted_nodes(); + let mut grads = GradStore::new(); + grads.insert(self, self.ones_like()?); + for node in sorted_nodes.iter() { + if node.is_variable() { + continue; + } + let grad = grads.remove(node).unwrap(); + // TODO: We should perform all these operations in place (or at least not track the + // whole graph). + // The only drawback would be if we wanted to support grad of grad but this is out of + // scope. + if let Some(op) = node.op() { + match op { + Op::Add(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&grad)?; + } + Op::Sub(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; + } + Op::Mul(lhs, rhs) => { + let lhs_grad = grad.mul(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::Div(lhs, rhs) => { + let lhs_grad = grad.div(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::BroadcastAdd(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?; + } + Op::BroadcastSub(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; + } + Op::BroadcastMul(lhs, rhs) => { + let lhs_grad = grad.broadcast_mul(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; + let rhs_grad = grad.broadcast_mul(lhs)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; + } + Op::BroadcastDiv(lhs, rhs) => { + let lhs_grad = grad.broadcast_div(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; + let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; + } + Op::Embedding(_lhs, _rhs) => { + return Err(Error::BackwardNotSupported { op: "embedding" }) + } + Op::Matmul(lhs, rhs) => { + // Skipping checks, the op went ok, we can skip + // the matmul size checks for now. + + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::Cat(args, dim) => { + let mut start_idx = 0; + for arg in args { + let len = arg.dims()[*dim]; + let arg_grad = grad.narrow(*dim, start_idx, len)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)?; + start_idx += len; + } + } + Op::Broadcast(_arg) => { + return Err(Error::BackwardNotSupported { op: "broadcast" }) + } + Op::ToDType(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? + } + Op::Affine { arg, mul, .. } => { + let arg_grad = grad.affine(*mul, 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Log(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * *node)?)? + } + Op::Sin(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * arg.cos())?)? + } + Op::Cos(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? + } + Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }), + Op::Exp(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad / arg)?)? + } + Op::Neg(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.sub(&grad)? + } + Op::Narrow(_arg, _, _, _) => { + return Err(Error::BackwardNotSupported { op: "narrow" }) + } + Op::Softmax(_arg, _) => { + return Err(Error::BackwardNotSupported { op: "softmax" }) + } + Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), + Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), + Op::Sqr(arg) => { + let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Sqrt(arg) => { + let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::ToDevice(arg) => { + let sum_grad = grads.or_insert(arg)?; + let arg_grad = grad.to_device(&sum_grad.device())?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Transpose(arg, dim1, dim2) => { + let arg_grad = grad.transpose(*dim1, *dim2)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + }; + } + } + Ok(grads) + } +} + +pub struct GradStore(HashMap); + +impl GradStore { + fn new() -> Self { + GradStore(HashMap::new()) + } + + pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { + self.0.get(&id) + } + + pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { + self.0.get(&tensor.id()) + } + + pub fn remove(&mut self, tensor: &Tensor) -> Option { + self.0.remove(&tensor.id()) + } + + pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option { + self.0.insert(tensor.id(), grad) + } + + fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { + use std::collections::hash_map::Entry; + let grad = match self.0.entry(tensor.id()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let grad = tensor.zeros_like()?; + entry.insert(grad) + } + }; + Ok(grad) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5032de6b..d34c5983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod backprop; mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; diff --git a/src/tensor.rs b/src/tensor.rs index 17a9d0ae..5d883329 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,5 +1,4 @@ use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; -use std::collections::HashMap; use std::sync::Arc; /// Unique identifier for tensors. @@ -541,6 +540,14 @@ impl Tensor { self.id } + pub fn is_variable(&self) -> bool { + self.is_variable + } + + pub(crate) fn op(&self) -> &Option { + &self.op + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result { @@ -819,256 +826,6 @@ impl Tensor { } Ok(from_storage(storage, shape, op, false)) } - - /// Return all the nodes that lead to this value in a topologically sorted vec, the first - /// elements having dependencies on the latter ones, e.g. the first element if any is the - /// argument. - /// This assumes that the op graph is a DAG. - fn sorted_nodes(&self) -> Vec<&Tensor> { - // The vec of sorted nodes is passed as an owned value rather than a mutable reference - // to get around some lifetime limitations. - fn walk<'a>( - node: &'a Tensor, - nodes: Vec<&'a Tensor>, - already_seen: &mut HashMap, - ) -> (bool, Vec<&'a Tensor>) { - if let Some(&tg) = already_seen.get(&node.id) { - return (tg, nodes); - } - let mut track_grad = false; - let mut nodes = if node.is_variable { - // Do not call recursively on the "leaf" nodes. - track_grad = true; - nodes - } else if let Some(op) = &node.op { - match op { - Op::Add(lhs, rhs) - | Op::Mul(lhs, rhs) - | Op::Sub(lhs, rhs) - | Op::Div(lhs, rhs) - | Op::BroadcastAdd(lhs, rhs) - | Op::BroadcastMul(lhs, rhs) - | Op::BroadcastSub(lhs, rhs) - | Op::BroadcastDiv(lhs, rhs) - | Op::Embedding(lhs, rhs) - | Op::Matmul(lhs, rhs) => { - let (tg, nodes) = walk(lhs, nodes, already_seen); - track_grad |= tg; - let (tg, nodes) = walk(rhs, nodes, already_seen); - track_grad |= tg; - nodes - } - Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| { - let (tg, nodes) = walk(arg, nodes, already_seen); - track_grad |= tg; - nodes - }), - Op::Affine { arg, mul, .. } => { - if *mul == 0. { - nodes - } else { - let (tg, nodes) = walk(arg, nodes, already_seen); - track_grad |= tg; - nodes - } - } - Op::Reshape(node) - | Op::Broadcast(node) - | Op::ToDType(node) - | Op::ToDevice(node) - | Op::Transpose(node, _, _) - | Op::Narrow(node, _, _, _) - | Op::Softmax(node, _) - | Op::Sqr(node) - | Op::Sqrt(node) - | Op::Gelu(node) - | Op::Exp(node) - | Op::Log(node) - | Op::Sin(node) - | Op::Cos(node) - | Op::Abs(node) - | Op::Neg(node) => { - let (tg, nodes) = walk(node, nodes, already_seen); - track_grad |= tg; - nodes - } - } - } else { - nodes - }; - already_seen.insert(node.id, track_grad); - if track_grad { - nodes.push(node); - } - (track_grad, nodes) - } - let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new()); - nodes.reverse(); - nodes - } - - pub fn backward(&self) -> Result { - let sorted_nodes = self.sorted_nodes(); - let mut grads = GradStore::new(); - grads.insert(self, self.ones_like()?); - for node in sorted_nodes.iter() { - if node.is_variable { - continue; - } - let grad = grads.remove(node).unwrap(); - // TODO: We should perform all these operations in place (or at least not track the - // whole graph). - // The only drawback would be if we wanted to support grad of grad but this is out of - // scope. - if let Some(op) = &node.op { - match op { - Op::Add(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&grad)?; - } - Op::Sub(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; - } - Op::Mul(lhs, rhs) => { - let lhs_grad = grad.mul(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - let rhs_grad = grad.mul(lhs)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; - } - Op::Div(lhs, rhs) => { - let lhs_grad = grad.div(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; - } - Op::BroadcastAdd(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?; - } - Op::BroadcastSub(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; - } - Op::BroadcastMul(lhs, rhs) => { - let lhs_grad = grad.broadcast_mul(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; - let rhs_grad = grad.broadcast_mul(lhs)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; - } - Op::BroadcastDiv(lhs, rhs) => { - let lhs_grad = grad.broadcast_div(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; - let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; - } - Op::Embedding(_lhs, _rhs) => { - return Err(Error::BackwardNotSupported { op: "embedding" }) - } - Op::Matmul(lhs, rhs) => { - // Skipping checks, the op went ok, we can skip - // the matmul size checks for now. - - let lhs_grad = grad.matmul(&rhs.t()?)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - - let rhs_grad = lhs.t()?.matmul(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; - } - Op::Cat(args, dim) => { - let mut start_idx = 0; - for arg in args { - let len = arg.dims()[*dim]; - let arg_grad = grad.narrow(*dim, start_idx, len)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)?; - start_idx += len; - } - } - Op::Broadcast(_arg) => { - return Err(Error::BackwardNotSupported { op: "broadcast" }) - } - Op::ToDType(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? - } - Op::Affine { arg, mul, .. } => { - let arg_grad = grad.affine(*mul, 0.)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)? - } - Op::Log(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&(&grad * *node)?)? - } - Op::Sin(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&(&grad * arg.cos())?)? - } - Op::Cos(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? - } - Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }), - Op::Exp(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&(&grad / arg)?)? - } - Op::Neg(arg) => { - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.sub(&grad)? - } - Op::Narrow(_arg, _, _, _) => { - return Err(Error::BackwardNotSupported { op: "narrow" }) - } - Op::Softmax(_arg, _) => { - return Err(Error::BackwardNotSupported { op: "softmax" }) - } - Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), - Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), - Op::Sqr(arg) => { - let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)? - } - Op::Sqrt(arg) => { - let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)? - } - Op::ToDevice(arg) => { - let sum_grad = grads.or_insert(arg)?; - let arg_grad = grad.to_device(&sum_grad.device())?; - *sum_grad = sum_grad.add(&arg_grad)? - } - Op::Transpose(arg, dim1, dim2) => { - let arg_grad = grad.transpose(*dim1, *dim2)?; - let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)? - } - }; - } - } - Ok(grads) - } } macro_rules! bin_trait { @@ -1127,39 +884,3 @@ bin_trait!(Add, add, |_| 1., |v| v); bin_trait!(Sub, sub, |_| 1., |v: f64| -v); bin_trait!(Mul, mul, |v| v, |_| 0.); bin_trait!(Div, div, |v| 1. / v, |_| 0.); - -pub struct GradStore(HashMap); - -impl GradStore { - fn new() -> Self { - GradStore(HashMap::new()) - } - - pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { - self.0.get(&id) - } - - pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { - self.0.get(&tensor.id) - } - - pub fn remove(&mut self, tensor: &Tensor) -> Option { - self.0.remove(&tensor.id) - } - - pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option { - self.0.insert(tensor.id, grad) - } - - fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { - use std::collections::hash_map::Entry; - let grad = match self.0.entry(tensor.id) { - Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => { - let grad = tensor.zeros_like()?; - entry.insert(grad) - } - }; - Ok(grad) - } -}