mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
cargo fmt
This commit is contained in:
@ -583,7 +583,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||||
(left, right) => crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented"),
|
(left, right) => {
|
||||||
|
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_contiguous(
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -604,7 +606,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||||
(left, right) => crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented"),
|
(left, right) => {
|
||||||
|
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_strided(
|
candle_metal_kernels::call_cast_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -663,7 +667,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||||
(name, dtype) => crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -704,7 +710,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
(name, dtype) => crate::bail!("Metal strided unary {name} {dtype:?} not implemented"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -1092,7 +1100,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
(left, right) => crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented"),
|
(left, right) => {
|
||||||
|
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
@ -1288,7 +1298,9 @@ impl MetalStorage {
|
|||||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||||
(name, dtype) => crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
@ -1331,7 +1343,9 @@ impl MetalStorage {
|
|||||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||||
(name, dtype) => crate::bail!("Metal strided binary {name} {dtype:?} not implemented"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
|
Reference in New Issue
Block a user