mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
This commit is contained in:
@ -422,6 +422,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
DType::BF16 => "powf_bf16",
|
||||
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -439,6 +440,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
DType::BF16 => "powf_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -471,6 +473,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
DType::BF16 => "elu_bf16",
|
||||
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -488,6 +491,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
DType::BF16 => "elu_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -1292,6 +1296,67 @@ impl BackendStorage for MetalStorage {
|
||||
))
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
if self.dtype() != dst.dtype() {
|
||||
crate::bail!(
|
||||
"copy2d with inconsistent dtypes {:?} {:?}",
|
||||
self.dtype(),
|
||||
dst.dtype()
|
||||
)
|
||||
}
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_s == d2 && dst_s == d2 {
|
||||
command_buffer.set_label("copy2d_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("copy2d_contiguous");
|
||||
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let el_count = d1 * d2;
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::copy2d::HALF,
|
||||
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
|
||||
DType::I64 => candle_metal_kernels::copy2d::I64,
|
||||
DType::U32 => candle_metal_kernels::copy2d::U32,
|
||||
DType::U8 => candle_metal_kernels::copy2d::U8,
|
||||
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_copy2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
&self.buffer,
|
||||
&dst.buffer,
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
src_o * self.dtype.size_in_bytes(),
|
||||
dst_o * self.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy2d");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
|
Reference in New Issue
Block a user