Add the slice indexing.

This commit is contained in:
laurent
2023-06-20 10:50:58 +01:00
parent 786544292d
commit 6c5fc767a8
3 changed files with 81 additions and 4 deletions

View File

@ -114,13 +114,28 @@ impl Tensor {
}
}
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
crate::storage::StridedIndex::new(self.dims(), self.stride())
}
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
// TODO: properly use the strides here.
todo!()
if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: self.rank(),
shape: self.shape().clone(),
});
}
match &self.storage {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
}
}
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
// TODO: properly use the strides here.
// TODO: Similar to to_vec1 then reshape the resulting vec?
todo!()
}