mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Iteration over strided blocks (#175)
* Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing.
This commit is contained in:
138
candle-core/tests/layout_tests.rs
Normal file
138
candle-core/tests/layout_tests.rs
Normal file
@ -0,0 +1,138 @@
|
||||
mod test_utils;
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn contiguous(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<u32>()?,
|
||||
&[
|
||||
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.t()?.contiguous()?.to_vec3::<u32>()?,
|
||||
&[
|
||||
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]],
|
||||
[[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.transpose(0, 1)?.contiguous()?.to_vec3::<u32>()?,
|
||||
&[
|
||||
[[0, 1, 2, 3], [12, 13, 14, 15]],
|
||||
[[4, 5, 6, 7], [16, 17, 18, 19]],
|
||||
[[8, 9, 10, 11], [20, 21, 22, 23]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.transpose(0, 1)?.flatten_all()?.to_vec1::<u32>()?,
|
||||
&[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.i(1..)?
|
||||
.transpose(0, 1)?
|
||||
.contiguous()?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20, 21, 22, 23]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.transpose(0, 2)?.contiguous()?.to_vec3::<u32>()?,
|
||||
&[
|
||||
[[0, 12], [4, 16], [8, 20]],
|
||||
[[1, 13], [5, 17], [9, 21]],
|
||||
[[2, 14], [6, 18], [10, 22]],
|
||||
[[3, 15], [7, 19], [11, 23]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
use candle::Device::Cpu;
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 0);
|
||||
assert_eq!(len, 24);
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 26u32, &Cpu)?
|
||||
.i(2..)?
|
||||
.reshape((2, 3, 4))?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 2);
|
||||
assert_eq!(len, 24);
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i(1)?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 12);
|
||||
assert_eq!(len, 12);
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i((.., 1))?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 0);
|
||||
assert_eq!(len, 8);
|
||||
assert_eq!(tensor.to_vec2::<u32>()?, &[[4, 5, 6, 7], [16, 17, 18, 19]]);
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
match tensor.t()?.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
assert_eq!(block_len, 1);
|
||||
assert_eq!(
|
||||
block_start_index.collect::<Vec<_>>(),
|
||||
&[
|
||||
0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15,
|
||||
19, 23
|
||||
]
|
||||
)
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
match tensor.transpose(0, 1)?.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
assert_eq!(block_len, 4);
|
||||
assert_eq!(
|
||||
block_start_index.collect::<Vec<_>>(),
|
||||
&[0, 12, 4, 16, 8, 20]
|
||||
)
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user