mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Addressing a lot of comments.
This commit is contained in:
@ -482,11 +482,14 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
if !(sum_dims.len() == 1
|
||||
&& sum_dims[0] == layout.shape().rank() - 1
|
||||
&& layout.stride()[sum_dims[0]] == 1)
|
||||
{
|
||||
crate::bail!("Non last dim reduce op not supported yet");
|
||||
if sum_dims.len() != 1 {
|
||||
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
|
||||
}
|
||||
if sum_dims[0] != layout.shape().rank() - 1 {
|
||||
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
|
||||
}
|
||||
if layout.stride()[sum_dims[0]] != 1 {
|
||||
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
|
||||
}
|
||||
|
||||
let device = self.device.clone();
|
||||
@ -524,7 +527,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
if dtype == DType::U32 {
|
||||
crate::bail!("Implement return index reduce op");
|
||||
crate::bail!("reduce op {name} is not implemented yet.");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
@ -790,12 +793,16 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if t.dtype() != f.dtype() {
|
||||
crate::bail!("Invalid ternary different dtypes for values");
|
||||
crate::bail!(
|
||||
"Invalid where: different dtypes for values {:?} != {:?}",
|
||||
t.dtype(),
|
||||
f.dtype()
|
||||
);
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
|
||||
(left, right) => crate::bail!("where {left:?} - {right:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
|
Reference in New Issue
Block a user