From c4c61679498e2229d14161952c1b0b801397c8c1 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 10:45:20 +0100 Subject: [PATCH] Add the continuous method. --- src/error.rs | 3 +++ src/op.rs | 2 ++ src/tensor.rs | 28 ++++++++++++++++++++++++---- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/error.rs b/src/error.rs index cb302abd..c8c338ea 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,9 @@ pub enum Error { #[error("{op} expects at least one tensor")] OpRequiresAtLeastOneTensor { op: &'static str }, + #[error("backward is not supported for {op}")] + BackwardNotSupported { op: &'static str }, + #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, diff --git a/src/op.rs b/src/op.rs index 6642ba2d..22b28299 100644 --- a/src/op.rs +++ b/src/op.rs @@ -17,6 +17,8 @@ pub(crate) enum Op { add: f64, }, Neg(Tensor), + #[allow(dead_code)] + Reshape(Tensor), Sqr(Tensor), Sqrt(Tensor), ToDevice(Tensor), diff --git a/src/tensor.rs b/src/tensor.rs index 5872a244..99cd6064 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -575,6 +575,26 @@ impl Tensor { } } + pub fn contiguous(&self) -> Result { + if self.is_contiguous() { + Ok(self.clone()) + } else { + let shape = self.shape(); + let mut storage = self.device().zeros(shape, self.dtype())?; + self.storage + .copy_strided_src(&mut storage, shape, &self.stride, 0)?; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: self.op.clone(), + is_variable: self.is_variable, + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + pub fn cat(args: &[Self], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); @@ -708,7 +728,8 @@ impl Tensor { nodes } } - Op::ToDevice(node) + Op::Reshape(node) + | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) @@ -788,9 +809,7 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Cat(_args, _dim) => { - todo!() - } + Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }), Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?; @@ -801,6 +820,7 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Sqr(arg) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.or_insert(arg)?;