mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Llama more training (#297)
* Rework the var-builder to handle initializations. * Add some helper functions for layer creation. * Improve the layer initializations. * Get initialized variables. * Precompute the rot embeddings when training lamas.
This commit is contained in:
@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD};
|
||||
#[test]
|
||||
fn sgd_optim() -> Result<()> {
|
||||
let x = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(&[&x], 0.1);
|
||||
let sgd = SGD::new(vec![x.clone()], 0.1);
|
||||
let xt = x.as_tensor();
|
||||
for _step in 0..100 {
|
||||
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||
@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> {
|
||||
// Now use backprop to run a linear regression between samples and get the coefficients back.
|
||||
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
|
||||
let b = Var::new(0f32, &Device::Cpu)?;
|
||||
let sgd = SGD::new(&[&w, &b], 0.004);
|
||||
let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
|
||||
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
|
||||
for _step in 0..1000 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
|
Reference in New Issue
Block a user