Add some functions to create variables.

This commit is contained in:
laurent
2023-06-20 19:31:35 +01:00
parent 9ff8d2076a
commit 3b7984ccce

View File

@ -81,7 +81,12 @@ macro_rules! binary_op {
} }
impl Tensor { impl Tensor {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: Device,
is_variable: bool,
) -> Self {
let shape = shape.into(); let shape = shape.into();
let storage = device.zeros(&shape, dtype); let storage = device.zeros(&shape, dtype);
let stride = shape.stride_contiguous(); let stride = shape.stride_contiguous();
@ -91,12 +96,24 @@ impl Tensor {
shape, shape,
stride, stride,
op: None, op: None,
is_variable: false, is_variable,
}; };
Self(Arc::new(tensor_)) Self(Arc::new(tensor_))
} }
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::zeros_impl(shape, dtype, device, false)
}
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::zeros_impl(shape, dtype, device, true)
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
device: Device,
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?; let shape = array.shape()?;
let storage = device.tensor(array); let storage = device.tensor(array);
let stride = shape.stride_contiguous(); let stride = shape.stride_contiguous();
@ -106,11 +123,19 @@ impl Tensor {
shape, shape,
stride, stride,
op: None, op: None,
is_variable: false, is_variable,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))
} }
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape(); let lhs = self.shape();
let rhs = rhs.shape(); let rhs = rhs.shape();
@ -235,7 +260,11 @@ impl Tensor {
return (tg, nodes); return (tg, nodes);
} }
let mut track_grad = false; let mut track_grad = false;
let mut nodes = if let Some(op) = &node.op { let mut nodes = if node.is_variable {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if let Some(op) = &node.op {
match op { match op {
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => { Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen); let (tg, nodes) = walk(lhs, nodes, already_seen);