mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
This commit is contained in:
@ -666,7 +666,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
@ -2149,152 +2149,6 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
// TODO: Avoid these transpositions and have an implementation that works
|
||||
// for dim != 0...
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user