From caafef6cc14fc355af8401985e0b596b4a481bb7 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:32:02 +0100 Subject: [PATCH] Get the cpu tests to run. --- candle-core/src/cpu_backend.rs | 6 ++---- candle-core/src/dtype.rs | 12 ------------ candle-core/src/strided_index.rs | 2 +- candle-core/src/tensor.rs | 1 + candle-core/tests/tensor_tests.rs | 4 ++-- 5 files changed, 6 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1c5caa82..a5fdb826 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -229,10 +229,6 @@ impl CpuStorage { D::cpu_storage_as_slice(self) } - pub fn as_mut_slice(&mut self) -> Result<&mut [D]> { - D::cpu_storage_as_mut_slice(self) - } - pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { @@ -581,6 +577,7 @@ impl CpuStorage { layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. + // TODO: this should use the layout. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { @@ -618,6 +615,7 @@ impl CpuStorage { hidden_size: usize, vocab_size: usize, ) -> Result { + // TODO: this should use the layout. let ids = self.as_slice::()?; map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 1711c2b4..fdbfdbba 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy { } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>; fn cpu_storage_data(s: CpuStorage) -> Result>; } @@ -75,17 +74,6 @@ macro_rules! with_dtype { }), } } - - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> { - match s { - CpuStorage::$dtype(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::$dtype, - got: s.dtype(), - msg: "unexpected dtype", - }), - } - } } }; } diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index f8dc522f..e6d2868b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -17,7 +17,7 @@ impl<'a> StridedIndex<'a> { None } else { // This applies to the scalar case. - Some(0) + Some(layout.start_offset()) }; StridedIndex { next_storage_index, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b04e90b1..93846160 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -317,6 +317,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(relu, Relu); + pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 78ca4b05..8ac0c9f2 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> { let a_tt = a.t()?.contiguous()?.t()?; assert!(!a_tt.is_contiguous()); assert_eq!(a.dims(), a_tt.dims()); - assert_eq!(a_tt.stride(), &[6, 1, 2]); + assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]); let b_tt = b.t()?.contiguous()?.t()?; assert!(!b_tt.is_contiguous()); assert_eq!(b.dims(), b_tt.dims()); - assert_eq!(b_tt.stride(), &[6, 1, 3]); + assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]); assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected);