mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Boilerplate code for the sum operator.
This commit is contained in:
@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op {
|
||||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -347,6 +353,20 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dims[sum_dim] = 1
|
||||
}
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
|
Reference in New Issue
Block a user