Removing the fences speeds everything up and *is* correct this time...

This commit is contained in:
Nicolas Patry
2024-01-05 19:26:30 +01:00
parent 7b4389099a
commit 9130b6c4b6
4 changed files with 70 additions and 71 deletions

View File

@ -88,7 +88,7 @@ pub struct MetalDevice {
/// execution order to be linear. /// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph. /// compute graph.
fence: metal::Fence, // fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`], both fences need to match /// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
@ -131,9 +131,9 @@ impl MetalDevice {
&self.device &self.device
} }
pub(crate) fn fence(&self) -> &metal::Fence { // pub(crate) fn fence(&self) -> &metal::Fence {
&self.fence // &self.fence
} // }
pub fn command_queue(&self) -> &CommandQueue { pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue &self.command_queue
@ -225,10 +225,10 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence); // blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit"); blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence); // blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
// This is necessary, for mmaped safetensors // This is necessary, for mmaped safetensors
@ -251,7 +251,7 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros"); command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence); // blit.wait_for_fence(&self.fence);
blit.fill_buffer( blit.fill_buffer(
&buffer, &buffer,
metal::NSRange { metal::NSRange {
@ -260,7 +260,7 @@ impl MetalDevice {
}, },
0, 0,
); );
blit.update_fence(&self.fence); // blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
Ok(buffer) Ok(buffer)
} }
@ -1486,9 +1486,9 @@ impl MetalStorage {
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu"); blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence); // blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence); // blit.update_fence(&self.device.fence);
blit.end_encoding(); blit.end_encoding();
} }
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
@ -1506,16 +1506,16 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue(); command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0)); let command_buffer_index = Arc::new(RwLock::new(0));
let fence = device.new_fence(); // let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone())); let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new())); let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 20, _ => 10,
}; };
Ok(Self { Ok(Self {
device, device,
fence, // fence,
command_queue, command_queue,
command_buffer, command_buffer,
command_buffer_index, command_buffer_index,

View File

@ -32,9 +32,9 @@ impl QMetalStorage {
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu"); blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence()); // blit.wait_for_fence(&self.device.fence());
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence()); // blit.update_fence(&self.device.fence());
blit.end_encoding(); blit.end_encoding();
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count]; let mut out = vec![0.0; elem_count];

View File

@ -250,7 +250,7 @@ fn main() -> Result<()> {
let vb = let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?; let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu) (Model::Quantized(model), device)
} else { } else {
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16

View File

@ -219,17 +219,17 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,
fence: metal::Fence, // fence: metal::Fence,
} }
impl Kernels { impl Kernels {
pub fn new(fence: metal::Fence) -> Self { pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new()); let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new()); let pipelines = RwLock::new(Pipelines::new());
Self { Self {
libraries, libraries,
pipelines, pipelines,
fence, // fence,
} }
} }
@ -350,7 +350,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
@ -359,7 +359,7 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -381,7 +381,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -403,7 +403,7 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -422,7 +422,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output)); set_params!(encoder, (length, left, right, output));
@ -433,7 +433,7 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -458,7 +458,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -483,7 +483,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -502,7 +502,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output)); set_params!(encoder, (length, (input, input_offset), output));
@ -511,7 +511,7 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -531,7 +531,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -553,7 +553,7 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -573,7 +573,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -602,7 +602,7 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -624,7 +624,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -660,7 +660,7 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -679,7 +679,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -710,7 +710,7 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -730,7 +730,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
@ -739,7 +739,7 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -762,7 +762,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -783,7 +783,7 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -802,7 +802,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -811,7 +811,7 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -833,7 +833,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -853,7 +853,7 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -872,7 +872,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -881,7 +881,7 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -903,7 +903,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -923,7 +923,7 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -945,7 +945,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
@ -974,7 +974,7 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1001,7 +1001,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1024,7 +1024,7 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1053,7 +1053,7 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1076,7 +1076,7 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1105,7 +1105,7 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1128,7 +1128,7 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1158,7 +1158,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1182,7 +1182,7 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1386,7 +1386,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes; let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1430,7 +1430,7 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1455,7 +1455,7 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1475,7 +1475,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1505,7 +1505,7 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1527,7 +1527,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1553,7 +1553,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32; let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1571,7 +1571,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1710,10 +1710,9 @@ pub fn call_quantized_matmul_t(
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence); //encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
// println!("{b} {m} {n} {k}");
set_params!( set_params!(
encoder, encoder,
( (
@ -1744,7 +1743,7 @@ pub fn call_quantized_matmul_t(
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.update_fence(&kernels.fence); //encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())