mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Iteration over strided blocks (#175)
* Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing.
This commit is contained in:
@ -51,6 +51,8 @@ impl Layout {
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
/// Note that this does not implies that the start offset is 0 or that there are no extra
|
||||
/// elements at the end of the storage.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
}
|
||||
@ -146,6 +148,35 @@ impl Layout {
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self)
|
||||
crate::StridedIndex::from_layout(self)
|
||||
}
|
||||
|
||||
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
|
||||
let mut block_len = 1;
|
||||
let mut contiguous_dims = 0; // These are counted from the right.
|
||||
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
|
||||
if stride != block_len {
|
||||
break;
|
||||
}
|
||||
block_len *= dim;
|
||||
contiguous_dims += 1;
|
||||
}
|
||||
let index_dims = self.dims().len() - contiguous_dims;
|
||||
if index_dims == 0 {
|
||||
crate::StridedBlocks::SingleBlock {
|
||||
start_offset: self.start_offset,
|
||||
len: block_len,
|
||||
}
|
||||
} else {
|
||||
let block_start_index = crate::StridedIndex::new(
|
||||
&self.dims()[..index_dims],
|
||||
&self.stride[..index_dims],
|
||||
self.start_offset,
|
||||
);
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user