diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ad9322b..59a23c39 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -105,6 +105,28 @@ macro_rules! binary_op { }; } +macro_rules! binary_op_scalar { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name(&self, rhs: T) -> Result { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?; + let storage = self.storage().binary_impl::( + &*rhs.storage(), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + macro_rules! broadcast_binary_op { ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { @@ -447,8 +469,8 @@ impl Tensor { binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); - binary_op!(maximum, Maximum); - binary_op!(minimum, Minimum); + binary_op_scalar!(maximum, Maximum); + binary_op_scalar!(minimum, Minimum); broadcast_binary_op!(broadcast_add, add); broadcast_binary_op!(broadcast_mul, mul); broadcast_binary_op!(broadcast_sub, sub); @@ -827,6 +849,11 @@ impl Tensor { self.cmp(rhs, CmpOp::Le) } + /// Clamp the tensor values to be between `min` and `max`. + pub fn clamp(&self, min: T1, max: T2) -> Result { + self.maximum(min)?.minimum(max) + } + /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the /// nearest element. /// diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd68908f..f1c204ea 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -33,6 +33,17 @@ fn tensor_2d(device: &Device) -> Result<()> { Ok(()) } +fn clamp(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let tensor = tensor.clamp(1.5, 6.2)?; + assert_eq!( + tensor.to_vec2::()?, + [[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]], + ); + Ok(()) +} + fn binary_op(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let tensor1 = Tensor::new(data, device)?; @@ -908,6 +919,7 @@ test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); test_device!(randn, randn_cpu, randn_gpu); +test_device!(clamp, clamp_cpu, clamp_gpu); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381