Add the cat operator (without the storage implementation for now).

This commit is contained in:
laurent
2023-06-23 10:13:37 +01:00
parent bf9e1d1c23
commit 6110db31c9
4 changed files with 118 additions and 0 deletions

View File

@ -9,6 +9,9 @@ pub enum Error {
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
@ -24,6 +27,14 @@ pub enum Error {
op: &'static str,
},
#[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
ShapeMismatchCat {
dim: usize,
first_shape: Shape,
n: usize,
nth_shape: Shape,
},
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
lhs: DeviceLocation,

View File

@ -8,6 +8,8 @@ pub(crate) enum Op {
Div(Tensor, Tensor),
Matmul(Tensor, Tensor),
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
Affine {
arg: Tensor,

View File

@ -147,4 +147,15 @@ impl Storage {
}),
}
}
// self, the source can be strided whereas dst is contiguous.
pub(crate) fn copy_strided_src(
&self,
_dst: &mut Self,
_shape: &Shape,
_stride: &[usize],
_offset: usize,
) {
todo!()
}
}

View File

@ -575,6 +575,92 @@ impl Tensor {
}
}
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}
let rank = args[0].rank();
if dim >= rank {
return Err(Error::UnexpectedNumberOfDims {
expected: (dim + 1),
got: rank,
shape: args[0].shape().clone(),
});
}
let device = args[0].device();
let dtype = args[0].dtype();
let first_dims = args[0].shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[dim] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
if arg.dtype() != dtype {
// TODO: Improve the error message.
return Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
});
}
if arg.device().location() != device.location() {
// TODO: Improve the error message.
return Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
});
}
let mut mismatch = arg.rank() != rank;
for (dim_idx, (v1, v2)) in args[0]
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim == dim_idx {
cat_dims[dim] += v2;
}
if dim != dim_idx && v1 != v2 {
// TODO: It would probably be good to have a nicer error message here, i.e.
// mention the problematic dimension and the values.
mismatch = true;
}
}
if mismatch {
return Err(Error::ShapeMismatchCat {
dim,
first_shape: args[0].shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
});
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let stride = shape.stride_contiguous();
let op = if args.iter().any(|arg| arg.track_op()) {
Some(Op::Cat(args.to_vec(), dim))
} else {
None
};
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
arg.storage
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)
}
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
@ -608,6 +694,11 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
let (tg, nodes) = walk(arg, nodes, already_seen);
track_grad |= tg;
nodes
}),
Op::Affine { arg, mul, .. } => {
if *mul == 0. {
nodes
@ -697,6 +788,9 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Cat(_args, _dim) => {
todo!()
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;