mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations. * Handle zero-dims in matmul. * More testing.
This commit is contained in:
@ -79,6 +79,9 @@ macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
@ -92,6 +95,9 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -646,6 +655,9 @@ impl Tensor {
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -653,6 +665,9 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -660,6 +675,9 @@ impl Tensor {
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -1154,6 +1172,9 @@ impl Tensor {
|
||||
let n = b_dims[dim - 1];
|
||||
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
if c_shape.elem_count() == 0 || k == 0 {
|
||||
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||
}
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||
if k != k2 || batching != batching_b {
|
||||
|
@ -1083,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn zero_dim(device: &Device) -> Result<()> {
|
||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||
let t_unary = t.sqrt()?;
|
||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||
let t_plus = (&t + 1.)?;
|
||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||
let t_mm = t2.matmul(&t.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||
let t_mm = t.matmul(&t2.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||
let t_mm = t.t()?.matmul(&t)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
@ -1131,6 +1152,7 @@ test_device!(
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
Reference in New Issue
Block a user