mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.9.0
...
copy-multi
Author | SHA1 | Date | |
---|---|---|---|
09fafcfa99 |
@ -406,7 +406,7 @@ pub fn call_copy2d(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let width: usize = d1 * d2;
|
let width: usize = d1 * d2;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
|
||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
@ -112,6 +112,7 @@ kernel void FN_NAME( \
|
|||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
|
tid *= 4; \
|
||||||
if (tid >= d1 * d2) { \
|
if (tid >= d1 * d2) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
@ -120,6 +121,9 @@ kernel void FN_NAME( \
|
|||||||
size_t src_idx = idx1 * src_s + idx2; \
|
size_t src_idx = idx1 * src_s + idx2; \
|
||||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||||
output[dst_idx] = input[src_idx]; \
|
output[dst_idx] = input[src_idx]; \
|
||||||
|
output[dst_idx+1] = input[src_idx+1]; \
|
||||||
|
output[dst_idx+2] = input[src_idx+2]; \
|
||||||
|
output[dst_idx+3] = input[src_idx+3]; \
|
||||||
}
|
}
|
||||||
|
|
||||||
COPY2D(copy2d_f32, float)
|
COPY2D(copy2d_f32, float)
|
||||||
|
Reference in New Issue
Block a user