mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop.
This commit is contained in:
@ -15,6 +15,21 @@ pub struct Config {
|
||||
pub norm_eps: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
n_layers: 6,
|
||||
n_heads: 6,
|
||||
n_kv_heads: 6,
|
||||
vocab_size: 32000,
|
||||
seq_len: 256,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
|
Reference in New Issue
Block a user