mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some functions to create variables.
This commit is contained in:
@ -81,7 +81,12 @@ macro_rules! binary_op {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
fn zeros_impl<S: Into<Shape>>(
|
||||||
|
shape: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype);
|
let storage = device.zeros(&shape, dtype);
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
@ -91,12 +96,24 @@ impl Tensor {
|
|||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
is_variable: false,
|
is_variable,
|
||||||
};
|
};
|
||||||
Self(Arc::new(tensor_))
|
Self(Arc::new(tensor_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
|
Self::zeros_impl(shape, dtype, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
|
Self::zeros_impl(shape, dtype, device, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_impl<A: crate::device::NdArray>(
|
||||||
|
array: A,
|
||||||
|
device: Device,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
let shape = array.shape()?;
|
let shape = array.shape()?;
|
||||||
let storage = device.tensor(array);
|
let storage = device.tensor(array);
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
@ -106,11 +123,19 @@ impl Tensor {
|
|||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
is_variable: false,
|
is_variable,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||||
|
Self::new_impl(array, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||||
|
Self::new_impl(array, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
let lhs = self.shape();
|
let lhs = self.shape();
|
||||||
let rhs = rhs.shape();
|
let rhs = rhs.shape();
|
||||||
@ -235,7 +260,11 @@ impl Tensor {
|
|||||||
return (tg, nodes);
|
return (tg, nodes);
|
||||||
}
|
}
|
||||||
let mut track_grad = false;
|
let mut track_grad = false;
|
||||||
let mut nodes = if let Some(op) = &node.op {
|
let mut nodes = if node.is_variable {
|
||||||
|
// Do not call recursively on the "leaf" nodes.
|
||||||
|
track_grad = true;
|
||||||
|
nodes
|
||||||
|
} else if let Some(op) = &node.op {
|
||||||
match op {
|
match op {
|
||||||
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
|
Reference in New Issue
Block a user