Speeding up copies using blit.

This commit is contained in:
Nicolas Patry
2023-11-19 23:00:10 +01:00
parent 7052b9c884
commit c93a17694b

View File

@ -868,12 +868,21 @@ impl BackendStorage for MetalStorage {
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer();
if src_l.is_contiguous(){
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
blit.end_encoding();
}else{
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let command_buffer = self.device.command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
@ -894,7 +903,8 @@ impl BackendStorage for MetalStorage {
dst_offset * dst.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy");
command_buffer.set_label("copy_strided");
}
drop(command_buffer);
self.device.commit();
Ok(())