mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Llama more training (#297)
* Rework the var-builder to handle initializations. * Add some helper functions for layer creation. * Improve the layer initializations. * Get initialized variables. * Precompute the rot embeddings when training lamas.
This commit is contained in:
@ -8,7 +8,7 @@ pub struct SGD {
|
||||
}
|
||||
|
||||
impl SGD {
|
||||
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
|
||||
pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
|
||||
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||
Self {
|
||||
vars,
|
||||
@ -16,6 +16,13 @@ impl SGD {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(learning_rate: f64) -> Self {
|
||||
Self {
|
||||
vars: vec![],
|
||||
|
Reference in New Issue
Block a user