Cleanup the tensor creation code.

This commit is contained in:
laurent
2023-06-23 19:52:21 +01:00
parent 88187b784b
commit fe75a01188
2 changed files with 37 additions and 116 deletions

View File

@ -23,6 +23,8 @@ pub(crate) enum Op {
}, },
Neg(Tensor), Neg(Tensor),
Reshape(Tensor), Reshape(Tensor),
#[allow(dead_code)]
Softmax(Tensor, usize),
Sqr(Tensor), Sqr(Tensor),
Sqrt(Tensor), Sqrt(Tensor),
ToDevice(Tensor), ToDevice(Tensor),

View File

@ -54,15 +54,7 @@ macro_rules! unary_op {
} else { } else {
None None
}; };
let tensor_ = Tensor_ { Ok(from_storage(storage, shape.clone(), op, false))
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
} }
}; };
} }
@ -82,15 +74,7 @@ macro_rules! binary_op {
} else { } else {
None None
}; };
let tensor_ = Tensor_ { Ok(from_storage(storage, shape.clone(), op, false))
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
} }
}; };
} }
@ -110,17 +94,23 @@ macro_rules! broadcast_binary_op {
} else { } else {
None None
}; };
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage, storage,
shape: shape.clone(), shape,
stride: shape.stride_contiguous(), stride,
op, op,
is_variable: false, is_variable,
};
Ok(Self(Arc::new(tensor_)))
}
}; };
Tensor(Arc::new(tensor_))
} }
impl Tensor { impl Tensor {
@ -132,16 +122,7 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
let storage = device.ones(&shape, dtype)?; let storage = device.ones(&shape, dtype)?;
let stride = shape.stride_contiguous(); Ok(from_storage(storage, shape, None, is_variable))
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op: None,
is_variable,
};
Ok(Self(Arc::new(tensor_)))
} }
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -164,16 +145,7 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
let storage = device.zeros(&shape, dtype)?; let storage = device.zeros(&shape, dtype)?;
let stride = shape.stride_contiguous(); Ok(from_storage(storage, shape, None, is_variable))
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op: None,
is_variable,
};
Ok(Self(Arc::new(tensor_)))
} }
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
@ -200,16 +172,7 @@ impl Tensor {
return Err(Error::ShapeMismatch { buffer_size, shape }); return Err(Error::ShapeMismatch { buffer_size, shape });
} }
let storage = device.storage(array)?; let storage = device.storage(array)?;
let stride = shape.stride_contiguous(); Ok(from_storage(storage, shape, None, is_variable))
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op: None,
is_variable,
};
Ok(Self(Arc::new(tensor_)))
} }
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
@ -314,9 +277,7 @@ impl Tensor {
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let shape = self.shape(); let shape = self.shape();
let storage = self let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?;
.storage
.affine_impl(self.shape(), self.stride(), mul, add)?;
let op = if self.track_op() { let op = if self.track_op() {
Some(Op::Affine { Some(Op::Affine {
arg: self.clone(), arg: self.clone(),
@ -326,15 +287,7 @@ impl Tensor {
} else { } else {
None None
}; };
let tensor_ = Tensor_ { Ok(from_storage(storage, shape.clone(), op, false))
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
} }
pub fn matmul(&self, rhs: &Self) -> Result<Self> { pub fn matmul(&self, rhs: &Self) -> Result<Self> {
@ -373,7 +326,6 @@ impl Tensor {
} }
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
let c_stride = c_shape.stride_contiguous();
let batching: usize = a_dims[..dim - 2].iter().product(); let batching: usize = a_dims[..dim - 2].iter().product();
let storage = self.storage.matmul_impl( let storage = self.storage.matmul_impl(
@ -387,15 +339,7 @@ impl Tensor {
} else { } else {
None None
}; };
let tensor_ = Tensor_ { Ok(from_storage(storage, c_shape, op, false))
id: TensorId::new(),
storage,
shape: c_shape,
stride: c_stride,
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
} }
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> { pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
@ -419,15 +363,7 @@ impl Tensor {
} else { } else {
None None
}; };
let tensor_ = Tensor_ { Ok(from_storage(storage, shape, op, false))
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
} }
pub(crate) fn strided_index(&self) -> crate::StridedIndex { pub(crate) fn strided_index(&self) -> crate::StridedIndex {
@ -669,15 +605,12 @@ impl Tensor {
let mut storage = self.device().zeros(shape, self.dtype())?; let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage self.storage
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
let tensor_ = Tensor_ { Ok(from_storage(
id: TensorId::new(),
storage, storage,
shape: shape.clone(), shape.clone(),
stride: shape.stride_contiguous(), self.op.clone(),
op: self.op.clone(), self.is_variable,
is_variable: self.is_variable, ))
};
Ok(Tensor(Arc::new(tensor_)))
} }
} }
@ -702,16 +635,7 @@ impl Tensor {
} else { } else {
None None
}; };
let stride = shape.stride_contiguous(); Ok(from_storage(storage, shape, op, false))
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
} }
pub fn cat(args: &[Self], dim: usize) -> Result<Self> { pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
@ -778,7 +702,6 @@ impl Tensor {
offsets.push(next_offset); offsets.push(next_offset);
} }
let shape = Shape::from(cat_dims); let shape = Shape::from(cat_dims);
let stride = shape.stride_contiguous();
let op = if args.iter().any(|arg| arg.track_op()) { let op = if args.iter().any(|arg| arg.track_op()) {
Some(Op::Cat(args.to_vec(), dim)) Some(Op::Cat(args.to_vec(), dim))
} else { } else {
@ -789,15 +712,7 @@ impl Tensor {
arg.storage arg.storage
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)? .copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
} }
let tensor_ = Tensor_ { Ok(from_storage(storage, shape, op, false))
id: TensorId::new(),
storage,
shape,
stride,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
} }
/// Return all the nodes that lead to this value in a topologically sorted vec, the first /// Return all the nodes that lead to this value in a topologically sorted vec, the first
@ -855,6 +770,7 @@ impl Tensor {
Op::Reshape(node) Op::Reshape(node)
| Op::ToDevice(node) | Op::ToDevice(node)
| Op::Transpose(node, _, _) | Op::Transpose(node, _, _)
| Op::Softmax(node, _)
| Op::Sqr(node) | Op::Sqr(node)
| Op::Sqrt(node) | Op::Sqrt(node)
| Op::Gelu(node) | Op::Gelu(node)
@ -975,6 +891,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)? *sum_grad = sum_grad.sub(&grad)?
} }
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
Op::Sqr(arg) => { Op::Sqr(arg) => {