Llama more training (#297)

* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
This commit is contained in:
Laurent Mazare
2023-08-01 19:53:41 +01:00
committed by GitHub
parent a27239f3d9
commit ff876c2103
10 changed files with 238 additions and 163 deletions

View File

@ -1,5 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -43,8 +44,25 @@ pub struct Cache {
impl Cache {
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
let n_elem = cfg.dim / cfg.n_heads;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), vb.device())?;
let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())?
.to_dtype(DType::F32)?
.reshape((cfg.seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let precomputed_cos = idx_theta.cos()?;
let precomputed_sin = idx_theta.sin()?;
let freq_cis_real = vb
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")
.unwrap_or(precomputed_cos);
let freq_cis_imag = vb
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")
.unwrap_or(precomputed_sin);
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self {
@ -76,16 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
Ok(Embedding::new(embeddings, cfg.dim))
}
struct RmsNorm {
scale: Tensor,
eps: f64,
@ -93,7 +101,7 @@ struct RmsNorm {
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
Ok(Self { scale, eps })
}
@ -315,7 +323,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?;
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)

View File

@ -142,15 +142,15 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
dataset.train_tokens.len(),
dataset.valid_tokens.len()
);
let vb = candle_nn::VarBuilder::zeros(DType::F32, &device);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here.
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;

View File

@ -4,128 +4,20 @@ extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Result, Shape, Tensor, Var, D};
use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
struct TensorData {
tensors: std::collections::HashMap<String, Var>,
pub dtype: DType,
pub device: Device,
}
// A variant of candle_nn::VarBuilder for initializing variables before training.
#[derive(Clone)]
struct VarStore {
data: Arc<Mutex<TensorData>>,
path: Vec<String>,
}
impl VarStore {
fn new(dtype: DType, device: Device) -> Self {
let data = TensorData {
tensors: std::collections::HashMap::new(),
dtype,
device,
};
Self {
data: Arc::new(Mutex::new(data)),
path: vec![],
}
}
fn pp(&self, s: &str) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
data: self.data.clone(),
path,
}
}
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let mut tensor_data = self.data.lock().unwrap();
if let Some(tensor) = tensor_data.tensors.get(&path) {
let tensor_shape = tensor.shape();
if &shape != tensor_shape {
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
}
return Ok(tensor.as_tensor().clone());
}
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
}
fn all_vars(&self) -> Vec<Var> {
let tensor_data = self.data.lock().unwrap();
#[allow(clippy::map_clone)]
tensor_data
.tensors
.values()
.map(|c| c.clone())
.collect::<Vec<_>>()
}
fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor()));
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(())
}
fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
use candle::safetensors::Load;
let path = path.as_ref();
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
let data = data.deserialize()?;
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.tensors.iter_mut() {
match data.tensor(name) {
Ok(data) => {
let data: Tensor = data.load(var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Err(_) => candle::bail!("cannot find tensor for {name}"),
}
}
Ok(())
}
}
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get(out_dim, "bias", init_bs)?;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
trait Model: Sized {
fn new(vs: VarStore) -> Result<Self>;
fn new(vs: VarBuilder) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
@ -134,7 +26,7 @@ struct LinearModel {
}
impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
fn new(vs: VarBuilder) -> Result<Self> {
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
@ -150,9 +42,9 @@ struct Mlp {
}
impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
Ok(Self { ln1, ln2 })
}
@ -180,17 +72,16 @@ fn training_loop<M: Model>(
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut vs = VarStore::new(DType::F32, dev.clone());
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = M::new(vs.clone())?;
if let Some(load) = &args.load {
println!("loading weights from {load}");
vs.load(load)?
varmap.load(load)?
}
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..args.epochs {
@ -215,7 +106,7 @@ fn training_loop<M: Model>(
}
if let Some(save) = &args.save {
println!("saving trained weights in {save}");
vs.save(save)?
varmap.save(save)?
}
Ok(())
}