mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Boilerplate code for the sum operator.
This commit is contained in:
@ -56,6 +56,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
|
| Op::Sum(node, _)
|
||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
@ -188,6 +189,9 @@ impl Tensor {
|
|||||||
Op::Broadcast(_arg) => {
|
Op::Broadcast(_arg) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||||
}
|
}
|
||||||
|
Op::Sum(_arg, _sum_dims) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "sum" })
|
||||||
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
|
@ -171,6 +171,15 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sum(
|
||||||
|
&self,
|
||||||
|
_shape: &Shape,
|
||||||
|
_stride: &[usize],
|
||||||
|
_sum_dims: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
// [self] stores data in a contiguous way.
|
// [self] stores data in a contiguous way.
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
|
@ -291,6 +291,15 @@ impl CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sum(
|
||||||
|
&self,
|
||||||
|
_shape: &Shape,
|
||||||
|
_stride: &[usize],
|
||||||
|
_sum_dims: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,10 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ pub(crate) enum Op {
|
|||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
|
Sum(Tensor, Vec<usize>),
|
||||||
ToDType(Tensor),
|
ToDType(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Exp(Tensor),
|
Exp(Tensor),
|
||||||
|
@ -72,6 +72,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => {
|
||||||
|
let storage = storage.sum(shape, stride, s)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
Self::Cuda(storage) => {
|
||||||
|
let storage = storage.sum(shape, stride, s)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||||
|
@ -102,7 +102,13 @@ macro_rules! broadcast_binary_op {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||||
fn from_storage(storage: Storage, shape: Shape, op: Option<Op>, is_variable: bool) -> Tensor {
|
fn from_storage<S: Into<Shape>>(
|
||||||
|
storage: Storage,
|
||||||
|
shape: S,
|
||||||
|
op: Option<Op>,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Tensor {
|
||||||
|
let shape = shape.into();
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -347,6 +353,20 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||||
|
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
for &sum_dim in sum_dims.iter() {
|
||||||
|
dims[sum_dim] = 1
|
||||||
|
}
|
||||||
|
Ok(from_storage(storage, dims, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||||
let a_dims = self.shape().dims();
|
let a_dims = self.shape().dims();
|
||||||
let b_dims = rhs.shape().dims();
|
let b_dims = rhs.shape().dims();
|
||||||
|
Reference in New Issue
Block a user