Add the topological sort for backprop.

This commit is contained in:
laurent
2023-06-20 19:15:39 +01:00
parent 671bcf060e
commit 9ff8d2076a
3 changed files with 58 additions and 1 deletions

5
.cargo/config.toml Normal file
View File

@ -0,0 +1,5 @@
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=native"]

View File

@ -131,7 +131,7 @@ mod tests {
#[test] #[test]
fn stride() { fn stride() {
let shape = Shape::from(()); let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), []); assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42); let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]); assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337)); let shape = Shape::from((42, 1337));

View File

@ -22,6 +22,7 @@ pub struct Tensor_ {
// The strides are given in number of elements and not in bytes. // The strides are given in number of elements and not in bytes.
stride: Vec<usize>, stride: Vec<usize>,
op: Option<Op>, op: Option<Op>,
is_variable: bool,
} }
#[derive(Clone)] #[derive(Clone)]
@ -52,6 +53,7 @@ macro_rules! unary_op {
shape: shape.clone(), shape: shape.clone(),
stride: shape.stride_contiguous(), stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone())), op: Some(Op::$op_name(self.clone())),
is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
} }
@ -71,6 +73,7 @@ macro_rules! binary_op {
shape: shape.clone(), shape: shape.clone(),
stride: shape.stride_contiguous(), stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone(), rhs.clone())), op: Some(Op::$op_name(self.clone(), rhs.clone())),
is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
} }
@ -88,6 +91,7 @@ impl Tensor {
shape, shape,
stride, stride,
op: None, op: None,
is_variable: false,
}; };
Self(Arc::new(tensor_)) Self(Arc::new(tensor_))
} }
@ -102,6 +106,7 @@ impl Tensor {
shape, shape,
stride, stride,
op: None, op: None,
is_variable: false,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
} }
@ -211,4 +216,51 @@ impl Tensor {
pub fn id(&self) -> TensorId { pub fn id(&self) -> TensorId {
self.id self.id
} }
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
use std::collections::HashMap;
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(
node: &'a Tensor,
nodes: Vec<&'a Tensor>,
already_seen: &mut HashMap<TensorId, bool>,
) -> (bool, Vec<&'a Tensor>) {
if let Some(&tg) = already_seen.get(&node.id) {
return (tg, nodes);
}
let mut track_grad = false;
let mut nodes = if let Some(op) = &node.op {
match op {
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::Sqr(node) | Op::Sqrt(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
}
}
} else {
nodes
};
already_seen.insert(node.id, track_grad);
if track_grad {
nodes.push(node);
}
(track_grad, nodes)
}
let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
nodes.reverse();
nodes
}
} }