From dd657397b2cc15d32058d3045f399c2a7acd7e11 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 08:17:35 +0100 Subject: [PATCH] Skeleton implementation for the narrow method and op. --- src/op.rs | 2 +- src/tensor.rs | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/op.rs b/src/op.rs index 7c2539f9..523154fb 100644 --- a/src/op.rs +++ b/src/op.rs @@ -27,9 +27,9 @@ pub(crate) enum Op { Sin(Tensor), Cos(Tensor), Abs(Tensor), + Narrow(Tensor, usize, usize, usize), Neg(Tensor), Reshape(Tensor), - #[allow(dead_code)] Softmax(Tensor, usize), Sqr(Tensor), Sqrt(Tensor), diff --git a/src/tensor.rs b/src/tensor.rs index 741bef0d..5467584b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -298,6 +298,36 @@ impl Tensor { Ok(from_storage(storage, shape.clone(), op, false)) } + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` + /// ranges from `start` to `start + length`. + // TODO: Once we've refactor the shape and strides, make this return a view of the same data + // rather than copying. + pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + let dims = self.shape().dims(); + if dim >= dims.len() { + return Err(Error::UnexpectedNumberOfDims { + expected: dim + 1, + got: dims.len(), + shape: self.shape().clone(), + }); + } + if start + length > dims[dim] { + todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + } + let mut dims = dims.to_vec(); + dims[dim] = length; + let shape = Shape::from(dims); + let storage = self.device().zeros(&shape, self.dtype())?; + // TODO: Actually copy the data, compared to copy_strided_src this requires a src start + // offset as well as a way to specify the number of elements to be copied. + let op = if self.track_op() { + Some(Op::Narrow(self.clone(), dim, start, length)) + } else { + None + }; + Ok(from_storage(storage, shape, op, false)) + } + pub fn softmax(&self, dim: usize) -> Result { let shape = self.shape(); let mut storage = self @@ -817,6 +847,7 @@ impl Tensor { | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) + | Op::Narrow(node, _, _, _) | Op::Softmax(node, _) | Op::Sqr(node) | Op::Sqrt(node) @@ -933,7 +964,10 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }), + Op::Cat(_args, _dim) => { + // TODO: Use narrow here. + return Err(Error::BackwardNotSupported { op: "cat" }); + } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? @@ -964,6 +998,9 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } + Op::Narrow(_arg, _, _, _) => { + return Err(Error::BackwardNotSupported { op: "narrow" }) + } Op::Softmax(_arg, _) => { return Err(Error::BackwardNotSupported { op: "softmax" }) }