diff --git a/src/error.rs b/src/error.rs index 723edaa1..cb302abd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,9 @@ pub enum Error { #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, + #[error("{op} expects at least one tensor")] + OpRequiresAtLeastOneTensor { op: &'static str }, + #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, @@ -24,6 +27,14 @@ pub enum Error { op: &'static str, }, + #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")] + ShapeMismatchCat { + dim: usize, + first_shape: Shape, + n: usize, + nth_shape: Shape, + }, + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { lhs: DeviceLocation, diff --git a/src/op.rs b/src/op.rs index 6e909a35..6642ba2d 100644 --- a/src/op.rs +++ b/src/op.rs @@ -8,6 +8,8 @@ pub(crate) enum Op { Div(Tensor, Tensor), Matmul(Tensor, Tensor), + Cat(Vec, usize), + #[allow(dead_code)] // add is currently unused. Affine { arg: Tensor, diff --git a/src/storage.rs b/src/storage.rs index 9f8cd2d5..55934064 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -147,4 +147,15 @@ impl Storage { }), } } + + // self, the source can be strided whereas dst is contiguous. + pub(crate) fn copy_strided_src( + &self, + _dst: &mut Self, + _shape: &Shape, + _stride: &[usize], + _offset: usize, + ) { + todo!() + } } diff --git a/src/tensor.rs b/src/tensor.rs index 53665ced..8467f099 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -575,6 +575,92 @@ impl Tensor { } } + pub fn cat(args: &[Self], dim: usize) -> Result { + if args.is_empty() { + return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); + } + let rank = args[0].rank(); + if dim >= rank { + return Err(Error::UnexpectedNumberOfDims { + expected: (dim + 1), + got: rank, + shape: args[0].shape().clone(), + }); + } + let device = args[0].device(); + let dtype = args[0].dtype(); + let first_dims = args[0].shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[dim] = 0; + let mut offsets = vec![0usize]; + for (arg_idx, arg) in args.iter().enumerate() { + if arg.dtype() != dtype { + // TODO: Improve the error message. + return Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + }); + } + if arg.device().location() != device.location() { + // TODO: Improve the error message. + return Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + }); + } + let mut mismatch = arg.rank() != rank; + for (dim_idx, (v1, v2)) in args[0] + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim == dim_idx { + cat_dims[dim] += v2; + } + if dim != dim_idx && v1 != v2 { + // TODO: It would probably be good to have a nicer error message here, i.e. + // mention the problematic dimension and the values. + mismatch = true; + } + } + if mismatch { + return Err(Error::ShapeMismatchCat { + dim, + first_shape: args[0].shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + }); + } + let next_offset = offsets.last().unwrap() + arg.elem_count(); + offsets.push(next_offset); + } + let shape = Shape::from(cat_dims); + let stride = shape.stride_contiguous(); + let op = if args.iter().any(|arg| arg.track_op()) { + Some(Op::Cat(args.to_vec(), dim)) + } else { + None + }; + let mut storage = device.zeros(&shape, dtype)?; + for (arg, &offset) in args.iter().zip(offsets.iter()) { + arg.storage + .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset) + } + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape, + stride, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the /// argument. @@ -608,6 +694,11 @@ impl Tensor { track_grad |= tg; nodes } + Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + }), Op::Affine { arg, mul, .. } => { if *mul == 0. { nodes @@ -697,6 +788,9 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::Cat(_args, _dim) => { + todo!() + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?;