diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 5d1bc2cd..3ea03b0b 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -756,4 +756,9 @@ impl GradStore { }; Ok(grad) } + + /// Get the tensor ids of the stored gradient tensors + pub fn get_ids(&self) -> impl Iterator { + self.0.keys() + } }