From 4eeaf205d6d0577805a41dc7ae2457be1862726a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 19:37:03 +0100 Subject: [PATCH] Fix softmax for long sequences (missing barrier). --- candle-core/src/metal_backend.rs | 2 +- candle-metal-kernels/src/reduce.metal | 15 ++++---- candle-metal-kernels/src/tests.rs | 51 +++++++++++++++++++++------ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 4bc80823..d38796a1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -126,7 +126,7 @@ impl MetalDevice { } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); - // subbuffers.push(new_buffer.clone()); + subbuffers.push(new_buffer.clone()); // println!("Created tensor {size} {name}"); for subbuffers in buffers.values_mut() { let newbuffers = subbuffers diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3a402427..53e4664a 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -32,7 +32,7 @@ kernel void NAME( \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ \ shared_memory[tid] = 0; \ /* \ @@ -67,6 +67,7 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ + threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -95,10 +96,12 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ + float tmp = 0; \ while (idx < stop_idx) { \ - shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \ + tmp = MAX(tmp, src[idx]); \ idx += block_dim; \ } \ + shared_memory[tid] = tmp; \ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ @@ -112,12 +115,13 @@ kernel void NAME( \ float _max = shared_memory[0]; \ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ idx = start_idx + tid; \ while (idx < stop_idx) { \ - const T val = T(exp(src[idx] - _max)); \ - dst[idx] = val; \ + const float val = exp(float(src[idx]) - _max); \ + dst[idx] = T(val); \ shared_memory[tid] += val; \ idx += block_dim; \ } \ @@ -125,10 +129,9 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ - const T inv_acc = T(1/shared_memory[0]); \ + const T inv_acc = T(1.0/shared_memory[0]); \ idx = start_idx + tid; \ while (idx < stop_idx) { \ dst[idx] *= inv_acc; \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8f3e2d43..75c2f013 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -37,7 +37,8 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -59,7 +60,8 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -94,7 +96,8 @@ fn run_strided( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let output = new_buffer(&device, v); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_unary_strided( &device, command_buffer, @@ -247,7 +250,8 @@ fn binary_add_f32() { fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -294,7 +298,8 @@ fn cast_u32_f32() { fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -329,7 +334,8 @@ fn run_affine_strided( add: f64, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -457,7 +463,8 @@ fn run_index_select( _ => unimplemented!(), }; - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_index_select( &device, &command_buffer, @@ -559,7 +566,8 @@ fn cos_f16() { fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -586,7 +594,8 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'static str) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -636,6 +645,24 @@ fn softmax() { vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] ); + let last_dim = 4096; + let n = 200; + let mut v = vec![0.0; n * last_dim]; + for i in 0..n { + v[i * last_dim] = 20.0; + } + let results = run_softmax(&v, last_dim, "softmax_float"); + let results = approx(results, 4); + println!("{results:?}"); + assert_eq!( + results.iter().map(|&s| s.round() as usize).sum::(), + n + ); + assert_eq!(results[0], 1.0); + assert_eq!(results[1], 0.0); + assert_eq!(results[last_dim], 1.0); + assert_eq!(results[2 * last_dim], 1.0); + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; let results = run_softmax(&v, last_dim, "softmax_float"); @@ -686,7 +713,8 @@ fn run_where_cond( name: &'static str, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -762,7 +790,8 @@ fn run_gemm( rhs_offset: usize, ) -> Vec { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged;