From fe8777822390584ac0c080d8de1f51c7b3d0d091 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 23 Jul 2023 19:06:47 +0200 Subject: [PATCH] Add the copy op. (#227) * Add the copy op. * Tweak some cat error messages. * Handle the contiguous case in to_vec1. * Fast variant for to_vec2. * Add add a faster to_vec3 variant. --- candle-core/src/backprop.rs | 5 ++ candle-core/src/op.rs | 1 + candle-core/src/tensor.rs | 106 ++++++++++++++++++++++-------------- 3 files changed, 72 insertions(+), 40 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 38898b7b..24da23a2 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -82,6 +82,7 @@ impl Tensor { } } Op::Reshape(node) + | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) | Op::Reduce(node, _, _) @@ -246,6 +247,10 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? } + Op::Copy(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad)? + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index de5094bd..144f1e98 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -89,6 +89,7 @@ pub(crate) enum Op { add: f64, }, ToDType(Tensor), + Copy(Tensor), Broadcast(Tensor), Narrow(Tensor, usize, usize, usize), Reshape(Tensor), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 561f1863..5257219d 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1128,17 +1128,17 @@ impl Tensor { } .bt())? } + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let data = match self.layout.contiguous_offsets() { + Some((o1, o2)) => data[o1..o2].to_vec(), + None => self.strided_index().map(|i| data[i]).collect(), + }; + Ok::, Error>(data) + }; match &*self.storage() { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok(self.strided_index().map(|i| data[i]).collect()) - } - Storage::Cuda(slice) => { - // TODO: Would it be possible to only fetch the necessary data? - let cpu_storage = slice.to_cpu_storage()?; - let data = S::cpu_storage_as_slice(&cpu_storage)?; - Ok(self.strided_index().map(|i| data[i]).collect()) - } + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1148,12 +1148,22 @@ impl Tensor { let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut rows = vec![]; - let mut src_index = self.strided_index(); - for _idx_row in 0..dim1 { - let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); - rows.push(row) + match self.layout.contiguous_offsets() { + Some((o1, o2)) => { + let data = &data[o1..o2]; + for idx_row in 0..dim1 { + rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec()) + } + } + None => { + let mut src_index = self.strided_index(); + for _idx_row in 0..dim1 { + let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + assert!(src_index.next().is_none()); + } } - assert!(src_index.next().is_none()); Ok(rows) }; match &*self.storage() { @@ -1168,16 +1178,32 @@ impl Tensor { let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut top_rows = vec![]; - let mut src_index = self.strided_index(); - for _idx in 0..dim1 { - let mut rows = vec![]; - for _jdx in 0..dim2 { - let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); - rows.push(row) + match self.layout.contiguous_offsets() { + Some((o1, o2)) => { + let data = &data[o1..o2]; + let dim23 = dim2 * dim3; + for idx1 in 0..dim1 { + let data = &data[idx1 * dim23..(idx1 + 1) * dim23]; + let mut rows = vec![]; + for idx2 in 0..dim2 { + rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec()) + } + top_rows.push(rows); + } + } + None => { + let mut src_index = self.strided_index(); + for _idx in 0..dim1 { + let mut rows = vec![]; + for _jdx in 0..dim2 { + let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + top_rows.push(rows); + } + assert!(src_index.next().is_none()); } - top_rows.push(rows); } - assert!(src_index.next().is_none()); Ok(top_rows) }; match &*self.storage() { @@ -1404,7 +1430,7 @@ impl Tensor { id: TensorId::new(), storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), layout: self.layout.clone(), - op: None, // TODO + op: Some(Op::Copy(self.clone())), is_variable: false, dtype: self.dtype, device: self.device.clone(), @@ -1540,7 +1566,7 @@ impl Tensor { Ok(from_storage( storage, shape.clone(), - None, // TODO + Some(Op::Copy(self.clone())), false, )) } @@ -1734,7 +1760,6 @@ impl Tensor { for (arg_idx, arg) in args.iter().enumerate() { let arg = arg.as_ref(); if arg.dtype() != dtype { - // TODO: Improve the error message. Err(Error::DTypeMismatchBinaryOp { lhs: dtype, rhs: arg.dtype(), @@ -1743,7 +1768,6 @@ impl Tensor { .bt())? } if arg.device().location() != device.location() { - // TODO: Improve the error message. Err(Error::DeviceMismatchBinaryOp { lhs: device.location(), rhs: arg.device().location(), @@ -1751,7 +1775,14 @@ impl Tensor { } .bt())? } - let mut mismatch = arg.rank() != rank; + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } for (dim_idx, (v1, v2)) in arg0 .shape() .dims() @@ -1763,20 +1794,15 @@ impl Tensor { cat_dims[0] += v2; } if dim_idx != 0 && v1 != v2 { - // TODO: It would probably be good to have a nicer error message here, i.e. - // mention the problematic dimension and the values. - mismatch = true; + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? } } - if mismatch { - Err(Error::ShapeMismatchCat { - dim: 0, // TODO: not the appropriate error message - first_shape: arg0.shape().clone(), - n: arg_idx + 1, - nth_shape: arg.shape().clone(), - } - .bt())? - } let next_offset = offsets.last().unwrap() + arg.elem_count(); offsets.push(next_offset); }