mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Bugfix for the contiguous strides.
This commit is contained in:
@ -25,6 +25,16 @@ impl<S: crate::WithDType> NdArray for S {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for &[S] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
@ -35,6 +45,16 @@ impl<S: crate::WithDType> NdArray for &[S] {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((M, N)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage_owned(self.concat())
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||
match self {
|
||||
|
14
src/dtype.rs
14
src/dtype.rs
@ -18,7 +18,11 @@ impl DType {
|
||||
pub trait WithDType: Sized + Copy {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
Self::to_cpu_storage_owned(data.to_vec())
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
}
|
||||
@ -26,8 +30,8 @@ pub trait WithDType: Sized + Copy {
|
||||
impl WithDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F32(data.to_vec())
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::F32(data)
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
@ -44,8 +48,8 @@ impl WithDType for f32 {
|
||||
impl WithDType for f64 {
|
||||
const DTYPE: DType = DType::F64;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F64(data.to_vec())
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::F64(data)
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
|
24
src/shape.rs
24
src/shape.rs
@ -109,7 +109,8 @@ impl Shape {
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||
self.0
|
||||
let mut stride: Vec<_> = self
|
||||
.0
|
||||
.iter()
|
||||
.rev()
|
||||
.scan(1, |prod, u| {
|
||||
@ -117,6 +118,25 @@ impl Shape {
|
||||
*prod *= u;
|
||||
Some(prod_pre_mult)
|
||||
})
|
||||
.collect()
|
||||
.collect();
|
||||
stride.reverse();
|
||||
stride
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), []);
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StridedIndex<'a> {
|
||||
next_storage_index: Option<usize>,
|
||||
multi_index: Vec<usize>,
|
||||
|
@ -11,7 +11,7 @@ fn zeros() -> Result<()> {
|
||||
|
||||
#[test]
|
||||
fn add_mul() -> Result<()> {
|
||||
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], Device::Cpu)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
let content: Vec<f32> = tensor.to_vec1()?;
|
||||
@ -24,3 +24,14 @@ fn add_mul() -> Result<()> {
|
||||
assert_eq!(content, [36., 4., 64.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_2d() -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, Device::Cpu)?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content, data);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user