diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index a0f3639b..ac5fca93 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -1,4 +1,4 @@ -use crate::Shape; +use crate::{Error, Result, Shape}; #[derive(Debug, PartialEq, Eq, Clone)] pub struct Layout { @@ -44,4 +44,25 @@ impl Layout { pub fn is_fortran_contiguous(&self) -> bool { self.shape.is_fortran_contiguous(&self.stride) } + + pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + let dims = self.shape().dims(); + if dim >= dims.len() { + Err(Error::UnexpectedNumberOfDims { + expected: dim + 1, + got: dims.len(), + shape: self.shape().clone(), + })? + } + if start + length > dims[dim] { + todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + } + let mut dims = dims.to_vec(); + dims[dim] = length; + Ok(Self { + shape: Shape::from(dims), + stride: self.stride.clone(), + start_offset: self.start_offset + self.stride[dim] * start, + }) + } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 3bcb0ff0..b7c94b46 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -227,16 +227,14 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_layout: &Layout, ) -> Result<()> { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) + src.copy_strided_src(dst, dst_offset, src_layout, src_offset) } (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) + Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 51eeb9ae..c7862250 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -351,37 +351,20 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + length`. - // TODO: Once we've refactored the shape and strides, make this return a view of the same data - // rather than copying. pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { - let dims = self.shape().dims(); - if dim >= dims.len() { - return Err(Error::UnexpectedNumberOfDims { - expected: dim + 1, - got: dims.len(), - shape: self.shape().clone(), - }); - } - if start + length > dims[dim] { - todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") - } - let mut dims = dims.to_vec(); - dims[dim] = length; - let adjusted_shape = Shape::from(dims); - let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?; - self.storage.copy_strided_src( - &mut storage, - /* dst_offset= */ 0, - &adjusted_shape, - &self.stride, - /* src_offest= */ self.stride[dim] * start, - )?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { None }; - Ok(from_storage(storage, adjusted_shape, op, false)) + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, length)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) } pub fn softmax(&self, dim: usize) -> Result { @@ -875,7 +858,7 @@ impl Tensor { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage( storage, shape.clone(), @@ -918,7 +901,7 @@ impl Tensor { } else { let mut storage = self.device().zeros(&shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) } } @@ -1055,7 +1038,7 @@ impl Tensor { for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); arg.storage - .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; + .copy_strided_src(&mut storage, offset, arg.layout())?; } Ok(from_storage(storage, shape, op, false)) }