mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Iteration over strided blocks (#175)
* Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing.
This commit is contained in:
@ -190,13 +190,17 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
dst_offset: usize,
|
||||
src_l: &Layout,
|
||||
) {
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o_dst1, o_dst2)) => {
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1);
|
||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2])
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let to_copy = (dst.len() - dst_offset).min(len);
|
||||
dst[dst_offset..dst_offset + to_copy]
|
||||
.copy_from_slice(&src[start_offset..start_offset + to_copy])
|
||||
}
|
||||
None => {
|
||||
for (dst_index, src_index) in src_l.strided_index().enumerate() {
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len: 1,
|
||||
} => {
|
||||
for (dst_index, src_index) in block_start_index.enumerate() {
|
||||
let dst_index = dst_index + dst_offset;
|
||||
if dst_index >= dst.len() {
|
||||
break;
|
||||
@ -204,6 +208,22 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
dst[dst_index] = src[src_index]
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut dst_index = dst_offset;
|
||||
for src_index in block_start_index {
|
||||
let next_dst_index = dst_index + block_len;
|
||||
if dst_index >= dst.len() {
|
||||
break;
|
||||
}
|
||||
let to_copy = usize::min(block_len, dst.len() - dst_index);
|
||||
dst[dst_index..dst_index + to_copy]
|
||||
.copy_from_slice(&src[src_index..src_index + to_copy]);
|
||||
dst_index = next_dst_index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user