mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a basic implementation for slice-assign. (#1377)
This commit is contained in:
@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
|
||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 0, 1],
|
||||
[10, 11, 12, 2, 3],
|
||||
[15, 16, 17, 4, 5]
|
||||
]
|
||||
);
|
||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[2, 3, 7, 8, 9],
|
||||
[4, 5, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user