Add the casting operation.

This commit is contained in:
laurent
2023-06-23 21:22:07 +01:00
parent 8ed350dc94
commit 5d44e76e3f
6 changed files with 83 additions and 4 deletions

View File

@ -13,12 +13,12 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
fn unary_map<T: Copy, F: FnMut(T) -> T>(
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
shape: &Shape,
stride: &[usize],
vs: &[T],
mut f: F,
) -> Vec<T> {
) -> Vec<U> {
if shape.is_contiguous(stride) {
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
} else {
@ -105,6 +105,48 @@ impl CpuStorage {
D::cpu_storage_as_mut_slice(self)
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below.
match (self, dtype) {
(Self::U32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::F32(data))
}
(Self::F64(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
Ok(Self::F32(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::U32(data))
}
(Self::F32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F64(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F64(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v);
Ok(Self::F64(data))
}
}
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,