mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Get the cpu tests to run.
This commit is contained in:
@ -229,10 +229,6 @@ impl CpuStorage {
|
|||||||
D::cpu_storage_as_slice(self)
|
D::cpu_storage_as_slice(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
|
|
||||||
D::cpu_storage_as_mut_slice(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
// TODO: find a way around the quadratic number of cases below.
|
// TODO: find a way around the quadratic number of cases below.
|
||||||
match (self, dtype) {
|
match (self, dtype) {
|
||||||
@ -581,6 +577,7 @@ impl CpuStorage {
|
|||||||
layout_f: &Layout,
|
layout_f: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Support types that could be casted to a boolean.
|
// TODO: Support types that could be casted to a boolean.
|
||||||
|
// TODO: this should use the layout.
|
||||||
let pred = self.as_slice::<u32>()?;
|
let pred = self.as_slice::<u32>()?;
|
||||||
match (t, f) {
|
match (t, f) {
|
||||||
(Self::BF16(t), Self::BF16(f)) => {
|
(Self::BF16(t), Self::BF16(f)) => {
|
||||||
@ -618,6 +615,7 @@ impl CpuStorage {
|
|||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
// TODO: this should use the layout.
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
|
||||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
|
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,17 +74,6 @@ macro_rules! with_dtype {
|
|||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
|
|
||||||
match s {
|
|
||||||
CpuStorage::$dtype(data) => Ok(data),
|
|
||||||
_ => Err(Error::UnexpectedDType {
|
|
||||||
expected: DType::$dtype,
|
|
||||||
got: s.dtype(),
|
|
||||||
msg: "unexpected dtype",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ impl<'a> StridedIndex<'a> {
|
|||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// This applies to the scalar case.
|
// This applies to the scalar case.
|
||||||
Some(0)
|
Some(layout.start_offset())
|
||||||
};
|
};
|
||||||
StridedIndex {
|
StridedIndex {
|
||||||
next_storage_index,
|
next_storage_index,
|
||||||
|
@ -317,6 +317,7 @@ impl Tensor {
|
|||||||
unary_op!(sqrt, Sqrt);
|
unary_op!(sqrt, Sqrt);
|
||||||
unary_op!(gelu, Gelu);
|
unary_op!(gelu, Gelu);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
|
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
if self.rank() != 0 {
|
if self.rank() != 0 {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
|
|||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
assert!(!a_tt.is_contiguous());
|
assert!(!a_tt.is_contiguous());
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
assert!(!b_tt.is_contiguous());
|
assert!(!b_tt.is_contiguous());
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Reference in New Issue
Block a user