mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Fixing softmax.
This commit is contained in:
@ -67,7 +67,6 @@ kernel void NAME( \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
@ -94,11 +93,10 @@ kernel void NAME(
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float tmp = 0; \
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, src[idx]); \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
@ -109,12 +107,15 @@ kernel void NAME(
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
@ -125,10 +126,12 @@ kernel void NAME(
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
|
Reference in New Issue
Block a user