mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Addressing a lot of comments.
This commit is contained in:
@ -482,11 +482,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
if !(sum_dims.len() == 1
|
if sum_dims.len() != 1 {
|
||||||
&& sum_dims[0] == layout.shape().rank() - 1
|
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
|
||||||
&& layout.stride()[sum_dims[0]] == 1)
|
}
|
||||||
{
|
if sum_dims[0] != layout.shape().rank() - 1 {
|
||||||
crate::bail!("Non last dim reduce op not supported yet");
|
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
|
||||||
|
}
|
||||||
|
if layout.stride()[sum_dims[0]] != 1 {
|
||||||
|
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
|
||||||
}
|
}
|
||||||
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
@ -524,7 +527,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
if dtype == DType::U32 {
|
if dtype == DType::U32 {
|
||||||
crate::bail!("Implement return index reduce op");
|
crate::bail!("reduce op {name} is not implemented yet.");
|
||||||
}
|
}
|
||||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
@ -790,12 +793,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if t.dtype() != f.dtype() {
|
if t.dtype() != f.dtype() {
|
||||||
crate::bail!("Invalid ternary different dtypes for values");
|
crate::bail!(
|
||||||
|
"Invalid where: different dtypes for values {:?} != {:?}",
|
||||||
|
t.dtype(),
|
||||||
|
f.dtype()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let name = match (self.dtype, t.dtype()) {
|
let name = match (self.dtype, t.dtype()) {
|
||||||
(DType::U8, DType::F32) => "where_u8_f32",
|
(DType::U8, DType::F32) => "where_u8_f32",
|
||||||
(DType::U8, DType::F16) => "where_u8_f16",
|
(DType::U8, DType::F16) => "where_u8_f16",
|
||||||
(left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
|
(left, right) => crate::bail!("where {left:?} - {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
|
@ -597,6 +597,7 @@ pub fn call_last_softmax(
|
|||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements_to_sum: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
@ -604,7 +605,10 @@ pub fn call_last_softmax(
|
|||||||
encoder.wait_for_fence(&kernels.fence);
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(length, elements_to_sum, (input, input_offset), output)
|
||||||
|
);
|
||||||
|
|
||||||
let out_length = length / elements_to_sum;
|
let out_length = length / elements_to_sum;
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"affine_float",
|
"affine_f32",
|
||||||
size,
|
size,
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&output,
|
||||||
@ -346,7 +346,7 @@ fn run_affine_strided<T: Clone>(
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"affine_float_strided",
|
"affine_f32_strided",
|
||||||
shape,
|
shape,
|
||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
@ -608,6 +608,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||||||
v.len(),
|
v.len(),
|
||||||
last_dim,
|
last_dim,
|
||||||
&input,
|
&input,
|
||||||
|
0,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -622,7 +623,7 @@ fn reduce_sum() {
|
|||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let out_length = 1;
|
let out_length = 1;
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
||||||
assert_eq!(approx(results, 4), vec![21.0]);
|
assert_eq!(approx(results, 4), vec![21.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,7 +632,7 @@ fn reduce_sum2() {
|
|||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let out_length = 2;
|
let out_length = 2;
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_float");
|
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
||||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -639,7 +640,7 @@ fn reduce_sum2() {
|
|||||||
fn softmax() {
|
fn softmax() {
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let last_dim = 6;
|
let last_dim = 6;
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
@ -651,7 +652,7 @@ fn softmax() {
|
|||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
v[i * last_dim] = 20.0;
|
v[i * last_dim] = 20.0;
|
||||||
}
|
}
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
let results = approx(results, 4);
|
let results = approx(results, 4);
|
||||||
println!("{results:?}");
|
println!("{results:?}");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -665,7 +666,7 @@ fn softmax() {
|
|||||||
|
|
||||||
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
let last_dim = 6;
|
let last_dim = 6;
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
|
||||||
@ -673,7 +674,7 @@ fn softmax() {
|
|||||||
|
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let last_dim = 3;
|
let last_dim = 3;
|
||||||
let results = run_softmax(&v, last_dim, "softmax_float");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||||
@ -684,7 +685,7 @@ fn softmax() {
|
|||||||
.map(|v| f16::from_f32(*v))
|
.map(|v| f16::from_f32(*v))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let last_dim = 6;
|
let last_dim = 6;
|
||||||
let results = run_softmax(&v, last_dim, "softmax_half");
|
let results = run_softmax(&v, last_dim, "softmax_f16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_f16(results, 4),
|
approx_f16(results, 4),
|
||||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||||
@ -695,7 +696,7 @@ fn softmax() {
|
|||||||
.map(|v| bf16::from_f32(*v))
|
.map(|v| bf16::from_f32(*v))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let last_dim = 6;
|
let last_dim = 6;
|
||||||
let results = run_softmax(&v, last_dim, "softmax_bfloat");
|
let results = run_softmax(&v, last_dim, "softmax_bf16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_bf16(results, 4),
|
approx_bf16(results, 4),
|
||||||
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
|
||||||
|
@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let n = layout.stride().len();
|
let n = layout.stride().len();
|
||||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,6 +235,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
elem_count,
|
elem_count,
|
||||||
last_dim,
|
last_dim,
|
||||||
storage.buffer(),
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
&mut output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
Reference in New Issue
Block a user