mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Bugfix for the contiguous strides.
This commit is contained in:
@ -25,6 +25,16 @@ impl<S: crate::WithDType> NdArray for S {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
|
||||||
|
fn shape(&self) -> Result<Shape> {
|
||||||
|
Ok(Shape::from(self.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage {
|
||||||
|
S::to_cpu_storage(self.as_slice())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S: crate::WithDType> NdArray for &[S] {
|
impl<S: crate::WithDType> NdArray for &[S] {
|
||||||
fn shape(&self) -> Result<Shape> {
|
fn shape(&self) -> Result<Shape> {
|
||||||
Ok(Shape::from(self.len()))
|
Ok(Shape::from(self.len()))
|
||||||
@ -35,6 +45,16 @@ impl<S: crate::WithDType> NdArray for &[S] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
|
||||||
|
fn shape(&self) -> Result<Shape> {
|
||||||
|
Ok(Shape::from((M, N)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage {
|
||||||
|
S::to_cpu_storage_owned(self.concat())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||||
match self {
|
match self {
|
||||||
|
14
src/dtype.rs
14
src/dtype.rs
@ -18,7 +18,11 @@ impl DType {
|
|||||||
pub trait WithDType: Sized + Copy {
|
pub trait WithDType: Sized + Copy {
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
|
Self::to_cpu_storage_owned(data.to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||||
}
|
}
|
||||||
@ -26,8 +30,8 @@ pub trait WithDType: Sized + Copy {
|
|||||||
impl WithDType for f32 {
|
impl WithDType for f32 {
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::F32(data.to_vec())
|
CpuStorage::F32(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||||
@ -44,8 +48,8 @@ impl WithDType for f32 {
|
|||||||
impl WithDType for f64 {
|
impl WithDType for f64 {
|
||||||
const DTYPE: DType = DType::F64;
|
const DTYPE: DType = DType::F64;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::F64(data.to_vec())
|
CpuStorage::F64(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||||
|
24
src/shape.rs
24
src/shape.rs
@ -109,7 +109,8 @@ impl Shape {
|
|||||||
/// The strides given in number of elements for a contiguous n-dimensional
|
/// The strides given in number of elements for a contiguous n-dimensional
|
||||||
/// arrays using this shape.
|
/// arrays using this shape.
|
||||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||||
self.0
|
let mut stride: Vec<_> = self
|
||||||
|
.0
|
||||||
.iter()
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
.scan(1, |prod, u| {
|
.scan(1, |prod, u| {
|
||||||
@ -117,6 +118,25 @@ impl Shape {
|
|||||||
*prod *= u;
|
*prod *= u;
|
||||||
Some(prod_pre_mult)
|
Some(prod_pre_mult)
|
||||||
})
|
})
|
||||||
.collect()
|
.collect();
|
||||||
|
stride.reverse();
|
||||||
|
stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stride() {
|
||||||
|
let shape = Shape::from(());
|
||||||
|
assert_eq!(shape.stride_contiguous(), []);
|
||||||
|
let shape = Shape::from(42);
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1]);
|
||||||
|
let shape = Shape::from((42, 1337));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||||
|
let shape = Shape::from((299, 792, 458));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub(crate) struct StridedIndex<'a> {
|
pub(crate) struct StridedIndex<'a> {
|
||||||
next_storage_index: Option<usize>,
|
next_storage_index: Option<usize>,
|
||||||
multi_index: Vec<usize>,
|
multi_index: Vec<usize>,
|
||||||
|
@ -11,7 +11,7 @@ fn zeros() -> Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add_mul() -> Result<()> {
|
fn add_mul() -> Result<()> {
|
||||||
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
let tensor = Tensor::new(&[3f32, 1., 4.], Device::Cpu)?;
|
||||||
let dim1 = tensor.shape().r1()?;
|
let dim1 = tensor.shape().r1()?;
|
||||||
assert_eq!(dim1, 3);
|
assert_eq!(dim1, 3);
|
||||||
let content: Vec<f32> = tensor.to_vec1()?;
|
let content: Vec<f32> = tensor.to_vec1()?;
|
||||||
@ -24,3 +24,14 @@ fn add_mul() -> Result<()> {
|
|||||||
assert_eq!(content, [36., 4., 64.]);
|
assert_eq!(content, [36., 4., 64.]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tensor_2d() -> Result<()> {
|
||||||
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let tensor = Tensor::new(data, Device::Cpu)?;
|
||||||
|
let dims = tensor.shape().r2()?;
|
||||||
|
assert_eq!(dims, (2, 5));
|
||||||
|
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||||
|
assert_eq!(content, data);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user