mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
More narrow testing.
This commit is contained in:
@ -116,5 +116,25 @@ fn narrow() -> Result<()> {
|
|||||||
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
|
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
|
||||||
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
|
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.narrow(1, 1, 1)?.to_vec3::<f32>()?,
|
||||||
|
&[[[1.0, 5.0, 9.0]], [[8.0, 2.0, 8.0]]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.narrow(0, 0, 1)?.to_vec3::<f32>()?,
|
||||||
|
&[[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.narrow(0, 1, 1)?.to_vec3::<f32>()?,
|
||||||
|
&[[[2.0, 1.0, 7.0], [8.0, 2.0, 8.0]]],
|
||||||
|
);
|
||||||
|
// The following has been checked against PyTorch via:
|
||||||
|
// import torch
|
||||||
|
// t = torch.tensor([[[3., 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]])
|
||||||
|
// t.transpose(-1, -2).narrow(1, 1, 2)
|
||||||
|
assert_eq!(
|
||||||
|
tensor.t()?.narrow(1, 1, 2)?.to_vec3::<f32>()?,
|
||||||
|
&[[[1.0, 5.0], [4.0, 9.0]], [[1.0, 2.0], [7.0, 8.0]]],
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user