From eaf760a75101b6aa891451566c78c98941a2a9f8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 30 Aug 2023 23:32:08 +0200 Subject: [PATCH] Add a python variant for the lstm test. (#682) --- candle-nn/tests/rnn.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs index 0f0cca38..eda1a381 100644 --- a/candle-nn/tests/rnn.rs +++ b/candle-nn/tests/rnn.rs @@ -7,6 +7,21 @@ extern crate accelerate_src; use candle::{test_utils::to_vec2_round, DType, Device, Result, Tensor}; use candle_nn::RNN; +/* The following test can be verified against PyTorch using the following snippet. +import torch +from torch import nn +lstm = nn.LSTM(2, 3, 1) +lstm.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 24.).reshape(12, 2).cos()) +lstm.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 36.).reshape(12, 3).sin()) +lstm.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2])) +lstm.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]).cos()) +state = torch.zeros((1, 3)), torch.zeros((1, 3)) +for inp in [3., 1., 4., 1., 5., 9., 2.]: + inp = torch.tensor([[inp, inp * 0.5]]) + _out, state = lstm(inp, state) +print(state) +# (tensor([[ 0.9919, 0.1738, -0.1451]], grad_fn=...), tensor([[ 5.7250, 0.4458, -0.2908]], grad_fn=...)) +*/ #[test] fn lstm() -> Result<()> { let cpu = &Device::Cpu;