mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the topological sort for backprop.
This commit is contained in:
5
.cargo/config.toml
Normal file
5
.cargo/config.toml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
rustflags = ["-C", "target-cpu=native"]
|
||||||
|
|
||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "target-cpu=native"]
|
@ -131,7 +131,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn stride() {
|
fn stride() {
|
||||||
let shape = Shape::from(());
|
let shape = Shape::from(());
|
||||||
assert_eq!(shape.stride_contiguous(), []);
|
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||||
let shape = Shape::from(42);
|
let shape = Shape::from(42);
|
||||||
assert_eq!(shape.stride_contiguous(), [1]);
|
assert_eq!(shape.stride_contiguous(), [1]);
|
||||||
let shape = Shape::from((42, 1337));
|
let shape = Shape::from((42, 1337));
|
||||||
|
@ -22,6 +22,7 @@ pub struct Tensor_ {
|
|||||||
// The strides are given in number of elements and not in bytes.
|
// The strides are given in number of elements and not in bytes.
|
||||||
stride: Vec<usize>,
|
stride: Vec<usize>,
|
||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
|
is_variable: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -52,6 +53,7 @@ macro_rules! unary_op {
|
|||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::$op_name(self.clone())),
|
op: Some(Op::$op_name(self.clone())),
|
||||||
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
@ -71,6 +73,7 @@ macro_rules! binary_op {
|
|||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
stride: shape.stride_contiguous(),
|
stride: shape.stride_contiguous(),
|
||||||
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
||||||
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
@ -88,6 +91,7 @@ impl Tensor {
|
|||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Self(Arc::new(tensor_))
|
Self(Arc::new(tensor_))
|
||||||
}
|
}
|
||||||
@ -102,6 +106,7 @@ impl Tensor {
|
|||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
@ -211,4 +216,51 @@ impl Tensor {
|
|||||||
pub fn id(&self) -> TensorId {
|
pub fn id(&self) -> TensorId {
|
||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
|
/// argument.
|
||||||
|
/// This assumes that the op graph is a DAG.
|
||||||
|
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
|
||||||
|
// to get around some lifetime limitations.
|
||||||
|
fn walk<'a>(
|
||||||
|
node: &'a Tensor,
|
||||||
|
nodes: Vec<&'a Tensor>,
|
||||||
|
already_seen: &mut HashMap<TensorId, bool>,
|
||||||
|
) -> (bool, Vec<&'a Tensor>) {
|
||||||
|
if let Some(&tg) = already_seen.get(&node.id) {
|
||||||
|
return (tg, nodes);
|
||||||
|
}
|
||||||
|
let mut track_grad = false;
|
||||||
|
let mut nodes = if let Some(op) = &node.op {
|
||||||
|
match op {
|
||||||
|
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
||||||
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
nodes
|
||||||
|
}
|
||||||
|
Op::Sqr(node) | Op::Sqrt(node) => {
|
||||||
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
nodes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nodes
|
||||||
|
};
|
||||||
|
already_seen.insert(node.id, track_grad);
|
||||||
|
if track_grad {
|
||||||
|
nodes.push(node);
|
||||||
|
}
|
||||||
|
(track_grad, nodes)
|
||||||
|
}
|
||||||
|
let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
|
||||||
|
nodes.reverse();
|
||||||
|
nodes
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user