mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Fixing softmax.
This commit is contained in:
@ -113,21 +113,23 @@ impl MetalDevice {
|
|||||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions, name: &str) -> Arc<Buffer> {
|
fn _new_buffer(
|
||||||
// println!("Creating new buffer {name}");
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
_name: &str,
|
||||||
|
) -> Arc<Buffer> {
|
||||||
let mut buffers = self.buffers.try_write().unwrap();
|
let mut buffers = self.buffers.try_write().unwrap();
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
for sub in &mut *subbuffers {
|
for sub in &mut *subbuffers {
|
||||||
if Arc::strong_count(sub) == 1 {
|
if Arc::strong_count(sub) == 1 {
|
||||||
// println!("Reusing tensor {size} {name}");
|
|
||||||
return sub.clone();
|
return sub.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
let new_buffer = Arc::new(new_buffer);
|
let new_buffer = Arc::new(new_buffer);
|
||||||
subbuffers.push(new_buffer.clone());
|
subbuffers.push(new_buffer.clone());
|
||||||
// println!("Created tensor {size} {name}");
|
|
||||||
for subbuffers in buffers.values_mut() {
|
for subbuffers in buffers.values_mut() {
|
||||||
let newbuffers = subbuffers
|
let newbuffers = subbuffers
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -67,7 +67,6 @@ kernel void NAME( \
|
|||||||
threadgroup_barrier(mem_flags::mem_none); \
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
dst[dst_id] = shared_memory[0]; \
|
dst[dst_id] = shared_memory[0]; \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
@ -94,11 +93,10 @@ kernel void NAME(
|
|||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||||
size_t idx = start_idx + tid; \
|
size_t idx = start_idx + tid; \
|
||||||
\
|
\
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
|
||||||
\
|
\
|
||||||
float tmp = 0; \
|
float tmp = -INFINITY; \
|
||||||
while (idx < stop_idx) { \
|
while (idx < stop_idx) { \
|
||||||
tmp = MAX(tmp, src[idx]); \
|
tmp = MAX(tmp, float(src[idx])); \
|
||||||
idx += block_dim; \
|
idx += block_dim; \
|
||||||
} \
|
} \
|
||||||
shared_memory[tid] = tmp; \
|
shared_memory[tid] = tmp; \
|
||||||
@ -109,12 +107,15 @@ kernel void NAME(
|
|||||||
if (tid < s) { \
|
if (tid < s) { \
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
|
/* wait for shared_memory[0] to be filled */ \
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
\
|
\
|
||||||
float _max = shared_memory[0]; \
|
float _max = shared_memory[0]; \
|
||||||
\
|
\
|
||||||
|
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
shared_memory[tid] = 0; \
|
shared_memory[tid] = 0; \
|
||||||
\
|
\
|
||||||
@ -125,10 +126,12 @@ kernel void NAME(
|
|||||||
shared_memory[tid] += val; \
|
shared_memory[tid] += val; \
|
||||||
idx += block_dim; \
|
idx += block_dim; \
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
if (tid < s) { \
|
if (tid < s) { \
|
||||||
shared_memory[tid] += shared_memory[tid + s]; \
|
shared_memory[tid] += shared_memory[tid + s]; \
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||||
|
@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let n = layout.stride().len();
|
let n = layout.stride().len();
|
||||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,10 +272,6 @@ impl MHA {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
// let view = xs.to_string();
|
|
||||||
// if view.contains("NaN") {
|
|
||||||
// panic!("NaN");
|
|
||||||
// }
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||||
let qkv = self
|
let qkv = self
|
||||||
|
Reference in New Issue
Block a user