mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum. * Add a clamp method to tensors.
This commit is contained in:
@ -105,6 +105,28 @@ macro_rules! binary_op {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! binary_op_scalar {
|
||||||
|
($fn_name:ident, $op_name:ident) => {
|
||||||
|
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
|
let rhs = match rhs.to_tensor_scalar()? {
|
||||||
|
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||||
|
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||||
|
.to_dtype(self.dtype())?
|
||||||
|
.to_device(self.device())?
|
||||||
|
.broadcast_as(self.shape())?,
|
||||||
|
};
|
||||||
|
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||||
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
|
&*rhs.storage(),
|
||||||
|
self.layout(),
|
||||||
|
rhs.layout(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||||
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! broadcast_binary_op {
|
macro_rules! broadcast_binary_op {
|
||||||
($fn_name:ident, $inner_fn_name:ident) => {
|
($fn_name:ident, $inner_fn_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
@ -447,8 +469,8 @@ impl Tensor {
|
|||||||
binary_op!(mul, Mul);
|
binary_op!(mul, Mul);
|
||||||
binary_op!(sub, Sub);
|
binary_op!(sub, Sub);
|
||||||
binary_op!(div, Div);
|
binary_op!(div, Div);
|
||||||
binary_op!(maximum, Maximum);
|
binary_op_scalar!(maximum, Maximum);
|
||||||
binary_op!(minimum, Minimum);
|
binary_op_scalar!(minimum, Minimum);
|
||||||
broadcast_binary_op!(broadcast_add, add);
|
broadcast_binary_op!(broadcast_add, add);
|
||||||
broadcast_binary_op!(broadcast_mul, mul);
|
broadcast_binary_op!(broadcast_mul, mul);
|
||||||
broadcast_binary_op!(broadcast_sub, sub);
|
broadcast_binary_op!(broadcast_sub, sub);
|
||||||
@ -827,6 +849,11 @@ impl Tensor {
|
|||||||
self.cmp(rhs, CmpOp::Le)
|
self.cmp(rhs, CmpOp::Le)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clamp the tensor values to be between `min` and `max`.
|
||||||
|
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
|
||||||
|
self.maximum(min)?.minimum(max)
|
||||||
|
}
|
||||||
|
|
||||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||||
/// nearest element.
|
/// nearest element.
|
||||||
///
|
///
|
||||||
|
@ -33,6 +33,17 @@ fn tensor_2d(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clamp(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let tensor = tensor.clamp(1.5, 6.2)?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec2::<f32>()?,
|
||||||
|
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn binary_op(device: &Device) -> Result<()> {
|
fn binary_op(device: &Device) -> Result<()> {
|
||||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
let tensor1 = Tensor::new(data, device)?;
|
let tensor1 = Tensor::new(data, device)?;
|
||||||
@ -908,6 +919,7 @@ test_device!(index_add, index_add_cpu, index_add_gpu);
|
|||||||
test_device!(gather, gather_cpu, gather_gpu);
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
Reference in New Issue
Block a user