mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add some backprop test.
This commit is contained in:
@ -169,7 +169,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
Self::new_impl(array, device, false)
|
||||
Self::new_impl(array, device, true)
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
@ -329,6 +329,7 @@ impl Tensor {
|
||||
|
||||
pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> {
|
||||
let sorted_nodes = self.sorted_nodes();
|
||||
println!("{}", sorted_nodes.len());
|
||||
let mut grads = HashMap::new();
|
||||
grads.insert(self.id, self.ones_like());
|
||||
for node in sorted_nodes.iter() {
|
||||
|
Reference in New Issue
Block a user