mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a python variant for the lstm test. (#682)
This commit is contained in:
@ -7,6 +7,21 @@ extern crate accelerate_src;
|
||||
use candle::{test_utils::to_vec2_round, DType, Device, Result, Tensor};
|
||||
use candle_nn::RNN;
|
||||
|
||||
/* The following test can be verified against PyTorch using the following snippet.
|
||||
import torch
|
||||
from torch import nn
|
||||
lstm = nn.LSTM(2, 3, 1)
|
||||
lstm.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 24.).reshape(12, 2).cos())
|
||||
lstm.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 36.).reshape(12, 3).sin())
|
||||
lstm.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]))
|
||||
lstm.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]).cos())
|
||||
state = torch.zeros((1, 3)), torch.zeros((1, 3))
|
||||
for inp in [3., 1., 4., 1., 5., 9., 2.]:
|
||||
inp = torch.tensor([[inp, inp * 0.5]])
|
||||
_out, state = lstm(inp, state)
|
||||
print(state)
|
||||
# (tensor([[ 0.9919, 0.1738, -0.1451]], grad_fn=...), tensor([[ 5.7250, 0.4458, -0.2908]], grad_fn=...))
|
||||
*/
|
||||
#[test]
|
||||
fn lstm() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
|
Reference in New Issue
Block a user