diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4105d0de..b9300534 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -321,6 +321,10 @@ impl CpuStorage { pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { + (Self::U8(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::U32(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) @@ -341,6 +345,10 @@ impl CpuStorage { let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } + (Self::U8(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::U32(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -361,6 +369,10 @@ impl CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } + (Self::U8(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::U32(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -381,6 +393,34 @@ impl CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::U8(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::U8(data)) + } + (Self::BF16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::F64(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::U8(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::U32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::U32(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) @@ -401,6 +441,10 @@ impl CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::U32(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -421,7 +465,6 @@ impl CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } - _ => todo!("implement cast for {:?} {dtype:?}", self.dtype()), } }