Add more gradient tests + bugfixes. (#211)

* Add more gradient tests + bugfixes.

* More tests and fixes.

* More tests.
This commit is contained in:
Laurent Mazare
2023-07-21 07:52:39 +02:00
committed by GitHub
parent 4845d5cc64
commit c60831aad4
3 changed files with 60 additions and 4 deletions

View File

@ -40,7 +40,7 @@ pub fn main() -> Result<()> {
let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?;
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1);
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
let test_images = m.test_images;
let test_labels = m.test_labels.to_vec1::<u8>()?;
for epoch in 1..200 {