mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Speeding up copies using blit.
This commit is contained in:
@ -868,12 +868,21 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if src_l.is_contiguous(){
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
|
||||
blit.end_encoding();
|
||||
|
||||
}else{
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
@ -894,7 +903,8 @@ impl BackendStorage for MetalStorage {
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy");
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user