mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
BlitEncoder added to affine for copying buffer contents quickly.
This commit is contained in:
@ -90,33 +90,45 @@ impl BackendStorage for MetalStorage {
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
/*
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// TODO: Don't load library every time
|
||||
let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap();
|
||||
let library = device
|
||||
.new_library_with_source(AFFINE, &CompileOptions::new())
|
||||
.unwrap();
|
||||
let function = library.get_function("affine", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let output_size = el * self.dtype.size_in_bytes();
|
||||
let output_buffer = device.new_buffer(output_size, self.dtype);
|
||||
|
||||
let src_length = self.buffer.length() as usize - layout.start_offset();
|
||||
let src = self.device.new_buffer(src_length, self.dtype);
|
||||
let blit_encoder = self.device.command_buffer.new_blit_command_encoder();
|
||||
blit_encoder.copy_from_buffer(
|
||||
self.buffer.as_ref(),
|
||||
layout.start_offset() as NSUInteger,
|
||||
output_buffer.as_ref(),
|
||||
0,
|
||||
(src_length * self.dtype.size_in_bytes()) as NSUInteger,
|
||||
);
|
||||
blit_encoder.end_encoding();
|
||||
|
||||
let encoder = device.command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let output_size = el * self.dtype.size_in_bytes();
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
let output_buffer = device.new_buffer(output_size, self.dtype);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&el));
|
||||
encoder.set_bytes(1, 4, void_ptr(&dims));
|
||||
let info = [dims, layout.stride()].concat();
|
||||
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
||||
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
||||
|
||||
encoder.set_buffer(3, Some(&self.buffer), 0);
|
||||
encoder.set_buffer(3, Some(&src), 0);
|
||||
encoder.set_buffer(4, Some(&output_buffer), 0);
|
||||
|
||||
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
||||
@ -136,7 +148,6 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
*/
|
||||
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
Reference in New Issue
Block a user