mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations * update dtypes for upsample
This commit is contained in:
@ -1038,6 +1038,10 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col_f32",
|
||||
DType::F16 => "im2col_f16",
|
||||
DType::BF16 => "im2col_bf16",
|
||||
DType::U8 => "im2col_u8",
|
||||
DType::U32 => "im2col_u32",
|
||||
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
@ -1250,6 +1254,10 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "upsample_nearest2d_f32",
|
||||
DType::F16 => "upsample_nearest2d_f16",
|
||||
DType::BF16 => "upsample_nearest2d_bf16",
|
||||
DType::U8 => "upsample_nearest2d_u8",
|
||||
DType::U32 => "upsample_nearest2d_u32",
|
||||
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user