Add the pytorch version of the linear regression as a comment. (#163)

* Add the pytorch version of the linear regression.

* Typo.
This commit is contained in:
Laurent Mazare
2023-07-13 21:05:57 +01:00
committed by GitHub
parent 23e105cd94
commit 57be3638d8

View File

@ -18,6 +18,30 @@ fn sgd_optim() -> Result<()> {
Ok(())
}
/* The results of this test have been checked against the following PyTorch code.
import torch
from torch import optim
w_gen = torch.tensor([[3., 1.]])
b_gen = torch.tensor([-2.])
sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
m = torch.nn.Linear(2, 1)
with torch.no_grad():
m.weight.zero_()
m.bias.zero_()
optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.)
for _step in range(1000):
optimizer.zero_grad()
ys = m(sample_xs)
loss = ((ys - sample_ys)**2).sum()
loss.backward()
optimizer.step()
print(m.weight)
print(m.bias)
*/
#[test]
fn sgd_linear_regression() -> Result<()> {
// Generate some linear data, y = 3.x1 + x2 - 2.