mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the pytorch version of the linear regression as a comment. (#163)
* Add the pytorch version of the linear regression. * Typo.
This commit is contained in:
@ -18,6 +18,30 @@ fn sgd_optim() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* The results of this test have been checked against the following PyTorch code.
|
||||||
|
import torch
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
w_gen = torch.tensor([[3., 1.]])
|
||||||
|
b_gen = torch.tensor([-2.])
|
||||||
|
|
||||||
|
sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
|
||||||
|
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen
|
||||||
|
|
||||||
|
m = torch.nn.Linear(2, 1)
|
||||||
|
with torch.no_grad():
|
||||||
|
m.weight.zero_()
|
||||||
|
m.bias.zero_()
|
||||||
|
optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.)
|
||||||
|
for _step in range(1000):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ys = m(sample_xs)
|
||||||
|
loss = ((ys - sample_ys)**2).sum()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
print(m.weight)
|
||||||
|
print(m.bias)
|
||||||
|
*/
|
||||||
#[test]
|
#[test]
|
||||||
fn sgd_linear_regression() -> Result<()> {
|
fn sgd_linear_regression() -> Result<()> {
|
||||||
// Generate some linear data, y = 3.x1 + x2 - 2.
|
// Generate some linear data, y = 3.x1 + x2 - 2.
|
||||||
|
Reference in New Issue
Block a user