mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Adding some half kernels.
This commit is contained in:
@ -153,11 +153,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
let mut buffer = device.new_buffer(el, self.dtype);
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
assert_eq!(dtype, DType::F32);
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "affine_float",
|
||||||
|
DType::F16 => "affine_half",
|
||||||
|
dtype => todo!("Affine {dtype:?}"),
|
||||||
|
};
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&mut buffer,
|
&mut buffer,
|
||||||
@ -166,11 +171,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
} else {
|
} else {
|
||||||
assert_eq!(dtype, DType::F32);
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "affine_float",
|
||||||
|
DType::F16 => "affine_half",
|
||||||
|
dtype => todo!("Affine {dtype:?}"),
|
||||||
|
};
|
||||||
candle_metal_kernels::call_affine_strided(
|
candle_metal_kernels::call_affine_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
|
name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
@ -273,6 +283,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
|
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||||
|
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_contiguous(
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
@ -286,11 +298,24 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
todo!(
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||||
self.dtype,
|
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||||
dtype
|
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||||
);
|
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_cast_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
layout.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
layout.stride(),
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// command_buffer.commit();
|
// command_buffer.commit();
|
||||||
@ -327,6 +352,20 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
|
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||||
|
("usin", DType::F16) => contiguous::sin::HALF,
|
||||||
|
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||||
|
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||||
|
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||||
|
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||||
|
("ulog", DType::F16) => contiguous::log::HALF,
|
||||||
|
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||||
|
// TODO erf does not exist in metal
|
||||||
|
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
@ -340,7 +379,51 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
use candle_metal_kernels::unary::strided;
|
||||||
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
|
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||||
|
("usin", DType::F32) => strided::sin::FLOAT,
|
||||||
|
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||||
|
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||||
|
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||||
|
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||||
|
("ulog", DType::F32) => strided::log::FLOAT,
|
||||||
|
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||||
|
// TODO erf does not exist in metal
|
||||||
|
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||||
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
|
("ucos", DType::F16) => strided::cos::HALF,
|
||||||
|
("usin", DType::F16) => strided::sin::HALF,
|
||||||
|
("usqr", DType::F16) => strided::sqr::HALF,
|
||||||
|
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||||
|
("uneg", DType::F16) => strided::neg::HALF,
|
||||||
|
("uexp", DType::F16) => strided::exp::HALF,
|
||||||
|
("ulog", DType::F16) => strided::log::HALF,
|
||||||
|
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||||
|
// TODO erf does not exist in metal
|
||||||
|
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||||
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
layout.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
layout.stride(),
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&mut buffer,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
// command_buffer.commit();
|
// command_buffer.commit();
|
||||||
// command_buffer.wait_until_scheduled();
|
// command_buffer.wait_until_scheduled();
|
||||||
@ -378,6 +461,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
("div", DType::F32) => contiguous::div::FLOAT,
|
("div", DType::F32) => contiguous::div::FLOAT,
|
||||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||||
|
("add", DType::F16) => contiguous::add::HALF,
|
||||||
|
("badd", DType::F16) => contiguous::add::HALF,
|
||||||
|
("sub", DType::F16) => contiguous::sub::HALF,
|
||||||
|
("bsub", DType::F16) => contiguous::sub::HALF,
|
||||||
|
("mul", DType::F16) => contiguous::mul::HALF,
|
||||||
|
("bmul", DType::F16) => contiguous::mul::HALF,
|
||||||
|
("div", DType::F16) => contiguous::div::HALF,
|
||||||
|
("bdiv", DType::F16) => contiguous::div::HALF,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
@ -399,6 +490,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||||
|
("badd", DType::F16) => strided::add::HALF,
|
||||||
|
("bsub", DType::F16) => strided::sub::HALF,
|
||||||
|
("bmul", DType::F16) => strided::mul::HALF,
|
||||||
|
("bdiv", DType::F16) => strided::div::HALF,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
@ -555,6 +650,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
@ -601,8 +697,20 @@ impl BackendStorage for MetalStorage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// Create descriptors
|
// Create descriptors
|
||||||
use metal::mps::matrix::*;
|
use metal::mps::matrix::*;
|
||||||
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
|
||||||
let size = core::mem::size_of::<f32>() as NSUInteger;
|
assert_eq!(self.dtype, rhs.dtype);
|
||||||
|
|
||||||
|
let (type_id, size) = match self.dtype {
|
||||||
|
DType::F32 => (
|
||||||
|
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||||
|
core::mem::size_of::<f32>() as NSUInteger,
|
||||||
|
),
|
||||||
|
DType::F16 => (
|
||||||
|
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||||
|
core::mem::size_of::<f16>() as NSUInteger,
|
||||||
|
),
|
||||||
|
dtype => todo!("Implement matmul {dtype:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
|
|
||||||
|
@ -46,6 +46,8 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
|
||||||
|
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||||
|
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
#endif
|
#endif
|
||||||
|
@ -75,6 +75,7 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
@ -460,6 +460,46 @@ pub fn call_cast_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_cast_strided(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
input: &Buffer,
|
||||||
|
input_strides: &[usize],
|
||||||
|
input_offset: usize,
|
||||||
|
output: &mut Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
|
// assert_eq!(input.length(), output.length());
|
||||||
|
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
let length: usize = shape.iter().product();
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(length, shape, input_strides, (input, input_offset), output)
|
||||||
|
);
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn call_reduce_contiguous(
|
pub fn call_reduce_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -565,13 +605,14 @@ pub fn call_affine(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
@ -596,6 +637,7 @@ pub fn call_affine_strided(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
@ -604,7 +646,7 @@ pub fn call_affine_strided(
|
|||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Affine, "affine_float_strided")?;
|
let func = kernels.load_function(device, Source::Affine, name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user