Add training for the llama2.c example (#296)

* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
This commit is contained in:
Laurent Mazare
2023-08-01 17:23:07 +01:00
committed by GitHub
parent babee9f011
commit a27239f3d9
6 changed files with 227 additions and 9 deletions

View File

@ -15,6 +15,21 @@ pub struct Config {
pub norm_eps: f64,
}
impl Config {
pub fn tiny() -> Self {
Self {
dim: 288,
hidden_dim: 768,
n_layers: 6,
n_heads: 6,
n_kv_heads: 6,
vocab_size: 32000,
seq_len: 256,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,