mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Boilerplate code for the sum operator.
This commit is contained in:
@ -56,6 +56,7 @@ impl Tensor {
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -188,6 +189,9 @@ impl Tensor {
|
||||
Op::Broadcast(_arg) => {
|
||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||
}
|
||||
Op::Sum(_arg, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "sum" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
|
@ -171,6 +171,15 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
_stride: &[usize],
|
||||
_sum_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way.
|
||||
let dims = shape.dims();
|
||||
|
@ -291,6 +291,15 @@ impl CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn sum(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
_stride: &[usize],
|
||||
_sum_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -64,6 +64,10 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Sum(Tensor, Vec<usize>),
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
|
@ -72,6 +72,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
|
@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op {
|
||||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -347,6 +353,20 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dims[sum_dim] = 1
|
||||
}
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
|
Reference in New Issue
Block a user