Adding some half kernels.

This commit is contained in:
Nicolas Patry
2023-11-11 13:30:21 +01:00
parent e02f1912bb
commit 54355ff997
4 changed files with 165 additions and 12 deletions

View File

@ -153,11 +153,16 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
assert_eq!(dtype, DType::F32);
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
dtype => todo!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
name,
el,
&self.buffer,
&mut buffer,
@ -166,11 +171,16 @@ impl BackendStorage for MetalStorage {
)
.unwrap();
} else {
assert_eq!(dtype, DType::F32);
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
dtype => todo!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.dims(),
&self.buffer,
layout.stride(),
@ -273,6 +283,8 @@ impl BackendStorage for MetalStorage {
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
@ -286,11 +298,24 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
@ -327,6 +352,20 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@ -340,7 +379,51 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
// TODO erf does not exist in metal
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
0,
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
@ -378,6 +461,14 @@ impl BackendStorage for MetalStorage {
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF,
("badd", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF,
("bsub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
("bmul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
@ -399,6 +490,10 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
@ -555,6 +650,7 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_buffer();
@ -601,8 +697,20 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
// Create descriptors
use metal::mps::matrix::*;
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
let size = core::mem::size_of::<f32>() as NSUInteger;
assert_eq!(self.dtype, rhs.dtype);
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
),
dtype => todo!("Implement matmul {dtype:?}"),
};
let elem_count = b * m * n;