Add a python variant for the lstm test. (#682)

This commit is contained in:
Laurent Mazare
2023-08-30 23:32:08 +02:00
committed by GitHub
parent 1d0bb48fae
commit eaf760a751

View File

@ -7,6 +7,21 @@ extern crate accelerate_src;
use candle::{test_utils::to_vec2_round, DType, Device, Result, Tensor};
use candle_nn::RNN;
/* The following test can be verified against PyTorch using the following snippet.
import torch
from torch import nn
lstm = nn.LSTM(2, 3, 1)
lstm.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 24.).reshape(12, 2).cos())
lstm.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 36.).reshape(12, 3).sin())
lstm.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]))
lstm.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]).cos())
state = torch.zeros((1, 3)), torch.zeros((1, 3))
for inp in [3., 1., 4., 1., 5., 9., 2.]:
inp = torch.tensor([[inp, inp * 0.5]])
_out, state = lstm(inp, state)
print(state)
# (tensor([[ 0.9919, 0.1738, -0.1451]], grad_fn=...), tensor([[ 5.7250, 0.4458, -0.2908]], grad_fn=...))
*/
#[test]
fn lstm() -> Result<()> {
let cpu = &Device::Cpu;