mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Fix comments.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||
@ -156,14 +155,6 @@ pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
}
|
||||
|
||||
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||
// let mut l = HashMap::new();
|
||||
// l.insert("affine", AFFINE);
|
||||
// l.insert("indexing", INDEXING);
|
||||
// l.insert("unary", UNARY);
|
||||
// l
|
||||
// });
|
||||
//
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalKernelError {
|
||||
#[error("Could not lock kernel map: {0}")]
|
||||
@ -197,21 +188,7 @@ impl Kernels {
|
||||
Self { libraries, funcs }
|
||||
}
|
||||
|
||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||
// let kernels = Self::new();
|
||||
// kernels.load_libraries(device)?;
|
||||
// Ok(kernels)
|
||||
// }
|
||||
|
||||
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||
// for name in LIBRARY_SOURCES.keys() {
|
||||
// self.load_library(device, name)?;
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
// LIBRARY_SOURCES.get(name).cloned()
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
@ -261,6 +238,7 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -270,8 +248,6 @@ pub fn call_unary_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
@ -292,6 +268,8 @@ pub fn call_unary_contiguous(
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -339,6 +317,7 @@ pub fn call_unary_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_binary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -349,8 +328,6 @@ pub fn call_binary_contiguous(
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
@ -373,6 +350,7 @@ pub fn call_binary_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_binary_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -425,6 +403,7 @@ pub fn call_binary_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_cast_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -434,8 +413,6 @@ pub fn call_cast_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
@ -458,6 +435,7 @@ pub fn call_cast_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_reduce_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -508,6 +486,7 @@ pub fn call_reduce_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_last_softmax(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -543,7 +522,6 @@ pub fn call_last_softmax(
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
// (elements_to_sum as u64 + 2 - 1) / 2,
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
@ -559,6 +537,7 @@ pub fn call_last_softmax(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -590,6 +569,7 @@ pub fn call_affine(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_where_cond_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -643,6 +623,7 @@ pub fn call_where_cond_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_index_select(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -813,7 +794,6 @@ mod tests {
|
||||
#[test]
|
||||
fn cos_f32_strided() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
// Shape = [6], strides = [1];
|
||||
let shape = vec![6];
|
||||
let strides = vec![1];
|
||||
let offset = 0;
|
||||
|
Reference in New Issue
Block a user