mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
This commit is contained in:
@ -701,4 +701,32 @@ impl Storage {
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user