mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the topological sort for backprop.
This commit is contained in:
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@ -0,0 +1,5 @@
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
@ -131,7 +131,7 @@ mod tests {
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), []);
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
|
@ -22,6 +22,7 @@ pub struct Tensor_ {
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
stride: Vec<usize>,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -52,6 +53,7 @@ macro_rules! unary_op {
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone())),
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
@ -71,6 +73,7 @@ macro_rules! binary_op {
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
@ -88,6 +91,7 @@ impl Tensor {
|
||||
shape,
|
||||
stride,
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
Self(Arc::new(tensor_))
|
||||
}
|
||||
@ -102,6 +106,7 @@ impl Tensor {
|
||||
shape,
|
||||
stride,
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
@ -211,4 +216,51 @@ impl Tensor {
|
||||
pub fn id(&self) -> TensorId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
/// argument.
|
||||
/// This assumes that the op graph is a DAG.
|
||||
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
|
||||
// to get around some lifetime limitations.
|
||||
fn walk<'a>(
|
||||
node: &'a Tensor,
|
||||
nodes: Vec<&'a Tensor>,
|
||||
already_seen: &mut HashMap<TensorId, bool>,
|
||||
) -> (bool, Vec<&'a Tensor>) {
|
||||
if let Some(&tg) = already_seen.get(&node.id) {
|
||||
return (tg, nodes);
|
||||
}
|
||||
let mut track_grad = false;
|
||||
let mut nodes = if let Some(op) = &node.op {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Sqr(node) | Op::Sqrt(node) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nodes
|
||||
};
|
||||
already_seen.insert(node.id, track_grad);
|
||||
if track_grad {
|
||||
nodes.push(node);
|
||||
}
|
||||
(track_grad, nodes)
|
||||
}
|
||||
let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
|
||||
nodes.reverse();
|
||||
nodes
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user