mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Introduce the variables api used for adjusting parameters during the training loop. (#158)
* Add the variable api. * And add a comment.
This commit is contained in:
@ -56,6 +56,7 @@ mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
@ -67,6 +68,7 @@ pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
use strided_index::StridedIndex;
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
pub use variable::Variable;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
30
candle-core/src/variable.rs
Normal file
30
candle-core/src/variable.rs
Normal file
@ -0,0 +1,30 @@
|
||||
// Variables are wrappers around tensors that can be modified, they are typically used for holding
|
||||
// weights and being modified by gradient descent.
|
||||
// They are not cloneable by default to avoid having too many potential writers on the data.
|
||||
// We also do not expose a public way to create variables as this would break the invariant that
|
||||
// the tensor within a variable is actually with `is_variable` set to `true`.
|
||||
use crate::Tensor;
|
||||
|
||||
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
||||
/// whereas tensors are immutable.
|
||||
#[derive(Debug)]
|
||||
pub struct Variable(Tensor);
|
||||
|
||||
impl std::ops::Deref for Variable {
|
||||
type Target = Tensor;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Consumes this `Variable` and return the underlying tensor.
|
||||
pub fn into_inner(self) -> Tensor {
|
||||
self.0
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user