mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Iteration over strided blocks (#175)
* Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing.
This commit is contained in:
@ -190,13 +190,17 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
) {
|
) {
|
||||||
match src_l.contiguous_offsets() {
|
match src_l.strided_blocks() {
|
||||||
Some((o_dst1, o_dst2)) => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1);
|
let to_copy = (dst.len() - dst_offset).min(len);
|
||||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2])
|
dst[dst_offset..dst_offset + to_copy]
|
||||||
|
.copy_from_slice(&src[start_offset..start_offset + to_copy])
|
||||||
}
|
}
|
||||||
None => {
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
for (dst_index, src_index) in src_l.strided_index().enumerate() {
|
block_start_index,
|
||||||
|
block_len: 1,
|
||||||
|
} => {
|
||||||
|
for (dst_index, src_index) in block_start_index.enumerate() {
|
||||||
let dst_index = dst_index + dst_offset;
|
let dst_index = dst_index + dst_offset;
|
||||||
if dst_index >= dst.len() {
|
if dst_index >= dst.len() {
|
||||||
break;
|
break;
|
||||||
@ -204,6 +208,22 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
dst[dst_index] = src[src_index]
|
dst[dst_index] = src[src_index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut dst_index = dst_offset;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let next_dst_index = dst_index + block_len;
|
||||||
|
if dst_index >= dst.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let to_copy = usize::min(block_len, dst.len() - dst_index);
|
||||||
|
dst[dst_index..dst_index + to_copy]
|
||||||
|
.copy_from_slice(&src[src_index..src_index + to_copy]);
|
||||||
|
dst_index = next_dst_index
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +51,8 @@ impl Layout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||||
|
/// Note that this does not implies that the start offset is 0 or that there are no extra
|
||||||
|
/// elements at the end of the storage.
|
||||||
pub fn is_contiguous(&self) -> bool {
|
pub fn is_contiguous(&self) -> bool {
|
||||||
self.shape.is_contiguous(&self.stride)
|
self.shape.is_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
@ -146,6 +148,35 @@ impl Layout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
crate::StridedIndex::new(self)
|
crate::StridedIndex::from_layout(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
|
||||||
|
let mut block_len = 1;
|
||||||
|
let mut contiguous_dims = 0; // These are counted from the right.
|
||||||
|
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
|
||||||
|
if stride != block_len {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
block_len *= dim;
|
||||||
|
contiguous_dims += 1;
|
||||||
|
}
|
||||||
|
let index_dims = self.dims().len() - contiguous_dims;
|
||||||
|
if index_dims == 0 {
|
||||||
|
crate::StridedBlocks::SingleBlock {
|
||||||
|
start_offset: self.start_offset,
|
||||||
|
len: block_len,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let block_start_index = crate::StridedIndex::new(
|
||||||
|
&self.dims()[..index_dims],
|
||||||
|
&self.stride[..index_dims],
|
||||||
|
self.start_offset,
|
||||||
|
);
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ pub use indexer::IndexOp;
|
|||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
use strided_index::StridedIndex;
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
pub use tensor::{Tensor, TensorId};
|
pub use tensor::{Tensor, TensorId};
|
||||||
pub use variable::Var;
|
pub use variable::Var;
|
||||||
|
|
||||||
|
@ -3,28 +3,35 @@ use crate::Layout;
|
|||||||
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
||||||
/// flat buffer using some potential strides.
|
/// flat buffer using some potential strides.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct StridedIndex<'a> {
|
pub struct StridedIndex<'a> {
|
||||||
next_storage_index: Option<usize>,
|
next_storage_index: Option<usize>,
|
||||||
multi_index: Vec<usize>,
|
multi_index: Vec<usize>,
|
||||||
layout: &'a Layout,
|
dims: &'a [usize],
|
||||||
|
stride: &'a [usize],
|
||||||
|
start_offset: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> StridedIndex<'a> {
|
impl<'a> StridedIndex<'a> {
|
||||||
pub(crate) fn new(layout: &'a Layout) -> Self {
|
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
|
||||||
let dims = layout.dims();
|
|
||||||
let elem_count: usize = dims.iter().product();
|
let elem_count: usize = dims.iter().product();
|
||||||
let next_storage_index = if elem_count == 0 {
|
let next_storage_index = if elem_count == 0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// This applies to the scalar case.
|
// This applies to the scalar case.
|
||||||
Some(layout.start_offset())
|
Some(start_offset)
|
||||||
};
|
};
|
||||||
StridedIndex {
|
StridedIndex {
|
||||||
next_storage_index,
|
next_storage_index,
|
||||||
multi_index: vec![0; dims.len()],
|
multi_index: vec![0; dims.len()],
|
||||||
layout,
|
dims,
|
||||||
|
stride,
|
||||||
|
start_offset,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn from_layout(l: &'a Layout) -> Self {
|
||||||
|
Self::new(l.dims(), l.stride(), l.start_offset())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Iterator for StridedIndex<'a> {
|
impl<'a> Iterator for StridedIndex<'a> {
|
||||||
@ -36,12 +43,7 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
Some(storage_index) => storage_index,
|
Some(storage_index) => storage_index,
|
||||||
};
|
};
|
||||||
let mut updated = false;
|
let mut updated = false;
|
||||||
for (multi_i, max_i) in self
|
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
||||||
.multi_index
|
|
||||||
.iter_mut()
|
|
||||||
.zip(self.layout.dims().iter())
|
|
||||||
.rev()
|
|
||||||
{
|
|
||||||
let next_i = *multi_i + 1;
|
let next_i = *multi_i + 1;
|
||||||
if next_i < *max_i {
|
if next_i < *max_i {
|
||||||
*multi_i = next_i;
|
*multi_i = next_i;
|
||||||
@ -55,10 +57,10 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
let next_storage_index = self
|
let next_storage_index = self
|
||||||
.multi_index
|
.multi_index
|
||||||
.iter()
|
.iter()
|
||||||
.zip(self.layout.stride().iter())
|
.zip(self.stride.iter())
|
||||||
.map(|(&x, &y)| x * y)
|
.map(|(&x, &y)| x * y)
|
||||||
.sum::<usize>()
|
.sum::<usize>()
|
||||||
+ self.layout.start_offset();
|
+ self.start_offset;
|
||||||
Some(next_storage_index)
|
Some(next_storage_index)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -66,3 +68,15 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
Some(storage_index)
|
Some(storage_index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StridedBlocks<'a> {
|
||||||
|
SingleBlock {
|
||||||
|
start_offset: usize,
|
||||||
|
len: usize,
|
||||||
|
},
|
||||||
|
MultipleBlocks {
|
||||||
|
block_start_index: StridedIndex<'a>,
|
||||||
|
block_len: usize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
@ -834,10 +834,20 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
/// Returns an iterator over position of the elements in the storage when ranging over the
|
||||||
|
/// index tuples in lexicographic order.
|
||||||
|
pub fn strided_index(&self) -> crate::StridedIndex {
|
||||||
self.layout.strided_index()
|
self.layout.strided_index()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Similar to `strided_index` but returns the position of the start of each contiguous block
|
||||||
|
/// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
|
||||||
|
/// will only return the start offset and the size would be the number of elements in the
|
||||||
|
/// tensor.
|
||||||
|
pub fn strided_blocks(&self) -> crate::StridedBlocks {
|
||||||
|
self.layout.strided_blocks()
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
||||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||||
if self.rank() != 1 {
|
if self.rank() != 1 {
|
||||||
|
138
candle-core/tests/layout_tests.rs
Normal file
138
candle-core/tests/layout_tests.rs
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
mod test_utils;
|
||||||
|
use candle::{Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
|
fn contiguous(device: &Device) -> Result<()> {
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec3::<u32>()?,
|
||||||
|
&[
|
||||||
|
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
|
||||||
|
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.t()?.contiguous()?.to_vec3::<u32>()?,
|
||||||
|
&[
|
||||||
|
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]],
|
||||||
|
[[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.transpose(0, 1)?.contiguous()?.to_vec3::<u32>()?,
|
||||||
|
&[
|
||||||
|
[[0, 1, 2, 3], [12, 13, 14, 15]],
|
||||||
|
[[4, 5, 6, 7], [16, 17, 18, 19]],
|
||||||
|
[[8, 9, 10, 11], [20, 21, 22, 23]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.transpose(0, 1)?.flatten_all()?.to_vec1::<u32>()?,
|
||||||
|
&[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor
|
||||||
|
.i(1..)?
|
||||||
|
.transpose(0, 1)?
|
||||||
|
.contiguous()?
|
||||||
|
.to_vec3::<u32>()?,
|
||||||
|
&[[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20, 21, 22, 23]]]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tensor.transpose(0, 2)?.contiguous()?.to_vec3::<u32>()?,
|
||||||
|
&[
|
||||||
|
[[0, 12], [4, 16], [8, 20]],
|
||||||
|
[[1, 13], [5, 17], [9, 21]],
|
||||||
|
[[2, 14], [6, 18], [10, 22]],
|
||||||
|
[[3, 15], [7, 19], [11, 23]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strided_blocks() -> Result<()> {
|
||||||
|
use candle::Device::Cpu;
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
match tensor.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
assert_eq!(start_offset, 0);
|
||||||
|
assert_eq!(len, 24);
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 26u32, &Cpu)?
|
||||||
|
.i(2..)?
|
||||||
|
.reshape((2, 3, 4))?;
|
||||||
|
match tensor.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
assert_eq!(start_offset, 2);
|
||||||
|
assert_eq!(len, 24);
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
let tensor = tensor.i(1)?;
|
||||||
|
match tensor.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
assert_eq!(start_offset, 12);
|
||||||
|
assert_eq!(len, 12);
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
let tensor = tensor.i((.., 1))?;
|
||||||
|
match tensor.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
assert_eq!(start_offset, 0);
|
||||||
|
assert_eq!(len, 8);
|
||||||
|
assert_eq!(tensor.to_vec2::<u32>()?, &[[4, 5, 6, 7], [16, 17, 18, 19]]);
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
match tensor.t()?.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
assert_eq!(block_len, 1);
|
||||||
|
assert_eq!(
|
||||||
|
block_start_index.collect::<Vec<_>>(),
|
||||||
|
&[
|
||||||
|
0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15,
|
||||||
|
19, 23
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
|
match tensor.transpose(0, 1)?.strided_blocks() {
|
||||||
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
|
panic!("unexpected block structure")
|
||||||
|
}
|
||||||
|
candle::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
assert_eq!(block_len, 4);
|
||||||
|
assert_eq!(
|
||||||
|
block_start_index.collect::<Vec<_>>(),
|
||||||
|
&[0, 12, 4, 16, 8, 20]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user