mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Detach all grads during backprop. (#1243)
* Detach all grads during backprop. * Add an environment variable to select the backprop behavior. * Update the comment.
This commit is contained in:
@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||||
|
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -155,10 +166,16 @@ impl Tensor {
|
|||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let grad = grads.remove(node).unwrap();
|
let grad = grads
|
||||||
// TODO: We should perform all these operations in place (or at least not track the
|
.remove(node)
|
||||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
.expect("candle internal error - grad not populated");
|
||||||
// this is out of scope.
|
// https://github.com/huggingface/candle/issues/1241
|
||||||
|
// Ideally, we would make these operations in place where possible to ensure that we
|
||||||
|
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||||
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
|
// derivatives but these are out of scope at the moment.
|
||||||
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
|
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
|
Reference in New Issue
Block a user