mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Removing the fences speeds everything up and *is* correct this time...
This commit is contained in:
@ -88,7 +88,7 @@ pub struct MetalDevice {
|
|||||||
/// execution order to be linear.
|
/// execution order to be linear.
|
||||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||||
/// compute graph.
|
/// compute graph.
|
||||||
fence: metal::Fence,
|
// fence: metal::Fence,
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
@ -131,9 +131,9 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn fence(&self) -> &metal::Fence {
|
// pub(crate) fn fence(&self) -> &metal::Fence {
|
||||||
&self.fence
|
// &self.fence
|
||||||
}
|
// }
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
&self.command_queue
|
&self.command_queue
|
||||||
@ -225,10 +225,10 @@ impl MetalDevice {
|
|||||||
let command_buffer = self.command_buffer()?;
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("with_data");
|
command_buffer.set_label("with_data");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.wait_for_fence(&self.fence);
|
// blit.wait_for_fence(&self.fence);
|
||||||
blit.set_label("with_data_blit");
|
blit.set_label("with_data_blit");
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||||
blit.update_fence(&self.fence);
|
// blit.update_fence(&self.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
|
|
||||||
// This is necessary, for mmaped safetensors
|
// This is necessary, for mmaped safetensors
|
||||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
|||||||
let command_buffer = self.command_buffer()?;
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("zeros");
|
command_buffer.set_label("zeros");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.wait_for_fence(&self.fence);
|
// blit.wait_for_fence(&self.fence);
|
||||||
blit.fill_buffer(
|
blit.fill_buffer(
|
||||||
&buffer,
|
&buffer,
|
||||||
metal::NSRange {
|
metal::NSRange {
|
||||||
@ -260,7 +260,7 @@ impl MetalDevice {
|
|||||||
},
|
},
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
blit.update_fence(&self.fence);
|
// blit.update_fence(&self.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
@ -1543,9 +1543,9 @@ impl MetalStorage {
|
|||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
blit.wait_for_fence(&self.device.fence);
|
// blit.wait_for_fence(&self.device.fence);
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
blit.update_fence(&self.device.fence);
|
// blit.update_fence(&self.device.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
@ -1563,16 +1563,16 @@ impl BackendDevice for MetalDevice {
|
|||||||
command_buffer.enqueue();
|
command_buffer.enqueue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||||
let fence = device.new_fence();
|
// let fence = device.new_fence();
|
||||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
Ok(val) => val.parse()?,
|
Ok(val) => val.parse()?,
|
||||||
_ => 20,
|
_ => 10,
|
||||||
};
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
fence,
|
// fence,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
command_buffer_index,
|
command_buffer_index,
|
||||||
|
@ -32,9 +32,9 @@ impl QMetalStorage {
|
|||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
blit.wait_for_fence(&self.device.fence());
|
// blit.wait_for_fence(&self.device.fence());
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
blit.update_fence(&self.device.fence());
|
// blit.update_fence(&self.device.fence());
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
|
@ -250,7 +250,7 @@ fn main() -> Result<()> {
|
|||||||
let vb =
|
let vb =
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||||
let model = QMistral::new(&config, vb)?;
|
let model = QMistral::new(&config, vb)?;
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), device)
|
||||||
} else {
|
} else {
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
|
@ -219,17 +219,17 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
|||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
pipelines: RwLock<Pipelines>,
|
pipelines: RwLock<Pipelines>,
|
||||||
fence: metal::Fence,
|
// fence: metal::Fence,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
pub fn new(fence: metal::Fence) -> Self {
|
pub fn new() -> Self {
|
||||||
let libraries = RwLock::new(Libraries::new());
|
let libraries = RwLock::new(Libraries::new());
|
||||||
let pipelines = RwLock::new(Pipelines::new());
|
let pipelines = RwLock::new(Pipelines::new());
|
||||||
Self {
|
Self {
|
||||||
libraries,
|
libraries,
|
||||||
pipelines,
|
pipelines,
|
||||||
fence,
|
// fence,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,7 +350,7 @@ pub fn call_unary_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, input, output));
|
set_params!(encoder, (length, input, output));
|
||||||
@ -359,7 +359,7 @@ pub fn call_unary_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -381,7 +381,7 @@ pub fn call_unary_strided(
|
|||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -403,7 +403,7 @@ pub fn call_unary_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -422,7 +422,7 @@ pub fn call_binary_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, left, right, output));
|
set_params!(encoder, (length, left, right, output));
|
||||||
@ -433,7 +433,7 @@ pub fn call_binary_contiguous(
|
|||||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -458,7 +458,7 @@ pub fn call_binary_strided(
|
|||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -483,7 +483,7 @@ pub fn call_binary_strided(
|
|||||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -502,7 +502,7 @@ pub fn call_cast_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, (input, input_offset), output));
|
||||||
@ -511,7 +511,7 @@ pub fn call_cast_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -531,7 +531,7 @@ pub fn call_cast_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -553,7 +553,7 @@ pub fn call_cast_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -573,7 +573,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -602,7 +602,7 @@ pub fn call_reduce_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -624,7 +624,7 @@ pub fn call_reduce_strided(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -660,7 +660,7 @@ pub fn call_reduce_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -679,7 +679,7 @@ pub fn call_last_softmax(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -710,7 +710,7 @@ pub fn call_last_softmax(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -730,7 +730,7 @@ pub fn call_affine(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, add, input, output));
|
set_params!(encoder, (size, mul, add, input, output));
|
||||||
@ -739,7 +739,7 @@ pub fn call_affine(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -762,7 +762,7 @@ pub fn call_affine_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -783,7 +783,7 @@ pub fn call_affine_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -802,7 +802,7 @@ pub fn call_powf(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
@ -811,7 +811,7 @@ pub fn call_powf(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -833,7 +833,7 @@ pub fn call_powf_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -853,7 +853,7 @@ pub fn call_powf_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -872,7 +872,7 @@ pub fn call_elu(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
@ -881,7 +881,7 @@ pub fn call_elu(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -903,7 +903,7 @@ pub fn call_elu_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -923,7 +923,7 @@ pub fn call_elu_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -945,7 +945,7 @@ pub fn call_where_cond_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
@ -974,7 +974,7 @@ pub fn call_where_cond_strided(
|
|||||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1001,7 +1001,7 @@ pub fn call_index_select(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1024,7 +1024,7 @@ pub fn call_index_select(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1053,7 +1053,7 @@ pub fn call_gather(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1076,7 +1076,7 @@ pub fn call_gather(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1105,7 +1105,7 @@ pub fn call_scatter_add(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1128,7 +1128,7 @@ pub fn call_scatter_add(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1158,7 +1158,7 @@ pub fn call_index_add(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1182,7 +1182,7 @@ pub fn call_index_add(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1386,7 +1386,7 @@ pub fn call_gemm(
|
|||||||
let block_bytes = block_elements * bytes;
|
let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
@ -1430,7 +1430,7 @@ pub fn call_gemm(
|
|||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1455,7 +1455,7 @@ pub fn call_im2col1d_strided(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1475,7 +1475,7 @@ pub fn call_im2col1d_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1505,7 +1505,7 @@ pub fn call_im2col_strided(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1527,7 +1527,7 @@ pub fn call_im2col_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1553,7 +1553,7 @@ pub fn call_upsample_nearest_2d(
|
|||||||
let scale_h = shape[3] as f32 / out_h as f32;
|
let scale_h = shape[3] as f32 / out_h as f32;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1571,7 +1571,7 @@ pub fn call_upsample_nearest_2d(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1710,10 +1710,9 @@ pub fn call_quantized_matmul_t(
|
|||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
//encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
// println!("{b} {m} {n} {k}");
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(
|
||||||
@ -1744,7 +1743,7 @@ pub fn call_quantized_matmul_t(
|
|||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||||
encoder.update_fence(&kernels.fence);
|
//encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user