diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 16b9cfd9..dfad5f62 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -36,6 +36,8 @@ impl Tensor { // Do not call recursively on the "leaf" nodes. track_grad = true; nodes + } else if node.dtype().is_int() { + nodes } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) @@ -103,7 +105,6 @@ impl Tensor { | Op::Broadcast(node) | Op::Cmp(node, _) | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) - | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Permute(node, _) @@ -116,6 +117,15 @@ impl Tensor { track_grad |= tg; nodes } + Op::ToDType(node) => { + if node.dtype().is_float() { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } else { + nodes + } + } Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { @@ -374,7 +384,7 @@ impl Tensor { } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? + *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)? } Op::Copy(arg) => { let sum_grad = grads.or_insert(arg)?;