mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Cleanup the fence.
This commit is contained in:
@ -84,13 +84,8 @@ pub struct MetalDevice {
|
|||||||
command_buffer_index: Arc<RwLock<usize>>,
|
command_buffer_index: Arc<RwLock<usize>>,
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
compute_per_buffer: usize,
|
compute_per_buffer: usize,
|
||||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
|
||||||
/// execution order to be linear.
|
|
||||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
|
||||||
/// compute graph.
|
|
||||||
// fence: metal::Fence,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
/// Simple allocator struct.
|
/// Simple allocator struct.
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
@ -131,10 +126,6 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub(crate) fn fence(&self) -> &metal::Fence {
|
|
||||||
// &self.fence
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
@ -225,10 +216,8 @@ impl MetalDevice {
|
|||||||
let command_buffer = self.command_buffer()?;
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("with_data");
|
command_buffer.set_label("with_data");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
// blit.wait_for_fence(&self.fence);
|
|
||||||
blit.set_label("with_data_blit");
|
blit.set_label("with_data_blit");
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||||
// blit.update_fence(&self.fence);
|
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
|
|
||||||
// This is necessary, for mmaped safetensors
|
// This is necessary, for mmaped safetensors
|
||||||
@ -251,7 +240,6 @@ impl MetalDevice {
|
|||||||
let command_buffer = self.command_buffer()?;
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("zeros");
|
command_buffer.set_label("zeros");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
// blit.wait_for_fence(&self.fence);
|
|
||||||
blit.fill_buffer(
|
blit.fill_buffer(
|
||||||
&buffer,
|
&buffer,
|
||||||
metal::NSRange {
|
metal::NSRange {
|
||||||
@ -260,7 +248,6 @@ impl MetalDevice {
|
|||||||
},
|
},
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
// blit.update_fence(&self.fence);
|
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
@ -1543,9 +1530,7 @@ impl MetalStorage {
|
|||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
// blit.wait_for_fence(&self.device.fence);
|
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
// blit.update_fence(&self.device.fence);
|
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
@ -1563,7 +1548,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
command_buffer.enqueue();
|
command_buffer.enqueue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||||
// let fence = device.new_fence();
|
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
@ -1572,7 +1556,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
};
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
// fence,
|
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
command_buffer_index,
|
command_buffer_index,
|
||||||
|
@ -32,9 +32,7 @@ impl QMetalStorage {
|
|||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
// blit.wait_for_fence(&self.device.fence());
|
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
// blit.update_fence(&self.device.fence());
|
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
|
@ -219,7 +219,6 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
|||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
pipelines: RwLock<Pipelines>,
|
pipelines: RwLock<Pipelines>,
|
||||||
// fence: metal::Fence,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
@ -229,7 +228,6 @@ impl Kernels {
|
|||||||
Self {
|
Self {
|
||||||
libraries,
|
libraries,
|
||||||
pipelines,
|
pipelines,
|
||||||
// fence,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,7 +348,6 @@ pub fn call_unary_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, input, output));
|
set_params!(encoder, (length, input, output));
|
||||||
@ -359,7 +356,6 @@ pub fn call_unary_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -381,7 +377,6 @@ pub fn call_unary_strided(
|
|||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -403,7 +398,6 @@ pub fn call_unary_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -422,7 +416,6 @@ pub fn call_binary_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, left, right, output));
|
set_params!(encoder, (length, left, right, output));
|
||||||
@ -433,7 +426,6 @@ pub fn call_binary_contiguous(
|
|||||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -458,7 +450,6 @@ pub fn call_binary_strided(
|
|||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -483,7 +474,6 @@ pub fn call_binary_strided(
|
|||||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -502,7 +492,6 @@ pub fn call_cast_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, (input, input_offset), output));
|
||||||
@ -511,7 +500,6 @@ pub fn call_cast_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -531,7 +519,6 @@ pub fn call_cast_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -553,7 +540,6 @@ pub fn call_cast_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -573,7 +559,6 @@ pub fn call_reduce_contiguous(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -602,7 +587,6 @@ pub fn call_reduce_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -624,7 +608,6 @@ pub fn call_reduce_strided(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -660,7 +643,6 @@ pub fn call_reduce_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -679,7 +661,6 @@ pub fn call_last_softmax(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -710,7 +691,6 @@ pub fn call_last_softmax(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -730,7 +710,6 @@ pub fn call_affine(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, add, input, output));
|
set_params!(encoder, (size, mul, add, input, output));
|
||||||
@ -739,7 +718,6 @@ pub fn call_affine(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -762,7 +740,6 @@ pub fn call_affine_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -783,7 +760,6 @@ pub fn call_affine_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -802,7 +778,6 @@ pub fn call_powf(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
@ -811,7 +786,6 @@ pub fn call_powf(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -833,7 +807,6 @@ pub fn call_powf_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -853,7 +826,6 @@ pub fn call_powf_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -872,7 +844,6 @@ pub fn call_elu(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
@ -881,7 +852,6 @@ pub fn call_elu(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -903,7 +873,6 @@ pub fn call_elu_strided(
|
|||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -923,7 +892,6 @@ pub fn call_elu_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -945,7 +913,6 @@ pub fn call_where_cond_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
@ -974,7 +941,6 @@ pub fn call_where_cond_strided(
|
|||||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1001,7 +967,6 @@ pub fn call_index_select(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1024,7 +989,6 @@ pub fn call_index_select(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1053,7 +1017,6 @@ pub fn call_gather(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1076,7 +1039,6 @@ pub fn call_gather(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1105,7 +1067,6 @@ pub fn call_scatter_add(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1128,7 +1089,6 @@ pub fn call_scatter_add(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1158,7 +1118,6 @@ pub fn call_index_add(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1182,7 +1141,6 @@ pub fn call_index_add(
|
|||||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1386,7 +1344,6 @@ pub fn call_gemm(
|
|||||||
let block_bytes = block_elements * bytes;
|
let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
@ -1430,7 +1387,6 @@ pub fn call_gemm(
|
|||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1455,7 +1411,6 @@ pub fn call_im2col1d_strided(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1475,7 +1430,6 @@ pub fn call_im2col1d_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1505,7 +1459,6 @@ pub fn call_im2col_strided(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1527,7 +1480,6 @@ pub fn call_im2col_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1553,7 +1505,6 @@ pub fn call_upsample_nearest_2d(
|
|||||||
let scale_h = shape[3] as f32 / out_h as f32;
|
let scale_h = shape[3] as f32 / out_h as f32;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1571,7 +1522,6 @@ pub fn call_upsample_nearest_2d(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1710,7 +1660,6 @@ pub fn call_quantized_matmul_t(
|
|||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
//encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1743,7 +1692,6 @@ pub fn call_quantized_matmul_t(
|
|||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||||
//encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -37,8 +37,7 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
|||||||
|
|
||||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
@ -60,8 +59,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|||||||
|
|
||||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
@ -96,8 +94,7 @@ fn run_strided<T: Clone>(
|
|||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let output = new_buffer(&device, v);
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
call_unary_strided(
|
call_unary_strided(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -278,8 +275,7 @@ fn binary_ops_bf16() {
|
|||||||
|
|
||||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
@ -409,8 +405,7 @@ fn it_cast_f16_bf16() {
|
|||||||
|
|
||||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
@ -445,8 +440,7 @@ fn run_affine_strided<T: Clone>(
|
|||||||
add: f64,
|
add: f64,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
@ -595,8 +589,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
let dst_el = ids.len() * left_size * right_size;
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
call_index_select(
|
call_index_select(
|
||||||
&device,
|
&device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -631,8 +624,7 @@ fn cos_f16() {
|
|||||||
|
|
||||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
@ -662,8 +654,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
|
|
||||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
@ -782,8 +773,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
@ -859,8 +849,7 @@ fn run_gemm<T: Clone>(
|
|||||||
rhs_offset: usize,
|
rhs_offset: usize,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let kernels = Kernels::new();
|
||||||
let kernels = Kernels::new(fence);
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
Reference in New Issue
Block a user