mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Merge branch 'main' into ivarflakstad/metal-prng
This commit is contained in:
@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
DType::BF16 => "affine_bf16",
|
||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
DType::BF16 => "affine_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(DType::U8, DType::U32) => "where_u8_u32",
|
||||
|
Reference in New Issue
Block a user