mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the slice indexing.
This commit is contained in:
@ -17,7 +17,65 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
pub(crate) struct StridedIndex<'a> {
|
||||||
|
next_storage_index: Option<usize>,
|
||||||
|
multi_index: Vec<usize>,
|
||||||
|
dims: &'a [usize],
|
||||||
|
stride: &'a [usize],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> StridedIndex<'a> {
|
||||||
|
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
|
||||||
|
let elem_count: usize = dims.iter().product();
|
||||||
|
let next_storage_index = if elem_count == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// This applies to the scalar case.
|
||||||
|
Some(0)
|
||||||
|
};
|
||||||
|
StridedIndex {
|
||||||
|
next_storage_index,
|
||||||
|
multi_index: vec![0; dims.len()],
|
||||||
|
dims,
|
||||||
|
stride,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for StridedIndex<'a> {
|
||||||
|
type Item = usize;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let storage_index = match self.next_storage_index {
|
||||||
|
None => return None,
|
||||||
|
Some(storage_index) => storage_index,
|
||||||
|
};
|
||||||
|
let mut updated = false;
|
||||||
|
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
||||||
|
let next_i = *multi_i + 1;
|
||||||
|
if next_i < *max_i {
|
||||||
|
*multi_i = next_i;
|
||||||
|
updated = true;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
*multi_i = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.next_storage_index = if updated {
|
||||||
|
let next_storage_index = self
|
||||||
|
.multi_index
|
||||||
|
.iter()
|
||||||
|
.zip(self.stride.iter())
|
||||||
|
.map(|(&x, &y)| x * y)
|
||||||
|
.sum();
|
||||||
|
Some(next_storage_index)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Some(storage_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
@ -56,6 +114,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Support broadcasting?
|
||||||
pub(crate) fn add_impl(
|
pub(crate) fn add_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -95,6 +154,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Support broadcasting?
|
||||||
pub(crate) fn mul_impl(
|
pub(crate) fn mul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -114,13 +114,28 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
|
||||||
|
crate::storage::StridedIndex::new(self.dims(), self.stride())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||||
// TODO: properly use the strides here.
|
if self.rank() != 1 {
|
||||||
todo!()
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 1,
|
||||||
|
got: self.rank(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
match &self.storage {
|
||||||
|
Storage::Cpu(cpu_storage) => {
|
||||||
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
|
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||||
// TODO: properly use the strides here.
|
// TODO: Similar to to_vec1 then reshape the resulting vec?
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,5 +9,7 @@ fn add() -> Result<()> {
|
|||||||
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
||||||
let dim1 = tensor.shape().r1()?;
|
let dim1 = tensor.shape().r1()?;
|
||||||
assert_eq!(dim1, 3);
|
assert_eq!(dim1, 3);
|
||||||
|
let content: Vec<f32> = tensor.to_vec1()?;
|
||||||
|
assert_eq!(content, [3., 1., 4.]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user