mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Adding some half kernels.
This commit is contained in:
@ -460,6 +460,46 @@ pub fn call_cast_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_cast_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(
|
||||
pipeline_state_descriptor.compute_function().unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, shape, input_strides, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_reduce_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -565,13 +605,14 @@ pub fn call_affine(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
@ -596,6 +637,7 @@ pub fn call_affine_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
@ -604,7 +646,7 @@ pub fn call_affine_strided(
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let func = kernels.load_function(device, Source::Affine, "affine_float_strided")?;
|
||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||
|
||||
|
Reference in New Issue
Block a user