mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal. * Add a hacky stopping rule for moondream.
This commit is contained in:
@ -123,7 +123,7 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
|
@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
|
|||||||
(thread_group_count, thread_group_size)
|
(thread_group_count, thread_group_size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||||
|
fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||||
|
let mut pows0 = 0u64;
|
||||||
|
let mut pows1 = 0u64;
|
||||||
|
let mut pows2 = 0u64;
|
||||||
|
let mut sum = 0u64;
|
||||||
|
loop {
|
||||||
|
let presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if dim0 >= (1 << (pows0 + 1)) {
|
||||||
|
pows0 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dim1 >= (1 << (pows1 + 1)) {
|
||||||
|
pows1 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dim2 >= (1 << (pows2 + 1)) {
|
||||||
|
pows2 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == presum || sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MTLSize {
|
||||||
|
width: 1 << pows0,
|
||||||
|
height: 1 << pows1,
|
||||||
|
depth: 1 << pows2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||||
<P as EncoderParam>::set_param(encoder, position, data)
|
<P as EncoderParam>::set_param(encoder, position, data)
|
||||||
}
|
}
|
||||||
@ -396,21 +434,24 @@ pub fn call_copy2d(
|
|||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(
|
||||||
d1,
|
d1 as i64,
|
||||||
d2,
|
d2 as i64,
|
||||||
src_s,
|
src_s as i64,
|
||||||
dst_s,
|
dst_s as i64,
|
||||||
(input, src_o_in_bytes),
|
(input, src_o_in_bytes),
|
||||||
(output, dst_o_in_bytes)
|
(output, dst_o_in_bytes)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
let width: usize = d1 * d2;
|
let grid_dims = MTLSize {
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
width: d1 as u64,
|
||||||
|
height: d2 as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -104,21 +104,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
|||||||
|
|
||||||
#define COPY2D(FN_NAME, TYPENAME) \
|
#define COPY2D(FN_NAME, TYPENAME) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &d1, \
|
constant int64_t &d1, \
|
||||||
constant size_t &d2, \
|
constant int64_t &d2, \
|
||||||
constant size_t &src_s, \
|
constant int64_t &src_s, \
|
||||||
constant size_t &dst_s, \
|
constant int64_t &dst_s, \
|
||||||
device const TYPENAME *input, \
|
device const TYPENAME *input, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint2 idx [[thread_position_in_grid]] \
|
||||||
) { \
|
) { \
|
||||||
if (tid >= d1 * d2) { \
|
if (idx.x >= d1 || idx.y >= d2) return; \
|
||||||
return; \
|
int64_t src_idx = idx.x * src_s + idx.y; \
|
||||||
} \
|
int64_t dst_idx = idx.x * dst_s + idx.y; \
|
||||||
size_t idx1 = tid / d2; \
|
|
||||||
size_t idx2 = tid - idx1 * d2; \
|
|
||||||
size_t src_idx = idx1 * src_s + idx2; \
|
|
||||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
|
||||||
output[dst_idx] = input[src_idx]; \
|
output[dst_idx] = input[src_idx]; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user