Fixing softmax.

This commit is contained in:
Nicolas Patry
2023-12-15 01:35:08 +01:00
parent 4eeaf205d6
commit ece4c69a68
4 changed files with 14 additions and 13 deletions

View File

@ -113,21 +113,23 @@ impl MetalDevice {
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
} }
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc<Buffer> { fn _new_buffer(
// println!("Creating new buffer {name}"); &self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Arc<Buffer> {
let mut buffers = self.buffers.try_write().unwrap(); let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry((size, option)).or_insert(vec![]); let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
for sub in &mut *subbuffers { for sub in &mut *subbuffers {
if Arc::strong_count(sub) == 1 { if Arc::strong_count(sub) == 1 {
// println!("Reusing tensor {size} {name}");
return sub.clone(); return sub.clone();
} }
} }
let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer); let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone()); subbuffers.push(new_buffer.clone());
// println!("Created tensor {size} {name}");
for subbuffers in buffers.values_mut() { for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers let newbuffers = subbuffers
.iter() .iter()

View File

@ -67,7 +67,6 @@ kernel void NAME( \
threadgroup_barrier(mem_flags::mem_none); \ threadgroup_barrier(mem_flags::mem_none); \
} \ } \
\ \
threadgroup_barrier(mem_flags::mem_none); \
dst[dst_id] = shared_memory[0]; \ dst[dst_id] = shared_memory[0]; \
} \ } \
@ -94,11 +93,10 @@ kernel void NAME(
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \ size_t idx = start_idx + tid; \
\ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\ \
float tmp = 0; \ float tmp = -INFINITY; \
while (idx < stop_idx) { \ while (idx < stop_idx) { \
tmp = MAX(tmp, src[idx]); \ tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \ idx += block_dim; \
} \ } \
shared_memory[tid] = tmp; \ shared_memory[tid] = tmp; \
@ -109,12 +107,15 @@ kernel void NAME(
if (tid < s) { \ if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \ } \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \ } \
\ \
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \ threadgroup_barrier(mem_flags::mem_threadgroup); \
\ \
float _max = shared_memory[0]; \ float _max = shared_memory[0]; \
\ \
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \ threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \ shared_memory[tid] = 0; \
\ \
@ -125,10 +126,12 @@ kernel void NAME(
shared_memory[tid] += val; \ shared_memory[tid] += val; \
idx += block_dim; \ idx += block_dim; \
} \ } \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \ if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \ shared_memory[tid] += shared_memory[tid + s]; \
} \ } \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \ } \
\ \
const T inv_acc = T(1.0/shared_memory[0]); \ const T inv_acc = T(1.0/shared_memory[0]); \

View File

@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}; };
let n = layout.stride().len(); let n = layout.stride().len();
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
candle::bail!("Non contiguous softmax-last-dim is not implemented"); candle::bail!("Non contiguous softmax-last-dim is not implemented");
} }

View File

@ -272,10 +272,6 @@ impl MHA {
} }
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
// let view = xs.to_string();
// if view.contains("NaN") {
// panic!("NaN");
// }
let _enter = self.span.enter(); let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?; let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self let qkv = self