Boilerplate code for the sum operator.

This commit is contained in:
laurent
2023-06-25 09:35:17 +01:00
parent 7ccf27dda2
commit 3852a85af3
7 changed files with 61 additions and 1 deletions

View File

@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op {
}
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: Option<Op>,
is_variable: bool,
) -> Tensor {
let shape = shape.into();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -347,6 +353,20 @@ impl Tensor {
Ok(from_storage(storage, shape.clone(), op, false))
}
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
None
};
let mut dims = self.dims().to_vec();
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
Ok(from_storage(storage, dims, op, false))
}
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();