mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some functions to create variables.
This commit is contained in:
@ -81,7 +81,12 @@ macro_rules! binary_op {
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
fn zeros_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
is_variable: bool,
|
||||
) -> Self {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype);
|
||||
let stride = shape.stride_contiguous();
|
||||
@ -91,12 +96,24 @@ impl Tensor {
|
||||
shape,
|
||||
stride,
|
||||
op: None,
|
||||
is_variable: false,
|
||||
is_variable,
|
||||
};
|
||||
Self(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
Self::zeros_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
Self::zeros_impl(shape, dtype, device, true)
|
||||
}
|
||||
|
||||
pub fn new_impl<A: crate::device::NdArray>(
|
||||
array: A,
|
||||
device: Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
let storage = device.tensor(array);
|
||||
let stride = shape.stride_contiguous();
|
||||
@ -106,11 +123,19 @@ impl Tensor {
|
||||
shape,
|
||||
stride,
|
||||
op: None,
|
||||
is_variable: false,
|
||||
is_variable,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
Self::new_impl(array, device, false)
|
||||
}
|
||||
|
||||
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
Self::new_impl(array, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
@ -235,7 +260,11 @@ impl Tensor {
|
||||
return (tg, nodes);
|
||||
}
|
||||
let mut track_grad = false;
|
||||
let mut nodes = if let Some(op) = &node.op {
|
||||
let mut nodes = if node.is_variable {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if let Some(op) = &node.op {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
|
Reference in New Issue
Block a user