Fix comments.

This commit is contained in:
Nicolas Patry
2023-11-20 14:00:39 +01:00
parent bd3b243725
commit c66e5d4716
5 changed files with 66 additions and 104 deletions

View File

@ -1,4 +1,3 @@
#![allow(clippy::too_many_arguments)]
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
ComputePipelineState, Device, Function, Library, MTLSize,
@ -156,14 +155,6 @@ pub mod binary {
ops!(add, sub, mul, div);
}
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
// let mut l = HashMap::new();
// l.insert("affine", AFFINE);
// l.insert("indexing", INDEXING);
// l.insert("unary", UNARY);
// l
// });
//
#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Could not lock kernel map: {0}")]
@ -197,21 +188,7 @@ impl Kernels {
Self { libraries, funcs }
}
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
// let kernels = Self::new();
// kernels.load_libraries(device)?;
// Ok(kernels)
// }
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
// for name in LIBRARY_SOURCES.keys() {
// self.load_library(device, name)?;
// }
// Ok(())
// }
fn get_library_source(&self, source: Source) -> &'static str {
// LIBRARY_SOURCES.get(name).cloned()
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
@ -261,6 +238,7 @@ impl Kernels {
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -270,8 +248,6 @@ pub fn call_unary_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -292,6 +268,8 @@ pub fn call_unary_contiguous(
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -339,6 +317,7 @@ pub fn call_unary_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -349,8 +328,6 @@ pub fn call_binary_contiguous(
right: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -373,6 +350,7 @@ pub fn call_binary_contiguous(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -425,6 +403,7 @@ pub fn call_binary_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_cast_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -434,8 +413,6 @@ pub fn call_cast_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -458,6 +435,7 @@ pub fn call_cast_contiguous(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -508,6 +486,7 @@ pub fn call_reduce_contiguous(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_last_softmax(
device: &Device,
command_buffer: &CommandBufferRef,
@ -543,7 +522,6 @@ pub fn call_last_softmax(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
// (elements_to_sum as u64 + 2 - 1) / 2,
elements_to_sum as u64,
)
.next_power_of_two();
@ -559,6 +537,7 @@ pub fn call_last_softmax(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
@ -590,6 +569,7 @@ pub fn call_affine(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -643,6 +623,7 @@ pub fn call_where_cond_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
@ -813,7 +794,6 @@ mod tests {
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Shape = [6], strides = [1];
let shape = vec![6];
let strides = vec![1];
let offset = 0;