mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
More narrow testing.
This commit is contained in:
@ -116,5 +116,25 @@ fn narrow() -> Result<()> {
|
||||
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
|
||||
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.narrow(1, 1, 1)?.to_vec3::<f32>()?,
|
||||
&[[[1.0, 5.0, 9.0]], [[8.0, 2.0, 8.0]]],
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.narrow(0, 0, 1)?.to_vec3::<f32>()?,
|
||||
&[[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]],
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.narrow(0, 1, 1)?.to_vec3::<f32>()?,
|
||||
&[[[2.0, 1.0, 7.0], [8.0, 2.0, 8.0]]],
|
||||
);
|
||||
// The following has been checked against PyTorch via:
|
||||
// import torch
|
||||
// t = torch.tensor([[[3., 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]])
|
||||
// t.transpose(-1, -2).narrow(1, 1, 2)
|
||||
assert_eq!(
|
||||
tensor.t()?.narrow(1, 1, 2)?.to_vec3::<f32>()?,
|
||||
&[[[1.0, 5.0], [4.0, 9.0]], [[1.0, 2.0], [7.0, 8.0]]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user