mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add more gradient tests + bugfixes. (#211)
* Add more gradient tests + bugfixes. * More tests and fixes. * More tests.
This commit is contained in:
@ -40,7 +40,7 @@ pub fn main() -> Result<()> {
|
||||
let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?;
|
||||
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1);
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_vec1::<u8>()?;
|
||||
for epoch in 1..200 {
|
||||
|
Reference in New Issue
Block a user