mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Llama more training (#297)
* Rework the var-builder to handle initializations. * Add some helper functions for layer creation. * Improve the layer initializations. * Get initialized variables. * Precompute the rot embeddings when training lamas.
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
use candle_nn::linear_no_bias as linear;
|
||||||
|
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@ -43,8 +44,25 @@ pub struct Cache {
|
|||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
|
let n_elem = cfg.dim / cfg.n_heads;
|
||||||
let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
|
let theta: Vec<_> = (0..n_elem)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||||
|
.collect();
|
||||||
|
let theta = Tensor::new(theta.as_slice(), vb.device())?;
|
||||||
|
let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((cfg.seq_len, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
let precomputed_cos = idx_theta.cos()?;
|
||||||
|
let precomputed_sin = idx_theta.sin()?;
|
||||||
|
|
||||||
|
let freq_cis_real = vb
|
||||||
|
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")
|
||||||
|
.unwrap_or(precomputed_cos);
|
||||||
|
let freq_cis_imag = vb
|
||||||
|
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")
|
||||||
|
.unwrap_or(precomputed_sin);
|
||||||
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||||
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -76,16 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
xs / (xs.neg()?.exp()? + 1.0)?
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
Ok(Linear::new(weight, None))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, cfg.dim))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
@ -93,7 +101,7 @@ struct RmsNorm {
|
|||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
let scale = vb.get(size, "weight")?;
|
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
|
||||||
Ok(Self { scale, eps })
|
Ok(Self { scale, eps })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +323,7 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||||
let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||||
|
@ -142,15 +142,15 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
dataset.train_tokens.len(),
|
dataset.train_tokens.len(),
|
||||||
dataset.valid_tokens.len()
|
dataset.valid_tokens.len()
|
||||||
);
|
);
|
||||||
let vb = candle_nn::VarBuilder::zeros(DType::F32, &device);
|
let varmap = candle_nn::VarMap::new();
|
||||||
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here.
|
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||||
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
|
|
||||||
for (batch_index, batch) in batch_iter.enumerate() {
|
for (batch_index, batch) in batch_iter.enumerate() {
|
||||||
let (inp, tgt) = batch?;
|
let (inp, tgt) = batch?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
|
@ -4,128 +4,20 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Init, Linear};
|
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
|
|
||||||
struct TensorData {
|
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||||
tensors: std::collections::HashMap<String, Var>,
|
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||||
pub dtype: DType,
|
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||||
pub device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
// A variant of candle_nn::VarBuilder for initializing variables before training.
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct VarStore {
|
|
||||||
data: Arc<Mutex<TensorData>>,
|
|
||||||
path: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VarStore {
|
|
||||||
fn new(dtype: DType, device: Device) -> Self {
|
|
||||||
let data = TensorData {
|
|
||||||
tensors: std::collections::HashMap::new(),
|
|
||||||
dtype,
|
|
||||||
device,
|
|
||||||
};
|
|
||||||
Self {
|
|
||||||
data: Arc::new(Mutex::new(data)),
|
|
||||||
path: vec![],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pp(&self, s: &str) -> Self {
|
|
||||||
let mut path = self.path.clone();
|
|
||||||
path.push(s.to_string());
|
|
||||||
Self {
|
|
||||||
data: self.data.clone(),
|
|
||||||
path,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
|
|
||||||
let shape = shape.into();
|
|
||||||
let path = if self.path.is_empty() {
|
|
||||||
tensor_name.to_string()
|
|
||||||
} else {
|
|
||||||
[&self.path.join("."), tensor_name].join(".")
|
|
||||||
};
|
|
||||||
let mut tensor_data = self.data.lock().unwrap();
|
|
||||||
if let Some(tensor) = tensor_data.tensors.get(&path) {
|
|
||||||
let tensor_shape = tensor.shape();
|
|
||||||
if &shape != tensor_shape {
|
|
||||||
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
|
|
||||||
}
|
|
||||||
return Ok(tensor.as_tensor().clone());
|
|
||||||
}
|
|
||||||
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
|
|
||||||
let tensor = var.as_tensor().clone();
|
|
||||||
tensor_data.tensors.insert(path, var);
|
|
||||||
Ok(tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn all_vars(&self) -> Vec<Var> {
|
|
||||||
let tensor_data = self.data.lock().unwrap();
|
|
||||||
#[allow(clippy::map_clone)]
|
|
||||||
tensor_data
|
|
||||||
.tensors
|
|
||||||
.values()
|
|
||||||
.map(|c| c.clone())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let tensor_data = self.data.lock().unwrap();
|
|
||||||
let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor()));
|
|
||||||
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
|
||||||
use candle::safetensors::Load;
|
|
||||||
|
|
||||||
let path = path.as_ref();
|
|
||||||
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
|
||||||
let data = data.deserialize()?;
|
|
||||||
let mut tensor_data = self.data.lock().unwrap();
|
|
||||||
for (name, var) in tensor_data.tensors.iter_mut() {
|
|
||||||
match data.tensor(name) {
|
|
||||||
Ok(data) => {
|
|
||||||
let data: Tensor = data.load(var.device())?;
|
|
||||||
if let Err(err) = var.set(&data) {
|
|
||||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
|
||||||
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
|
||||||
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
|
|
||||||
Ok(Linear::new(ws, Some(bs)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
|
||||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
|
||||||
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
|
|
||||||
let bound = 1. / (in_dim as f64).sqrt();
|
|
||||||
let init_bs = Init::Uniform {
|
|
||||||
lo: -bound,
|
|
||||||
up: bound,
|
|
||||||
};
|
|
||||||
let bs = vs.get(out_dim, "bias", init_bs)?;
|
|
||||||
Ok(Linear::new(ws, Some(bs)))
|
Ok(Linear::new(ws, Some(bs)))
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Model: Sized {
|
trait Model: Sized {
|
||||||
fn new(vs: VarStore) -> Result<Self>;
|
fn new(vs: VarBuilder) -> Result<Self>;
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,7 +26,7 @@ struct LinearModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model for LinearModel {
|
impl Model for LinearModel {
|
||||||
fn new(vs: VarStore) -> Result<Self> {
|
fn new(vs: VarBuilder) -> Result<Self> {
|
||||||
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
||||||
Ok(Self { linear })
|
Ok(Self { linear })
|
||||||
}
|
}
|
||||||
@ -150,9 +42,9 @@ struct Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model for Mlp {
|
impl Model for Mlp {
|
||||||
fn new(vs: VarStore) -> Result<Self> {
|
fn new(vs: VarBuilder) -> Result<Self> {
|
||||||
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
||||||
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
|
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
|
||||||
Ok(Self { ln1, ln2 })
|
Ok(Self { ln1, ln2 })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,17 +72,16 @@ fn training_loop<M: Model>(
|
|||||||
let train_images = m.train_images.to_device(&dev)?;
|
let train_images = m.train_images.to_device(&dev)?;
|
||||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
|
|
||||||
let mut vs = VarStore::new(DType::F32, dev.clone());
|
let mut varmap = VarMap::new();
|
||||||
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||||
let model = M::new(vs.clone())?;
|
let model = M::new(vs.clone())?;
|
||||||
|
|
||||||
if let Some(load) = &args.load {
|
if let Some(load) = &args.load {
|
||||||
println!("loading weights from {load}");
|
println!("loading weights from {load}");
|
||||||
vs.load(load)?
|
varmap.load(load)?
|
||||||
}
|
}
|
||||||
|
|
||||||
let all_vars = vs.all_vars();
|
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
|
||||||
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
|
|
||||||
let test_images = m.test_images.to_device(&dev)?;
|
let test_images = m.test_images.to_device(&dev)?;
|
||||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
for epoch in 1..args.epochs {
|
for epoch in 1..args.epochs {
|
||||||
@ -215,7 +106,7 @@ fn training_loop<M: Model>(
|
|||||||
}
|
}
|
||||||
if let Some(save) = &args.save {
|
if let Some(save) = &args.save {
|
||||||
println!("saving trained weights in {save}");
|
println!("saving trained weights in {save}");
|
||||||
vs.save(save)?
|
varmap.save(save)?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -28,3 +28,15 @@ impl Embedding {
|
|||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
|
||||||
|
let embeddings = vb.get_or_init(
|
||||||
|
(in_size, out_size),
|
||||||
|
"weight",
|
||||||
|
crate::Init::Randn {
|
||||||
|
mean: 0.,
|
||||||
|
stdev: 1.,
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(Embedding::new(embeddings, out_size))
|
||||||
|
}
|
||||||
|
@ -62,3 +62,9 @@ impl LayerNorm {
|
|||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
|
||||||
|
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
|
||||||
|
let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
|
||||||
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
|
}
|
||||||
|
@ -15,9 +15,9 @@ pub mod vision;
|
|||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use conv::{Conv1d, Conv1dConfig};
|
pub use conv::{Conv1d, Conv1dConfig};
|
||||||
pub use embedding::Embedding;
|
pub use embedding::{embedding, Embedding};
|
||||||
pub use init::Init;
|
pub use init::Init;
|
||||||
pub use layer_norm::LayerNorm;
|
pub use layer_norm::{layer_norm, LayerNorm};
|
||||||
pub use linear::Linear;
|
pub use linear::{linear, linear_no_bias, Linear};
|
||||||
pub use optim::SGD;
|
pub use optim::SGD;
|
||||||
pub use var_builder::VarBuilder;
|
pub use var_builder::{VarBuilder, VarMap};
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
|
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
|
||||||
//! # Ok(()) }
|
//! # Ok(()) }
|
||||||
//! ```
|
//! ```
|
||||||
use candle::Tensor;
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Linear {
|
pub struct Linear {
|
||||||
@ -42,3 +42,24 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create or initialize a new linear layer.
|
||||||
|
///
|
||||||
|
/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
|
||||||
|
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||||
|
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||||
|
let bound = 1. / (in_dim as f64).sqrt();
|
||||||
|
let init_bs = crate::Init::Uniform {
|
||||||
|
lo: -bound,
|
||||||
|
up: bound,
|
||||||
|
};
|
||||||
|
let bs = vs.get_or_init(out_dim, "bias", init_bs)?;
|
||||||
|
Ok(Linear::new(ws, Some(bs)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||||
|
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||||
|
Ok(Linear::new(ws, None))
|
||||||
|
}
|
||||||
|
@ -8,7 +8,7 @@ pub struct SGD {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SGD {
|
impl SGD {
|
||||||
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
|
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
|
||||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||||
Self {
|
Self {
|
||||||
vars,
|
vars,
|
||||||
@ -16,6 +16,13 @@ impl SGD {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
vars,
|
||||||
|
learning_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn empty(learning_rate: f64) -> Self {
|
pub fn empty(learning_rate: f64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
vars: vec![],
|
vars: vec![],
|
||||||
|
@ -1,7 +1,87 @@
|
|||||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var};
|
||||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
|
||||||
|
/// and new variables can be added by providing some initialization config in case they are
|
||||||
|
/// missing.
|
||||||
|
/// `VarMap` structures can be serialized in the safetensors format.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct VarMap {
|
||||||
|
data: Arc<Mutex<HashMap<String, Var>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VarMap {
|
||||||
|
/// Create a new empty `VarMap`.
|
||||||
|
#[allow(clippy::new_without_default)]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let data = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
Self { data }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve all the variables currently stored in the map.
|
||||||
|
pub fn all_vars(&self) -> Vec<Var> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
|
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save the map in the safetensors format.
|
||||||
|
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
|
||||||
|
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load some values from a safetensors file and modify the existing variables to have these
|
||||||
|
/// values.
|
||||||
|
///
|
||||||
|
/// Note that values for variables that are currently not in the map are not kept.
|
||||||
|
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||||
|
let path = path.as_ref();
|
||||||
|
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||||
|
let data = data.deserialize()?;
|
||||||
|
let mut tensor_data = self.data.lock().unwrap();
|
||||||
|
for (name, var) in tensor_data.iter_mut() {
|
||||||
|
match data.tensor(name) {
|
||||||
|
Ok(data) => {
|
||||||
|
let data: Tensor = data.load(var.device())?;
|
||||||
|
if let Err(err) = var.set(&data) {
|
||||||
|
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve or add a new variable.
|
||||||
|
pub fn get<S: Into<Shape>>(
|
||||||
|
&self,
|
||||||
|
shape: S,
|
||||||
|
path: &str,
|
||||||
|
init: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let mut tensor_data = self.data.lock().unwrap();
|
||||||
|
if let Some(tensor) = tensor_data.get(path) {
|
||||||
|
let tensor_shape = tensor.shape();
|
||||||
|
if &shape != tensor_shape {
|
||||||
|
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
|
||||||
|
}
|
||||||
|
return Ok(tensor.as_tensor().clone());
|
||||||
|
}
|
||||||
|
let var = init.var(shape, dtype, device)?;
|
||||||
|
let tensor = var.as_tensor().clone();
|
||||||
|
tensor_data.insert(path.to_string(), var);
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
||||||
// generics.
|
// generics.
|
||||||
@ -13,6 +93,7 @@ enum Tensors<'a> {
|
|||||||
Npz(candle::npy::NpzTensors),
|
Npz(candle::npy::NpzTensors),
|
||||||
TensorMap(HashMap<String, Tensor>),
|
TensorMap(HashMap<String, Tensor>),
|
||||||
Zeros,
|
Zeros,
|
||||||
|
VarMap(VarMap),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TensorData<'a> {
|
struct TensorData<'a> {
|
||||||
@ -64,6 +145,14 @@ impl<'a> TensorData<'a> {
|
|||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
|
||||||
|
Self {
|
||||||
|
tensors: Tensors::VarMap(varmap.clone()),
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
|
||||||
|
let data = TensorData::from_varmap(varmap, dtype, device);
|
||||||
|
Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(
|
pub fn from_npz<P: AsRef<std::path::Path>>(
|
||||||
file: P,
|
file: P,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
world_size: usize,
|
world_size: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let data = self.data.as_ref();
|
let data = self.data.as_ref();
|
||||||
let path = if self.path.is_empty() {
|
let path = self.path(tensor_name);
|
||||||
tensor_name.to_string()
|
|
||||||
} else {
|
|
||||||
[&self.path.join("."), tensor_name].join(".")
|
|
||||||
};
|
|
||||||
let tensor = match &self.data.tensors {
|
let tensor = match &self.data.tensors {
|
||||||
Tensors::SafeTensorWithRouting {
|
Tensors::SafeTensorWithRouting {
|
||||||
routing,
|
routing,
|
||||||
@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> {
|
|||||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||||
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
|
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
|
||||||
}
|
}
|
||||||
_ => unimplemented!(),
|
_ => candle::bail!("get_sharded is only available for safetensors"),
|
||||||
};
|
};
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve the tensor associted with the current name and path.
|
||||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||||
let data = self.data.as_ref();
|
let data = self.data.as_ref();
|
||||||
let s: Shape = s.into();
|
let s: Shape = s.into();
|
||||||
let path = if self.path.is_empty() {
|
let path = self.path(tensor_name);
|
||||||
tensor_name.to_string()
|
|
||||||
} else {
|
|
||||||
[&self.path.join("."), tensor_name].join(".")
|
|
||||||
};
|
|
||||||
let tensor = match &self.data.tensors {
|
let tensor = match &self.data.tensors {
|
||||||
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
|
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
|
||||||
Tensors::TensorMap(ts) => ts
|
Tensors::TensorMap(ts) => ts
|
||||||
@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> {
|
|||||||
.bt()
|
.bt()
|
||||||
})?
|
})?
|
||||||
.clone(),
|
.clone(),
|
||||||
|
Tensors::VarMap(varmap) => {
|
||||||
|
let data = varmap.data.lock().unwrap();
|
||||||
|
data.get(&path)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
Error::CannotFindTensor {
|
||||||
|
path: path.to_string(),
|
||||||
|
}
|
||||||
|
.bt()
|
||||||
|
})?
|
||||||
|
.as_tensor()
|
||||||
|
.clone()
|
||||||
|
}
|
||||||
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
|
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
|
||||||
Error::CannotFindTensor {
|
Error::CannotFindTensor {
|
||||||
path: path.to_string(),
|
path: path.to_string(),
|
||||||
@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve the tensor associted with the current name and path or initialize a new tensor if
|
||||||
|
/// it's missing.
|
||||||
|
///
|
||||||
|
/// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
|
||||||
|
pub fn get_or_init<S: Into<Shape>>(
|
||||||
|
&self,
|
||||||
|
s: S,
|
||||||
|
tensor_name: &str,
|
||||||
|
init: crate::Init,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let data = self.data.as_ref();
|
||||||
|
match &self.data.tensors {
|
||||||
|
Tensors::VarMap(varmap) => {
|
||||||
|
let path = self.path(tensor_name);
|
||||||
|
varmap.get(s, &path, init, data.dtype, &data.device)
|
||||||
|
}
|
||||||
|
_ => self.get(s, tensor_name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path(&self, tensor_name: &str) -> String {
|
||||||
|
if self.path.is_empty() {
|
||||||
|
tensor_name.to_string()
|
||||||
|
} else {
|
||||||
|
[&self.path.join("."), tensor_name].join(".")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD};
|
|||||||
#[test]
|
#[test]
|
||||||
fn sgd_optim() -> Result<()> {
|
fn sgd_optim() -> Result<()> {
|
||||||
let x = Var::new(0f32, &Device::Cpu)?;
|
let x = Var::new(0f32, &Device::Cpu)?;
|
||||||
let sgd = SGD::new(&[&x], 0.1);
|
let sgd = SGD::new(vec![x.clone()], 0.1);
|
||||||
let xt = x.as_tensor();
|
let xt = x.as_tensor();
|
||||||
for _step in 0..100 {
|
for _step in 0..100 {
|
||||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||||
@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> {
|
|||||||
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||||
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||||
let b = Var::new(0f32, &Device::Cpu)?;
|
let b = Var::new(0f32, &Device::Cpu)?;
|
||||||
let sgd = SGD::new(&[&w, &b], 0.004);
|
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
|
||||||
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||||
for _step in 0..1000 {
|
for _step in 0..1000 {
|
||||||
let ys = lin.forward(&sample_xs)?;
|
let ys = lin.forward(&sample_xs)?;
|
||||||
|
Reference in New Issue
Block a user