From fe75a01188a03ee3cea8b0784080c8443cc2404b Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 19:52:21 +0100 Subject: [PATCH] Cleanup the tensor creation code. --- src/op.rs | 2 + src/tensor.rs | 151 ++++++++++++-------------------------------------- 2 files changed, 37 insertions(+), 116 deletions(-) diff --git a/src/op.rs b/src/op.rs index b8198ca8..77e87140 100644 --- a/src/op.rs +++ b/src/op.rs @@ -23,6 +23,8 @@ pub(crate) enum Op { }, Neg(Tensor), Reshape(Tensor), + #[allow(dead_code)] + Softmax(Tensor, usize), Sqr(Tensor), Sqrt(Tensor), ToDevice(Tensor), diff --git a/src/tensor.rs b/src/tensor.rs index 34d62741..af084740 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -54,15 +54,7 @@ macro_rules! unary_op { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape.clone(), op, false)) } }; } @@ -82,15 +74,7 @@ macro_rules! binary_op { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape.clone(), op, false)) } }; } @@ -110,19 +94,25 @@ macro_rules! broadcast_binary_op { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape.clone(), op, false)) } }; } +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. +fn from_storage(storage: Storage, shape: Shape, op: Option, is_variable: bool) -> Tensor { + let stride = shape.stride_contiguous(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape, + stride, + op, + is_variable, + }; + Tensor(Arc::new(tensor_)) +} + impl Tensor { fn ones_impl>( shape: S, @@ -132,16 +122,7 @@ impl Tensor { ) -> Result { let shape = shape.into(); let storage = device.ones(&shape, dtype)?; - let stride = shape.stride_contiguous(); - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op: None, - is_variable, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape, None, is_variable)) } pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { @@ -164,16 +145,7 @@ impl Tensor { ) -> Result { let shape = shape.into(); let storage = device.zeros(&shape, dtype)?; - let stride = shape.stride_contiguous(); - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op: None, - is_variable, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape, None, is_variable)) } pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { @@ -200,16 +172,7 @@ impl Tensor { return Err(Error::ShapeMismatch { buffer_size, shape }); } let storage = device.storage(array)?; - let stride = shape.stride_contiguous(); - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op: None, - is_variable, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape, None, is_variable)) } pub fn new(array: A, device: &Device) -> Result { @@ -314,9 +277,7 @@ impl Tensor { pub fn affine(&self, mul: f64, add: f64) -> Result { let shape = self.shape(); - let storage = self - .storage - .affine_impl(self.shape(), self.stride(), mul, add)?; + let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -326,15 +287,7 @@ impl Tensor { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape.clone(), op, false)) } pub fn matmul(&self, rhs: &Self) -> Result { @@ -373,7 +326,6 @@ impl Tensor { } let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); - let c_stride = c_shape.stride_contiguous(); let batching: usize = a_dims[..dim - 2].iter().product(); let storage = self.storage.matmul_impl( @@ -387,15 +339,7 @@ impl Tensor { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: c_shape, - stride: c_stride, - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, c_shape, op, false)) } pub fn embedding(ids: &Self, rhs: &Self) -> Result { @@ -419,15 +363,7 @@ impl Tensor { } else { None }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op, - is_variable: false, - }; - Ok(Self(Arc::new(tensor_))) + Ok(from_storage(storage, shape, op, false)) } pub(crate) fn strided_index(&self) -> crate::StridedIndex { @@ -669,15 +605,12 @@ impl Tensor { let mut storage = self.device().zeros(shape, self.dtype())?; self.storage .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; - let tensor_ = Tensor_ { - id: TensorId::new(), + Ok(from_storage( storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op: self.op.clone(), - is_variable: self.is_variable, - }; - Ok(Tensor(Arc::new(tensor_))) + shape.clone(), + self.op.clone(), + self.is_variable, + )) } } @@ -702,16 +635,7 @@ impl Tensor { } else { None }; - let stride = shape.stride_contiguous(); - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op, - is_variable: false, - }; - Ok(Tensor(Arc::new(tensor_))) + Ok(from_storage(storage, shape, op, false)) } pub fn cat(args: &[Self], dim: usize) -> Result { @@ -778,7 +702,6 @@ impl Tensor { offsets.push(next_offset); } let shape = Shape::from(cat_dims); - let stride = shape.stride_contiguous(); let op = if args.iter().any(|arg| arg.track_op()) { Some(Op::Cat(args.to_vec(), dim)) } else { @@ -789,15 +712,7 @@ impl Tensor { arg.storage .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)? } - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op, - is_variable: false, - }; - Ok(Tensor(Arc::new(tensor_))) + Ok(from_storage(storage, shape, op, false)) } /// Return all the nodes that lead to this value in a topologically sorted vec, the first @@ -855,6 +770,7 @@ impl Tensor { Op::Reshape(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) + | Op::Softmax(node, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Gelu(node) @@ -975,6 +891,9 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } + Op::Softmax(_arg, _) => { + return Err(Error::BackwardNotSupported { op: "softmax" }) + } Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), Op::Sqr(arg) => {