Add some functions to create variables.

This commit is contained in:
laurent
2023-06-20 19:31:35 +01:00
parent 9ff8d2076a
commit 3b7984ccce

View File

@ -81,7 +81,12 @@ macro_rules! binary_op {
}
impl Tensor {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
device: Device,
is_variable: bool,
) -> Self {
let shape = shape.into();
let storage = device.zeros(&shape, dtype);
let stride = shape.stride_contiguous();
@ -91,12 +96,24 @@ impl Tensor {
shape,
stride,
op: None,
is_variable: false,
is_variable,
};
Self(Arc::new(tensor_))
}
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::zeros_impl(shape, dtype, device, false)
}
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
Self::zeros_impl(shape, dtype, device, true)
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
device: Device,
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?;
let storage = device.tensor(array);
let stride = shape.stride_contiguous();
@ -106,11 +123,19 @@ impl Tensor {
shape,
stride,
op: None,
is_variable: false,
is_variable,
};
Ok(Self(Arc::new(tensor_)))
}
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@ -235,7 +260,11 @@ impl Tensor {
return (tg, nodes);
}
let mut track_grad = false;
let mut nodes = if let Some(op) = &node.op {
let mut nodes = if node.is_variable {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if let Some(op) = &node.op {
match op {
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);