diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 2ba6adbb..8228e435 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -18,6 +18,30 @@ fn sgd_optim() -> Result<()> { Ok(()) } +/* The results of this test have been checked against the following PyTorch code. + import torch + from torch import optim + + w_gen = torch.tensor([[3., 1.]]) + b_gen = torch.tensor([-2.]) + + sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) + sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + + m = torch.nn.Linear(2, 1) + with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() + optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.) + for _step in range(1000): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() + print(m.weight) + print(m.bias) +*/ #[test] fn sgd_linear_regression() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2.