mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum. * Add a clamp method to tensors.
This commit is contained in:
@ -105,6 +105,28 @@ macro_rules! binary_op {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! binary_op_scalar {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $inner_fn_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
@ -447,8 +469,8 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
binary_op!(maximum, Maximum);
|
||||
binary_op!(minimum, Minimum);
|
||||
binary_op_scalar!(maximum, Maximum);
|
||||
binary_op_scalar!(minimum, Minimum);
|
||||
broadcast_binary_op!(broadcast_add, add);
|
||||
broadcast_binary_op!(broadcast_mul, mul);
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
@ -827,6 +849,11 @@ impl Tensor {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Clamp the tensor values to be between `min` and `max`.
|
||||
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
|
||||
self.maximum(min)?.minimum(max)
|
||||
}
|
||||
|
||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// nearest element.
|
||||
///
|
||||
|
Reference in New Issue
Block a user