Get the cpu tests to run.

This commit is contained in:
laurent
2023-06-28 14:32:02 +01:00
parent 14449ff80c
commit caafef6cc1
5 changed files with 6 additions and 19 deletions

View File

@ -229,10 +229,6 @@ impl CpuStorage {
D::cpu_storage_as_slice(self) D::cpu_storage_as_slice(self)
} }
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
D::cpu_storage_as_mut_slice(self)
}
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below. // TODO: find a way around the quadratic number of cases below.
match (self, dtype) { match (self, dtype) {
@ -581,6 +577,7 @@ impl CpuStorage {
layout_f: &Layout, layout_f: &Layout,
) -> Result<Self> { ) -> Result<Self> {
// TODO: Support types that could be casted to a boolean. // TODO: Support types that could be casted to a boolean.
// TODO: this should use the layout.
let pred = self.as_slice::<u32>()?; let pred = self.as_slice::<u32>()?;
match (t, f) { match (t, f) {
(Self::BF16(t), Self::BF16(f)) => { (Self::BF16(t), Self::BF16(f)) => {
@ -618,6 +615,7 @@ impl CpuStorage {
hidden_size: usize, hidden_size: usize,
vocab_size: usize, vocab_size: usize,
) -> Result<Self> { ) -> Result<Self> {
// TODO: this should use the layout.
let ids = self.as_slice::<u32>()?; let ids = self.as_slice::<u32>()?;
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
} }

View File

@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy {
} }
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>; fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
} }
@ -75,17 +74,6 @@ macro_rules! with_dtype {
}), }),
} }
} }
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
} }
}; };
} }

View File

@ -17,7 +17,7 @@ impl<'a> StridedIndex<'a> {
None None
} else { } else {
// This applies to the scalar case. // This applies to the scalar case.
Some(0) Some(layout.start_offset())
}; };
StridedIndex { StridedIndex {
next_storage_index, next_storage_index,

View File

@ -317,6 +317,7 @@ impl Tensor {
unary_op!(sqrt, Sqrt); unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu); unary_op!(gelu, Gelu);
unary_op!(relu, Relu); unary_op!(relu, Relu);
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> { pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 { if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims { return Err(Error::UnexpectedNumberOfDims {

View File

@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
let a_tt = a.t()?.contiguous()?.t()?; let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous()); assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims()); assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]); assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?; let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous()); assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims()); assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]); assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected); assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);