Fix softmax for long sequences (missing barrier).

This commit is contained in:
Nicolas Patry
2023-12-14 19:37:03 +01:00
parent f419a38e1a
commit 4eeaf205d6
3 changed files with 50 additions and 18 deletions

View File

@ -126,7 +126,7 @@ impl MetalDevice {
}
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
// subbuffers.push(new_buffer.clone());
subbuffers.push(new_buffer.clone());
// println!("Created tensor {size} {name}");
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers

View File

@ -32,7 +32,7 @@ kernel void NAME( \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
/* \
@ -67,6 +67,7 @@ kernel void NAME( \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
dst[dst_id] = shared_memory[0]; \
} \
@ -95,10 +96,12 @@ kernel void NAME(
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float tmp = 0; \
while (idx < stop_idx) { \
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
tmp = MAX(tmp, src[idx]); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
@ -112,12 +115,13 @@ kernel void NAME(
\
float _max = shared_memory[0]; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const T val = T(exp(src[idx] - _max)); \
dst[idx] = val; \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
@ -125,10 +129,9 @@ kernel void NAME(
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1/shared_memory[0]); \
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \

View File

@ -37,7 +37,8 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -59,7 +60,8 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -94,7 +96,8 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_unary_strided(
&device,
command_buffer,
@ -247,7 +250,8 @@ fn binary_add_f32() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -294,7 +298,8 @@ fn cast_u32_f32() {
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -329,7 +334,8 @@ fn run_affine_strided<T: Clone>(
add: f64,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -457,7 +463,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
_ => unimplemented!(),
};
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(
&device,
&command_buffer,
@ -559,7 +566,8 @@ fn cos_f16() {
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -586,7 +594,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -636,6 +645,24 @@ fn softmax() {
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let last_dim = 4096;
let n = 200;
let mut v = vec![0.0; n * last_dim];
for i in 0..n {
v[i * last_dim] = 20.0;
}
let results = run_softmax(&v, last_dim, "softmax_float");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
);
assert_eq!(results[0], 1.0);
assert_eq!(results[1], 0.0);
assert_eq!(results[last_dim], 1.0);
assert_eq!(results[2 * last_dim], 1.0);
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
@ -686,7 +713,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -762,7 +790,8 @@ fn run_gemm<T: Clone>(
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;