use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; use std::sync::Arc; /// Unique identifier for tensors. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct TensorId(usize); impl TensorId { fn new() -> Self { // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 use std::sync::atomic; static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) } } pub struct Tensor_ { id: TensorId, storage: Arc, shape: Shape, // The strides are given in number of elements and not in bytes. stride: Vec, op: Option, is_variable: bool, } impl AsRef for Tensor { fn as_ref(&self) -> &Tensor { self } } // Tensors are refcounted so that cloning is cheap when building the op graph. // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. #[derive(Clone)] pub struct Tensor(Arc); impl std::ops::Deref for Tensor { type Target = Tensor_; fn deref(&self) -> &Self::Target { self.0.as_ref() } } impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) } } macro_rules! unary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self) -> Result { let shape = self.shape(); let storage = self .storage .unary_impl::(self.shape(), self.stride())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) } else { None }; Ok(from_storage(storage, shape.clone(), op, false)) } }; } macro_rules! binary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; let storage = self.storage.binary_impl::( &rhs.storage, shape, self.stride(), rhs.stride(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::$op_name(self.clone(), rhs.clone())) } else { None }; Ok(from_storage(storage, shape.clone(), op, false)) } }; } macro_rules! broadcast_binary_op { ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { let lhs = self; let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?; let l_broadcast = shape != *lhs.shape(); let r_broadcast = shape != *rhs.shape(); match (l_broadcast, r_broadcast) { (true, true) => lhs .broadcast_as(&shape)? .$inner_fn_name(&rhs.broadcast_as(&shape)?), (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?), (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs), (false, false) => lhs.$inner_fn_name(rhs), } } }; } /// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. fn from_storage>( storage: Storage, shape: S, op: Option, is_variable: bool, ) -> Tensor { let shape = shape.into(); let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), shape, stride, op, is_variable, }; Tensor(Arc::new(tensor_)) } impl Tensor { fn ones_impl>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result { let shape = shape.into(); let storage = device.ones(&shape, dtype)?; Ok(from_storage(storage, shape, None, is_variable)) } pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { Self::ones_impl(shape, dtype, device, false) } pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { Self::ones_impl(shape, dtype, device, true) } pub fn ones_like(&self) -> Result { Tensor::ones(self.shape(), self.dtype(), &self.device()) } fn zeros_impl>( shape: S, dtype: DType, device: &Device, is_variable: bool, ) -> Result { let shape = shape.into(); let storage = device.zeros(&shape, dtype)?; Ok(from_storage(storage, shape, None, is_variable)) } pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { Self::zeros_impl(shape, dtype, device, false) } pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { Self::zeros_impl(shape, dtype, device, true) } pub fn zeros_like(&self) -> Result { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } pub fn new_impl( array: A, shape: Shape, device: &Device, is_variable: bool, ) -> Result { let n: usize = shape.elem_count(); let buffer_size: usize = array.shape()?.elem_count(); if buffer_size != n { return Err(Error::ShapeMismatch { buffer_size, shape }); } let storage = device.storage(array)?; Ok(from_storage(storage, shape, None, is_variable)) } pub fn new(array: A, device: &Device) -> Result { let shape = array.shape()?; Self::new_impl(array, shape, device, false) } pub fn var(array: A, device: &Device) -> Result { let shape = array.shape()?; Self::new_impl(array, shape, device, true) } pub fn from_vec_impl, D: crate::WithDType>( data: Vec, shape: S, device: &Device, is_variable: bool, ) -> Result { let shape = shape.into(); let buffer_size = data.len(); if buffer_size != shape.elem_count() { return Err(Error::ShapeMismatch { buffer_size, shape }); } let storage = device.storage_owned(data)?; Ok(from_storage(storage, shape, None, is_variable)) } pub fn from_vec, D: crate::WithDType>( data: Vec, shape: S, device: &Device, ) -> Result { Self::from_vec_impl(data, shape, device, false) } pub fn var_from_vec, D: crate::WithDType>( data: Vec, shape: S, device: &Device, ) -> Result { Self::from_vec_impl(data, shape, device, true) } pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S, device: &Device, ) -> Result { Self::new_impl(array, shape.into(), device, false) } pub fn var_from_slice, D: crate::WithDType>( array: &[D], shape: S, device: &Device, ) -> Result { Self::new_impl(array, shape.into(), device, true) } pub(crate) fn broadcast_shape_binary_op<'a>( &'a self, rhs: &'a Self, op: &'static str, ) -> Result { let lhs = self; let lhs_dims = lhs.shape().dims(); let rhs_dims = rhs.shape().dims(); let lhs_ndims = lhs_dims.len(); let rhs_ndims = rhs_dims.len(); let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); let mut bcast_dims = vec![0; bcast_ndims]; for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { let rev_idx = bcast_ndims - idx; let l_value = if lhs_ndims < rev_idx { 1 } else { lhs_dims[lhs_ndims - rev_idx] }; let r_value = if rhs_ndims < rev_idx { 1 } else { rhs_dims[rhs_ndims - rev_idx] }; *bcast_value = if l_value == r_value { l_value } else if l_value == 1 { r_value } else if r_value == 1 { l_value } else { Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op, })? } } Ok(Shape::from(bcast_dims)) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); if lhs != rhs { Err(Error::ShapeMismatchBinaryOp { lhs: lhs.clone(), rhs: rhs.clone(), op, }) } else { Ok(lhs) } } /// Returns true if the computation graph should track this op, that is if it is /// a variable or if it has some variable as dependencies. pub(crate) fn track_op(&self) -> bool { self.is_variable || self.op.is_some() } // TODO: Also make an inplace version or a pre-allocated? This could be tricky // if this can create cycles in the compute graph. binary_op!(add, Add); binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_sub, sub); broadcast_binary_op!(broadcast_div, div); unary_op!(neg, Neg); unary_op!(exp, Exp); unary_op!(log, Log); unary_op!(sin, Sin); unary_op!(cos, Cos); unary_op!(abs, Abs); unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { expected: 0, got: self.rank(), shape: self.shape().clone(), }); } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok::<_, Error>(data[0]) }; match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } pub fn affine(&self, mul: f64, add: f64) -> Result { let shape = self.shape(); let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), mul, add, }) } else { None }; Ok(from_storage(storage, shape.clone(), op, false)) } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + length`. // TODO: Once we've refactored the shape and strides, make this return a view of the same data // rather than copying. pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { let dims = self.shape().dims(); if dim >= dims.len() { return Err(Error::UnexpectedNumberOfDims { expected: dim + 1, got: dims.len(), shape: self.shape().clone(), }); } if start + length > dims[dim] { todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") } let mut dims = dims.to_vec(); dims[dim] = length; let adjusted_shape = Shape::from(dims); let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?; self.storage.copy_strided_src( &mut storage, /* dst_offset= */ 0, &adjusted_shape, &self.stride, /* src_offest= */ self.stride[dim] * start, )?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { None }; Ok(from_storage(storage, adjusted_shape, op, false)) } pub fn softmax(&self, dim: usize) -> Result { // TODO: unify the two branches. if self.device().is_cuda() { // We do not have a cuda kernel for divide_by_sum_over_dim so split // the operation. let exp = self.exp()?; let sum_exp = exp.sum(&[dim])?; exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); let mut storage = self .storage .unary_impl::(shape, self.stride())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { Some(Op::Softmax(self.clone(), dim)) } else { None }; Ok(from_storage(storage, shape.clone(), op, false)) } } pub fn sum(&self, sum_dims: &[usize]) -> Result { let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { None }; let mut dims = self.dims().to_vec(); for &sum_dim in sum_dims.iter() { dims[sum_dim] = 1 } Ok(from_storage(storage, dims, op, false)) } pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); let dim = a_dims.len(); if dim < 2 || b_dims.len() != dim { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op: "matmul", }); } let m = a_dims[dim - 2]; let k = a_dims[dim - 1]; let k2 = b_dims[dim - 2]; let n = b_dims[dim - 1]; if k != k2 { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(), op: "matmul", }); } let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let batching: usize = a_dims[..dim - 2].iter().product(); let storage = self.storage.matmul_impl( &rhs.storage, (batching, m, n, k), self.stride(), rhs.stride(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::Matmul(self.clone(), rhs.clone())) } else { None }; Ok(from_storage(storage, c_shape, op, false)) } pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; let storage = self.storage.where_cond( shape, self.stride(), &on_true.storage, on_true.stride(), &on_false.storage, on_false.stride(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { Some(Op::WhereCond( self.clone(), on_true.clone(), on_false.clone(), )) } else { None }; Ok(from_storage(storage, shape, op, false)) } pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" }); } else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 { return Err(Error::ShapeMismatchBinaryOp { lhs: ids.shape.clone(), rhs: rhs.shape.clone(), op: "embedding", }); } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; let storage = ids.storage.embedding_impl( ids_shape, &ids.stride, &rhs.storage, hidden_size, vocab_size, )?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) } else { None }; Ok(from_storage(storage, shape, op, false)) } pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::new(self.dims(), self.stride()) } /// Returns data from the underlying storage, this does not take the strides /// into account so the size of the resulting buffer might be larger than the /// tensor number of elements. pub fn storage_data(&self) -> Result> { match self.storage.as_ref() { Storage::Cpu(cpu_storage) => { let slice = S::cpu_storage_as_slice(cpu_storage)?; Ok(std::borrow::Cow::Borrowed(slice)) } Storage::Cuda(slice) => { let cpu_storage = slice.to_cpu_storage()?; let storage_data = S::cpu_storage_data(cpu_storage)?; Ok(std::borrow::Cow::Owned(storage_data)) } } } pub fn to_vec1(&self) -> Result> { if self.rank() != 1 { return Err(Error::UnexpectedNumberOfDims { expected: 1, got: self.rank(), shape: self.shape().clone(), }); } match self.storage.as_ref() { Storage::Cpu(cpu_storage) => { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } Storage::Cuda(slice) => { // TODO: Would it be possible to only fetch the necessary data? let cpu_storage = slice.to_cpu_storage()?; let data = S::cpu_storage_as_slice(&cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } } } pub fn to_vec2(&self) -> Result>> { let (dim1, dim2) = self.shape().r2()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut rows = vec![]; let mut src_index = self.strided_index(); for _idx_row in 0..dim1 { let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); rows.push(row) } assert!(src_index.next().is_none()); Ok(rows) }; match self.storage.as_ref() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } pub fn to_vec3(&self) -> Result>>> { let (dim1, dim2, dim3) = self.shape().r3()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut top_rows = vec![]; let mut src_index = self.strided_index(); for _idx in 0..dim1 { let mut rows = vec![]; for _jdx in 0..dim2 { let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); rows.push(row) } top_rows.push(rows); } assert!(src_index.next().is_none()); Ok(top_rows) }; match self.storage.as_ref() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } pub fn dtype(&self) -> DType { self.storage.dtype() } pub fn device(&self) -> Device { self.storage.device() } pub fn shape(&self) -> &Shape { &self.shape } pub fn dims(&self) -> &[usize] { self.shape().dims() } pub fn stride(&self) -> &[usize] { &self.stride } pub fn rank(&self) -> usize { self.shape().rank() } pub fn elem_count(&self) -> usize { self.shape().elem_count() } pub fn id(&self) -> TensorId { self.id } pub fn is_variable(&self) -> bool { self.is_variable } pub(crate) fn op(&self) -> &Option { &self.op } /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result { let rank = self.rank(); if rank < 2 { return Err(Error::UnexpectedNumberOfDims { expected: 2, got: rank, shape: self.shape().clone(), }); } self.transpose(rank - 2, rank - 1) } /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { let rank = self.rank(); if rank <= dim1 || rank <= dim2 { return Err(Error::UnexpectedNumberOfDims { expected: usize::max(dim1, dim2), got: rank, shape: self.shape().clone(), }); } let mut stride = self.stride().to_vec(); let mut dims = self.shape().dims().to_vec(); dims.swap(dim1, dim2); stride.swap(dim1, dim2); let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { None }; let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), shape: Shape::from(dims), stride, op, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) } /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. pub fn is_fortran_contiguous(&self) -> bool { self.shape.is_fortran_contiguous(&self.stride) } /// Compared to clone, this copies the actual storage but may fail because of running out of /// memory. pub fn copy(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(self.storage.try_clone()?), shape: self.shape.clone(), stride: self.stride.clone(), op: None, // TODO is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } /// Returns a new tensor detached from the current graph, gradient are not propagated through /// this new node. pub fn detach(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), shape: self.shape.clone(), stride: self.stride.clone(), op: None, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } /// If the target device is the same as the tensor device, only a shallow copy is performed. pub fn to_device(&self, device: &Device) -> Result { if self.device().same_id(device) { Ok(self.clone()) } else { let storage = match (self.storage.as_ref(), device) { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. let cpu_storage = storage.to_cpu_storage()?; Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?) } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), }; let op = if self.track_op() { Some(Op::ToDevice(self.clone())) } else { None }; let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), shape: self.shape.clone(), stride: self.stride.clone(), op, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } } /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted /// on the left. pub fn broadcast_left>(&self, left_shape: S) -> Result { let left_shape = left_shape.into(); let mut dims = left_shape.into_dims(); dims.extend(self.shape.dims()); self.broadcast_as(dims) } pub fn broadcast_as>(&self, shape: S) -> Result { let op = if self.track_op() { Some(Op::Broadcast(self.clone())) } else { None }; let shape = shape.into(); if shape.rank() < self.rank() { return Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), dst_shape: shape, }); } let added_dims = shape.rank() - self.rank(); let mut stride = vec![0; added_dims]; for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] .iter() .zip(self.dims().iter().zip(self.stride())) { let s = if dst_dim == src_dim { src_stride } else if src_dim != 1 { return Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), dst_shape: shape, }); } else { 0 }; stride.push(s) } let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), shape, stride, op, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } /// An alias for broadcast_as. pub fn expand>(&self, shape: S) -> Result { self.broadcast_as(shape) } pub fn to_dtype(&self, dtype: DType) -> Result { if self.dtype() == dtype { Ok(self.clone()) } else { let shape = self.shape(); let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else { None }; Ok(from_storage(storage, shape.clone(), op, false)) } } pub fn contiguous(&self) -> Result { if self.is_contiguous() { Ok(self.clone()) } else { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; self.storage .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; Ok(from_storage( storage, shape.clone(), None, // TODO false, )) } } // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. pub fn reshape>(&self, shape: S) -> Result { let shape = shape.into(); if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: shape, op: "reshape", }); } let op = if self.track_op() { Some(Op::Reshape(self.clone())) } else { None }; if self.is_contiguous() { let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), shape, stride, op, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) } else { let mut storage = self.device().zeros(&shape, self.dtype())?; self.storage .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; Ok(from_storage(storage, shape, op, false)) } } pub fn cat>(args: &[A], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } let arg0 = args[0].as_ref(); if args.len() == 1 { return Ok(arg0.clone()); } let rank = arg0.rank(); if dim >= rank { return Err(Error::UnexpectedNumberOfDims { expected: (dim + 1), got: rank, shape: arg0.shape().clone(), }); } if dim == 0 { Self::cat0(args) } else { // TODO: Avoid these transpositions and have an implementation that works // for dim != 0... let args: Vec = args .iter() .map(|a| a.as_ref().transpose(0, dim)) .collect::>>()?; let cat = Self::cat0(&args)?; cat.transpose(0, dim) } } pub fn cat0>(args: &[A]) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } let arg0 = args[0].as_ref(); if args.len() == 1 { return Ok(arg0.clone()); } let rank = arg0.rank(); let device = arg0.device(); let dtype = arg0.dtype(); let first_dims = arg0.shape().dims(); let mut cat_dims = first_dims.to_vec(); cat_dims[0] = 0; let mut offsets = vec![0usize]; for (arg_idx, arg) in args.iter().enumerate() { let arg = arg.as_ref(); if arg.dtype() != dtype { // TODO: Improve the error message. return Err(Error::DTypeMismatchBinaryOp { lhs: dtype, rhs: arg.dtype(), op: "cat", }); } if arg.device().location() != device.location() { // TODO: Improve the error message. return Err(Error::DeviceMismatchBinaryOp { lhs: device.location(), rhs: arg.device().location(), op: "cat", }); } let mut mismatch = arg.rank() != rank; for (dim_idx, (v1, v2)) in arg0 .shape() .dims() .iter() .zip(arg.shape().dims().iter()) .enumerate() { if dim_idx == 0 { cat_dims[0] += v2; } if dim_idx != 0 && v1 != v2 { // TODO: It would probably be good to have a nicer error message here, i.e. // mention the problematic dimension and the values. mismatch = true; } } if mismatch { return Err(Error::ShapeMismatchCat { dim: 0, // TODO: not the appropriate error message first_shape: arg0.shape().clone(), n: arg_idx + 1, nth_shape: arg.shape().clone(), }); } let next_offset = offsets.last().unwrap() + arg.elem_count(); offsets.push(next_offset); } let shape = Shape::from(cat_dims); let op = if args.iter().any(|arg| arg.as_ref().track_op()) { let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); Some(Op::Cat(args, 0)) } else { None }; let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); arg.storage .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; } Ok(from_storage(storage, shape, op, false)) } } macro_rules! bin_trait { ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => { impl> std::ops::$trait for Tensor { type Output = Result; fn $fn1(self, rhs: B) -> Self::Output { Tensor::$fn1(&self, rhs.borrow()) } } impl> std::ops::$trait for &Tensor { type Output = Result; fn $fn1(self, rhs: B) -> Self::Output { Tensor::$fn1(&self, rhs.borrow()) } } impl> std::ops::$trait> for Tensor { type Output = Result; fn $fn1(self, rhs: Result) -> Self::Output { Tensor::$fn1(&self, rhs?.borrow()) } } impl> std::ops::$trait> for &Tensor { type Output = Result; fn $fn1(self, rhs: Result) -> Self::Output { Tensor::$fn1(&self, rhs?.borrow()) } } impl std::ops::$trait for Tensor { type Output = Result; fn $fn1(self, rhs: f64) -> Self::Output { self.affine($mul(rhs), $add(rhs)) } } impl std::ops::$trait for &Tensor { type Output = Result; fn $fn1(self, rhs: f64) -> Self::Output { self.affine($mul(rhs), $add(rhs)) } } }; } bin_trait!(Add, add, |_| 1., |v| v); bin_trait!(Sub, sub, |_| 1., |v: f64| -v); bin_trait!(Mul, mul, |v| v, |_| 0.); bin_trait!(Div, div, |v| 1. / v, |_| 0.);