mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.9.0-alph
...
initialize
Author | SHA1 | Date | |
---|---|---|---|
0bb344f798 |
@ -1,5 +1,6 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap};
|
||||
use candle_nn::init::{DefaultInit, ModelInitializer};
|
||||
use candle_nn::{AdamW, Linear, ParamsAdamW};
|
||||
|
||||
fn gen_data() -> Result<(Tensor, Tensor)> {
|
||||
// Generate some sample linear data.
|
||||
@ -15,14 +16,15 @@ fn main() -> Result<()> {
|
||||
let (sample_xs, sample_ys) = gen_data()?;
|
||||
|
||||
// Use backprop to run a linear regression between samples and get the coefficients back.
|
||||
let varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
|
||||
let model = linear(2, 1, vb.pp("linear"))?;
|
||||
// let varmap = VarMap::new();
|
||||
// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
|
||||
let mut initializer = DefaultInit::new(DType::F32, Device::Cpu);
|
||||
let model = Linear::init(&mut initializer, ((2, 1), true))?;
|
||||
let params = ParamsAdamW {
|
||||
lr: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = AdamW::new(varmap.all_vars(), params)?;
|
||||
let mut opt = AdamW::new(initializer.vars().to_vec(), params)?;
|
||||
for step in 0..10000 {
|
||||
let ys = model.forward(&sample_xs)?;
|
||||
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
|
||||
|
@ -3,6 +3,50 @@
|
||||
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
|
||||
use candle::{DType, Device, Result, Shape, Tensor, Var};
|
||||
|
||||
pub trait Initializer<M> {
|
||||
type Config;
|
||||
fn init(&mut self, config: Self::Config) -> Result<M>;
|
||||
}
|
||||
|
||||
pub trait ModelInitializer: Sized {
|
||||
fn init<INIT: Initializer<Self>>(initializer: &mut INIT, config: INIT::Config) -> Result<Self> {
|
||||
initializer.init(config)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DefaultInit {
|
||||
vars: Vec<Var>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl DefaultInit {
|
||||
pub fn new(dtype: DType, device: Device) -> Self {
|
||||
let vars = vec![];
|
||||
Self {
|
||||
dtype,
|
||||
device,
|
||||
vars,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn vars(&self) -> &[Var] {
|
||||
&self.vars
|
||||
}
|
||||
|
||||
pub fn push_var(&mut self, var: Var) {
|
||||
self.vars.push(var)
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of features as input or output of a layer.
|
||||
/// In Kaiming initialization, choosing `FanIn` preserves
|
||||
/// the magnitude of the variance of the weights in the
|
||||
|
@ -17,6 +17,7 @@
|
||||
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
|
||||
//! # Ok(()) }
|
||||
//! ```
|
||||
use crate::init::{DefaultInit, Initializer, ModelInitializer};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -43,23 +44,45 @@ impl Linear {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create or initialize a new linear layer.
|
||||
impl ModelInitializer for Linear {}
|
||||
|
||||
impl Initializer<Linear> for DefaultInit {
|
||||
type Config = ((usize, usize), bool);
|
||||
|
||||
fn init(&mut self, (shape, has_bias): Self::Config) -> Result<Linear> {
|
||||
let dtype = self.dtype();
|
||||
let device = self.device().clone();
|
||||
let (out_dim, in_dim) = shape;
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = init_ws.var(shape, dtype, &device)?;
|
||||
self.push_var(ws.clone());
|
||||
let ws = ws.as_tensor().clone();
|
||||
if has_bias {
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = init_bs.var(out_dim, dtype, &device)?;
|
||||
self.push_var(bs.clone());
|
||||
let bs = bs.as_tensor().clone();
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
} else {
|
||||
Ok(Linear::new(ws, None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a linear layer.
|
||||
///
|
||||
/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_dim, "bias", init_bs)?;
|
||||
let ws = vs.get((out_dim, in_dim), "weight")?;
|
||||
let bs = vs.get(out_dim, "bias")?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get((out_dim, in_dim), "weight")?;
|
||||
Ok(Linear::new(ws, None))
|
||||
}
|
||||
|
Reference in New Issue
Block a user