mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
This commit is contained in:
@ -59,6 +59,7 @@ mod op;
|
|||||||
pub mod pickle;
|
pub mod pickle;
|
||||||
pub mod quantized;
|
pub mod quantized;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
|
pub mod scalar;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
|
23
candle-core/src/scalar.rs
Normal file
23
candle-core/src/scalar.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use crate::{Result, Tensor, WithDType};
|
||||||
|
|
||||||
|
pub enum TensorScalar {
|
||||||
|
Tensor(Tensor),
|
||||||
|
Scalar(Tensor),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait TensorOrScalar {
|
||||||
|
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorOrScalar for &Tensor {
|
||||||
|
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||||
|
Ok(TensorScalar::Tensor(self.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TensorOrScalar for T {
|
||||||
|
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||||
|
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||||
|
Ok(TensorScalar::Scalar(scalar))
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
|||||||
use crate::op::{
|
use crate::op::{
|
||||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||||
};
|
};
|
||||||
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -776,8 +777,15 @@ impl Tensor {
|
|||||||
/// comparison operation is specified by the `op` argument.
|
/// comparison operation is specified by the `op` argument.
|
||||||
///
|
///
|
||||||
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
||||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
|
||||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
let rhs = match rhs.to_tensor_scalar()? {
|
||||||
|
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||||
|
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||||
|
.to_dtype(self.dtype())?
|
||||||
|
.to_device(self.device())?
|
||||||
|
.broadcast_as(self.shape())?,
|
||||||
|
};
|
||||||
|
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||||
@ -786,36 +794,36 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise equality.
|
/// Element-wise equality.
|
||||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Eq)
|
self.cmp(rhs, CmpOp::Eq)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise non-equality.
|
/// Element-wise non-equality.
|
||||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Ne)
|
self.cmp(rhs, CmpOp::Ne)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
||||||
/// rhs` and 0 otherwise.
|
/// rhs` and 0 otherwise.
|
||||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Lt)
|
self.cmp(rhs, CmpOp::Lt)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
||||||
/// rhs` and 0 otherwise.
|
/// rhs` and 0 otherwise.
|
||||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Gt)
|
self.cmp(rhs, CmpOp::Gt)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
||||||
/// rhs` and 0 otherwise.
|
/// rhs` and 0 otherwise.
|
||||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Ge)
|
self.cmp(rhs, CmpOp::Ge)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
||||||
/// rhs` and 0 otherwise.
|
/// rhs` and 0 otherwise.
|
||||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||||
self.cmp(rhs, CmpOp::Le)
|
self.cmp(rhs, CmpOp::Le)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let point = Some((args.point_x, args.point_y));
|
let point = Some((args.point_x, args.point_y));
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||||
|
println!(
|
||||||
|
"mask generated in {:.2}s",
|
||||||
|
start_time.elapsed().as_secs_f32()
|
||||||
|
);
|
||||||
println!("mask:\n{mask}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions:?}");
|
||||||
|
|
||||||
// Save the mask as an image.
|
// Save the mask as an image.
|
||||||
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
|
let mask = (mask.ge(0f32)? * 255.)?;
|
||||||
let (_one, h, w) = mask.dims3()?;
|
let (_one, h, w) = mask.dims3()?;
|
||||||
let mask = mask.expand((3, h, w))?;
|
let mask = mask.expand((3, h, w))?;
|
||||||
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||||
|
@ -161,21 +161,21 @@ impl PromptEncoder {
|
|||||||
.forward_with_coords(&points, self.input_image_size)?;
|
.forward_with_coords(&points, self.input_image_size)?;
|
||||||
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||||
let zeros = point_embedding.zeros_like()?;
|
let zeros = point_embedding.zeros_like()?;
|
||||||
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
let point_embedding = labels.lt(0f32)?.where_cond(
|
||||||
&self
|
&self
|
||||||
.not_a_point_embed
|
.not_a_point_embed
|
||||||
.embeddings()
|
.embeddings()
|
||||||
.broadcast_as(zeros.shape())?,
|
.broadcast_as(zeros.shape())?,
|
||||||
&point_embedding,
|
&point_embedding,
|
||||||
)?;
|
)?;
|
||||||
let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
|
let labels0 = labels.eq(0f32)?.where_cond(
|
||||||
&self.point_embeddings[0]
|
&self.point_embeddings[0]
|
||||||
.embeddings()
|
.embeddings()
|
||||||
.broadcast_as(zeros.shape())?,
|
.broadcast_as(zeros.shape())?,
|
||||||
&zeros,
|
&zeros,
|
||||||
)?;
|
)?;
|
||||||
let point_embedding = (point_embedding + labels0)?;
|
let point_embedding = (point_embedding + labels0)?;
|
||||||
let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
|
let labels1 = labels.eq(1f32)?.where_cond(
|
||||||
&self.point_embeddings[1]
|
&self.point_embeddings[1]
|
||||||
.embeddings()
|
.embeddings()
|
||||||
.broadcast_as(zeros.shape())?,
|
.broadcast_as(zeros.shape())?,
|
||||||
|
Reference in New Issue
Block a user