Llama more training (#297)

* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
This commit is contained in:
Laurent Mazare
2023-08-01 19:53:41 +01:00
committed by GitHub
parent a27239f3d9
commit ff876c2103
10 changed files with 238 additions and 163 deletions

View File

@ -28,3 +28,15 @@ impl Embedding {
Ok(values)
}
}
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
let embeddings = vb.get_or_init(
(in_size, out_size),
"weight",
crate::Init::Randn {
mean: 0.,
stdev: 1.,
},
)?;
Ok(Embedding::new(embeddings, out_size))
}

View File

@ -62,3 +62,9 @@ impl LayerNorm {
Ok(x)
}
}
pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
Ok(LayerNorm::new(weight, bias, eps))
}

View File

@ -15,9 +15,9 @@ pub mod vision;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
pub use embedding::{embedding, Embedding};
pub use init::Init;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use layer_norm::{layer_norm, LayerNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::SGD;
pub use var_builder::VarBuilder;
pub use var_builder::{VarBuilder, VarMap};

View File

@ -17,7 +17,7 @@
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
//! # Ok(()) }
//! ```
use candle::Tensor;
use candle::{Result, Tensor};
#[derive(Debug)]
pub struct Linear {
@ -42,3 +42,24 @@ impl Linear {
}
}
}
/// Create or initialize a new linear layer.
///
/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get_or_init(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}

View File

@ -8,7 +8,7 @@ pub struct SGD {
}
impl SGD {
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self {
vars,
@ -16,6 +16,13 @@ impl SGD {
}
}
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
Self {
vars,
learning_rate,
}
}
pub fn empty(learning_rate: f64) -> Self {
Self {
vars: vec![],

View File

@ -1,7 +1,87 @@
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var};
use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
/// and new variables can be added by providing some initialization config in case they are
/// missing.
/// `VarMap` structures can be serialized in the safetensors format.
#[derive(Clone)]
pub struct VarMap {
data: Arc<Mutex<HashMap<String, Var>>>,
}
impl VarMap {
/// Create a new empty `VarMap`.
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let data = Arc::new(Mutex::new(HashMap::new()));
Self { data }
}
/// Retrieve all the variables currently stored in the map.
pub fn all_vars(&self) -> Vec<Var> {
let tensor_data = self.data.lock().unwrap();
#[allow(clippy::map_clone)]
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
}
/// Save the map in the safetensors format.
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(())
}
/// Load some values from a safetensors file and modify the existing variables to have these
/// values.
///
/// Note that values for variables that are currently not in the map are not kept.
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let path = path.as_ref();
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
let data = data.deserialize()?;
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.iter_mut() {
match data.tensor(name) {
Ok(data) => {
let data: Tensor = data.load(var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Err(_) => candle::bail!("cannot find tensor for {name}"),
}
}
Ok(())
}
/// Retrieve or add a new variable.
pub fn get<S: Into<Shape>>(
&self,
shape: S,
path: &str,
init: crate::Init,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let shape = shape.into();
let mut tensor_data = self.data.lock().unwrap();
if let Some(tensor) = tensor_data.get(path) {
let tensor_shape = tensor.shape();
if &shape != tensor_shape {
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
}
return Ok(tensor.as_tensor().clone());
}
let var = init.var(shape, dtype, device)?;
let tensor = var.as_tensor().clone();
tensor_data.insert(path.to_string(), var);
Ok(tensor)
}
}
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
// generics.
@ -13,6 +93,7 @@ enum Tensors<'a> {
Npz(candle::npy::NpzTensors),
TensorMap(HashMap<String, Tensor>),
Zeros,
VarMap(VarMap),
}
struct TensorData<'a> {
@ -64,6 +145,14 @@ impl<'a> TensorData<'a> {
dtype,
})
}
fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
Self {
tensors: Tensors::VarMap(varmap.clone()),
device: device.clone(),
dtype,
}
}
}
#[derive(Clone)]
@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> {
}
}
pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
let data = TensorData::from_varmap(varmap, dtype, device);
Self {
data: Arc::new(data),
path: vec![],
}
}
pub fn from_npz<P: AsRef<std::path::Path>>(
file: P,
dtype: DType,
@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> {
world_size: usize,
) -> Result<Tensor> {
let data = self.data.as_ref();
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let path = self.path(tensor_name);
let tensor = match &self.data.tensors {
Tensors::SafeTensorWithRouting {
routing,
@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> {
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
_ => candle::bail!("get_sharded is only available for safetensors"),
};
Ok(tensor)
}
/// Retrieve the tensor associted with the current name and path.
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let path = self.path(tensor_name);
let tensor = match &self.data.tensors {
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
Tensors::TensorMap(ts) => ts
@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> {
.bt()
})?
.clone(),
Tensors::VarMap(varmap) => {
let data = varmap.data.lock().unwrap();
data.get(&path)
.ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?
.as_tensor()
.clone()
}
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> {
}
Ok(tensor)
}
/// Retrieve the tensor associted with the current name and path or initialize a new tensor if
/// it's missing.
///
/// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
pub fn get_or_init<S: Into<Shape>>(
&self,
s: S,
tensor_name: &str,
init: crate::Init,
) -> Result<Tensor> {
let data = self.data.as_ref();
match &self.data.tensors {
Tensors::VarMap(varmap) => {
let path = self.path(tensor_name);
varmap.get(s, &path, init, data.dtype, &data.device)
}
_ => self.get(s, tensor_name),
}
}
fn path(&self, tensor_name: &str) -> String {
if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
}
}
}

View File

@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD};
#[test]
fn sgd_optim() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(&[&x], 0.1);
let sgd = SGD::new(vec![x.clone()], 0.1);
let xt = x.as_tensor();
for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> {
// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let sgd = SGD::new(&[&w, &b], 0.004);
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..1000 {
let ys = lin.forward(&sample_xs)?;