From ece4c69a681215837fd5a008e2ee652394daa8ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 15 Dec 2023 01:35:08 +0100 Subject: [PATCH] Fixing softmax. --- candle-core/src/metal_backend.rs | 10 ++++++---- candle-metal-kernels/src/reduce.metal | 11 +++++++---- candle-nn/src/ops.rs | 2 +- candle-transformers/src/models/mixformer.rs | 4 ---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d38796a1..b8b951f0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -113,21 +113,23 @@ impl MetalDevice { self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc { - // println!("Creating new buffer {name}"); + fn _new_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Arc { let mut buffers = self.buffers.try_write().unwrap(); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - // println!("Reusing tensor {size} {name}"); return sub.clone(); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); subbuffers.push(new_buffer.clone()); - // println!("Created tensor {size} {name}"); for subbuffers in buffers.values_mut() { let newbuffers = subbuffers .iter() diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 53e4664a..3633fdcf 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -67,7 +67,6 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ - threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -94,11 +93,10 @@ kernel void NAME( size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t idx = start_idx + tid; \ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ \ - float tmp = 0; \ + float tmp = -INFINITY; \ while (idx < stop_idx) { \ - tmp = MAX(tmp, src[idx]); \ + tmp = MAX(tmp, float(src[idx])); \ idx += block_dim; \ } \ shared_memory[tid] = tmp; \ @@ -109,12 +107,15 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ + /* wait for shared_memory[0] to be filled */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ float _max = shared_memory[0]; \ \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ @@ -125,10 +126,12 @@ kernel void NAME( shared_memory[tid] += val; \ idx += block_dim; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ const T inv_acc = T(1.0/shared_memory[0]); \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e002d931..f00d8e2f 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 3f9aa47d..e4e4f619 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -272,10 +272,6 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { - // let view = xs.to_string(); - // if view.contains("NaN") { - // panic!("NaN"); - // } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self