Introduce the variables api used for adjusting parameters during the training loop. (#158)

* Add the variable api.

* And add a comment.
This commit is contained in:
Laurent Mazare
2023-07-13 14:09:51 +01:00
committed by GitHub
parent 7adc8c903a
commit 6991036bc5
2 changed files with 32 additions and 0 deletions

View File

@ -56,6 +56,7 @@ mod storage;
mod strided_index;
mod tensor;
pub mod utils;
mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
@ -67,6 +68,7 @@ pub use shape::{Shape, D};
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
pub use variable::Variable;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};

View File

@ -0,0 +1,30 @@
// Variables are wrappers around tensors that can be modified, they are typically used for holding
// weights and being modified by gradient descent.
// They are not cloneable by default to avoid having too many potential writers on the data.
// We also do not expose a public way to create variables as this would break the invariant that
// the tensor within a variable is actually with `is_variable` set to `true`.
use crate::Tensor;
/// A variable is a wrapper around a tensor, however variables can have their content modified
/// whereas tensors are immutable.
#[derive(Debug)]
pub struct Variable(Tensor);
impl std::ops::Deref for Variable {
type Target = Tensor;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl Variable {
pub fn as_tensor(&self) -> &Tensor {
&self.0
}
/// Consumes this `Variable` and return the underlying tensor.
pub fn into_inner(self) -> Tensor {
self.0
}
}