mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
BlitEncoder added to affine for copying buffer contents quickly.
This commit is contained in:
@ -90,33 +90,45 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
|
|
||||||
/*
|
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
|
|
||||||
// TODO: Don't load library every time
|
// TODO: Don't load library every time
|
||||||
let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap();
|
let library = device
|
||||||
|
.new_library_with_source(AFFINE, &CompileOptions::new())
|
||||||
|
.unwrap();
|
||||||
let function = library.get_function("affine", None).unwrap();
|
let function = library.get_function("affine", None).unwrap();
|
||||||
let pipeline = device
|
let pipeline = device
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
let output_size = el * self.dtype.size_in_bytes();
|
||||||
|
let output_buffer = device.new_buffer(output_size, self.dtype);
|
||||||
|
|
||||||
|
let src_length = self.buffer.length() as usize - layout.start_offset();
|
||||||
|
let src = self.device.new_buffer(src_length, self.dtype);
|
||||||
|
let blit_encoder = self.device.command_buffer.new_blit_command_encoder();
|
||||||
|
blit_encoder.copy_from_buffer(
|
||||||
|
self.buffer.as_ref(),
|
||||||
|
layout.start_offset() as NSUInteger,
|
||||||
|
output_buffer.as_ref(),
|
||||||
|
0,
|
||||||
|
(src_length * self.dtype.size_in_bytes()) as NSUInteger,
|
||||||
|
);
|
||||||
|
blit_encoder.end_encoding();
|
||||||
|
|
||||||
let encoder = device.command_buffer.new_compute_command_encoder();
|
let encoder = device.command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let output_size = el * self.dtype.size_in_bytes();
|
|
||||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
|
|
||||||
let output_buffer = device.new_buffer(output_size, self.dtype);
|
|
||||||
|
|
||||||
encoder.set_bytes(0, 4, void_ptr(&el));
|
encoder.set_bytes(0, 4, void_ptr(&el));
|
||||||
encoder.set_bytes(1, 4, void_ptr(&dims));
|
encoder.set_bytes(1, 4, void_ptr(&dims));
|
||||||
let info = [dims, layout.stride()].concat();
|
let info = [dims, layout.stride()].concat();
|
||||||
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
||||||
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
||||||
|
|
||||||
encoder.set_buffer(3, Some(&self.buffer), 0);
|
encoder.set_buffer(3, Some(&src), 0);
|
||||||
encoder.set_buffer(4, Some(&output_buffer), 0);
|
encoder.set_buffer(4, Some(&output_buffer), 0);
|
||||||
|
|
||||||
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
||||||
@ -136,7 +148,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
*/
|
|
||||||
|
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user