mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the casting operation.
This commit is contained in:
@ -13,12 +13,12 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, F: FnMut(T) -> T>(
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
) -> Vec<U> {
|
||||
if shape.is_contiguous(stride) {
|
||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||
} else {
|
||||
@ -105,6 +105,48 @@ impl CpuStorage {
|
||||
D::cpu_storage_as_mut_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
|
Reference in New Issue
Block a user