BlitEncoder added to affine for copying buffer contents quickly.

This commit is contained in:
Ivar Flakstad
2023-11-06 08:23:36 +01:00
parent 8124d1003f
commit 634a4e7168

View File

@ -90,33 +90,45 @@ impl BackendStorage for MetalStorage {
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
/*
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
// TODO: Don't load library every time
let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap();
let library = device
.new_library_with_source(AFFINE, &CompileOptions::new())
.unwrap();
let function = library.get_function("affine", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let output_size = el * self.dtype.size_in_bytes();
let output_buffer = device.new_buffer(output_size, self.dtype);
let src_length = self.buffer.length() as usize - layout.start_offset();
let src = self.device.new_buffer(src_length, self.dtype);
let blit_encoder = self.device.command_buffer.new_blit_command_encoder();
blit_encoder.copy_from_buffer(
self.buffer.as_ref(),
layout.start_offset() as NSUInteger,
output_buffer.as_ref(),
0,
(src_length * self.dtype.size_in_bytes()) as NSUInteger,
);
blit_encoder.end_encoding();
let encoder = device.command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let output_size = el * self.dtype.size_in_bytes();
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
let output_buffer = device.new_buffer(output_size, self.dtype);
encoder.set_bytes(0, 4, void_ptr(&el));
encoder.set_bytes(1, 4, void_ptr(&dims));
let info = [dims, layout.stride()].concat();
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
encoder.set_buffer(3, Some(&self.buffer), 0);
encoder.set_buffer(3, Some(&src), 0);
encoder.set_buffer(4, Some(&output_buffer), 0);
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
@ -136,7 +148,6 @@ impl BackendStorage for MetalStorage {
encoder.dispatch_threads(grid_size, thread_group_size);
encoder.end_encoding();
*/
Ok(self.clone())
}