Llama more training (#297)

* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
This commit is contained in:
Laurent Mazare
2023-08-01 19:53:41 +01:00
committed by GitHub
parent a27239f3d9
commit ff876c2103
10 changed files with 238 additions and 163 deletions

View File

@ -8,7 +8,7 @@ pub struct SGD {
}
impl SGD {
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self {
vars,
@ -16,6 +16,13 @@ impl SGD {
}
}
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
Self {
vars,
learning_rate,
}
}
pub fn empty(learning_rate: f64) -> Self {
Self {
vars: vec![],