From 3b7984ccce12f7e44c313cd0f90ae3cd6f403627 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 19:31:35 +0100 Subject: [PATCH] Add some functions to create variables. --- src/tensor.rs | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 88f47a15..8a827783 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -81,7 +81,12 @@ macro_rules! binary_op { } impl Tensor { - pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { + fn zeros_impl>( + shape: S, + dtype: DType, + device: Device, + is_variable: bool, + ) -> Self { let shape = shape.into(); let storage = device.zeros(&shape, dtype); let stride = shape.stride_contiguous(); @@ -91,12 +96,24 @@ impl Tensor { shape, stride, op: None, - is_variable: false, + is_variable, }; Self(Arc::new(tensor_)) } - pub fn new(array: A, device: Device) -> Result { + pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { + Self::zeros_impl(shape, dtype, device, false) + } + + pub fn zeros_var>(shape: S, dtype: DType, device: Device) -> Self { + Self::zeros_impl(shape, dtype, device, true) + } + + pub fn new_impl( + array: A, + device: Device, + is_variable: bool, + ) -> Result { let shape = array.shape()?; let storage = device.tensor(array); let stride = shape.stride_contiguous(); @@ -106,11 +123,19 @@ impl Tensor { shape, stride, op: None, - is_variable: false, + is_variable, }; Ok(Self(Arc::new(tensor_))) } + pub fn new(array: A, device: Device) -> Result { + Self::new_impl(array, device, false) + } + + pub fn var(array: A, device: Device) -> Result { + Self::new_impl(array, device, false) + } + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); @@ -235,7 +260,11 @@ impl Tensor { return (tg, nodes); } let mut track_grad = false; - let mut nodes = if let Some(op) = &node.op { + let mut nodes = if node.is_variable { + // Do not call recursively on the "leaf" nodes. + track_grad = true; + nodes + } else if let Some(op) = &node.op { match op { Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen);