Adding some half kernels.

This commit is contained in:
Nicolas Patry
2023-11-11 13:30:21 +01:00
parent e02f1912bb
commit 54355ff997
4 changed files with 165 additions and 12 deletions

View File

@ -153,11 +153,16 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
assert_eq!(dtype, DType::F32);
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
dtype => todo!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
name,
el,
&self.buffer,
&mut buffer,
@ -166,11 +171,16 @@ impl BackendStorage for MetalStorage {
)
.unwrap();
} else {
assert_eq!(dtype, DType::F32);
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
dtype => todo!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.dims(),
&self.buffer,
layout.stride(),
@ -273,6 +283,8 @@ impl BackendStorage for MetalStorage {
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
@ -286,11 +298,24 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
@ -327,6 +352,20 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@ -340,7 +379,51 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
// TODO erf does not exist in metal
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
0,
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
@ -378,6 +461,14 @@ impl BackendStorage for MetalStorage {
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF,
("badd", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF,
("bsub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
("bmul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
@ -399,6 +490,10 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
@ -555,6 +650,7 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_buffer();
@ -601,8 +697,20 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
// Create descriptors
use metal::mps::matrix::*;
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
let size = core::mem::size_of::<f32>() as NSUInteger;
assert_eq!(self.dtype, rhs.dtype);
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
),
dtype => todo!("Implement matmul {dtype:?}"),
};
let elem_count = b * m * n;

View File

@ -46,6 +46,8 @@ kernel void FN_NAME_STRIDED( \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -75,6 +75,7 @@ kernel void FN_NAME( \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310

View File

@ -460,6 +460,46 @@ pub fn call_cast_contiguous(
Ok(())
}
pub fn call_cast_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
input: &Buffer,
input_strides: &[usize],
input_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(
encoder,
(length, shape, input_strides, (input, input_offset), output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -565,13 +605,14 @@ pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -596,6 +637,7 @@ pub fn call_affine_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input_stride: &[usize],
@ -604,7 +646,7 @@ pub fn call_affine_strided(
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float_strided")?;
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));