mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the cat operator (without the storage implementation for now).
This commit is contained in:
@ -575,6 +575,92 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
}
|
||||
let rank = args[0].rank();
|
||||
if dim >= rank {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: (dim + 1),
|
||||
got: rank,
|
||||
shape: args[0].shape().clone(),
|
||||
});
|
||||
}
|
||||
let device = args[0].device();
|
||||
let dtype = args[0].dtype();
|
||||
let first_dims = args[0].shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[dim] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
if arg.dtype() != dtype {
|
||||
// TODO: Improve the error message.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
});
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
// TODO: Improve the error message.
|
||||
return Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
});
|
||||
}
|
||||
let mut mismatch = arg.rank() != rank;
|
||||
for (dim_idx, (v1, v2)) in args[0]
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim == dim_idx {
|
||||
cat_dims[dim] += v2;
|
||||
}
|
||||
if dim != dim_idx && v1 != v2 {
|
||||
// TODO: It would probably be good to have a nicer error message here, i.e.
|
||||
// mention the problematic dimension and the values.
|
||||
mismatch = true;
|
||||
}
|
||||
}
|
||||
if mismatch {
|
||||
return Err(Error::ShapeMismatchCat {
|
||||
dim,
|
||||
first_shape: args[0].shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
});
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let stride = shape.stride_contiguous();
|
||||
let op = if args.iter().any(|arg| arg.track_op()) {
|
||||
Some(Op::Cat(args.to_vec(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
arg.storage
|
||||
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)
|
||||
}
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape,
|
||||
stride,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
/// argument.
|
||||
@ -608,6 +694,11 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
|
||||
let (tg, nodes) = walk(arg, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}),
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
if *mul == 0. {
|
||||
nodes
|
||||
@ -697,6 +788,9 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(_args, _dim) => {
|
||||
todo!()
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user