mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler.
This commit is contained in:
@ -1,13 +1,12 @@
|
||||
// Variables are wrappers around tensors that can be modified, they are typically used for holding
|
||||
// weights and being modified by gradient descent.
|
||||
// They are not cloneable by default to avoid having too many potential writers on the data.
|
||||
// We also do not expose a public way to create variables as this would break the invariant that
|
||||
// the tensor within a variable is actually with `is_variable` set to `true`.
|
||||
// We do not expose a public way to create variables as this would break the invariant that the
|
||||
// tensor within a variable is actually with `is_variable` set to `true`.
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
|
||||
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
||||
/// whereas tensors are immutable.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Var(Tensor);
|
||||
|
||||
impl std::ops::Deref for Var {
|
||||
|
Reference in New Issue
Block a user