diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index a0347416..3504b0a6 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -59,6 +59,7 @@ mod op; pub mod pickle; pub mod quantized; pub mod safetensors; +pub mod scalar; pub mod shape; mod storage; mod strided_index; diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs new file mode 100644 index 00000000..43e1f4c8 --- /dev/null +++ b/candle-core/src/scalar.rs @@ -0,0 +1,23 @@ +use crate::{Result, Tensor, WithDType}; + +pub enum TensorScalar { + Tensor(Tensor), + Scalar(Tensor), +} + +pub trait TensorOrScalar { + fn to_tensor_scalar(self) -> Result; +} + +impl TensorOrScalar for &Tensor { + fn to_tensor_scalar(self) -> Result { + Ok(TensorScalar::Tensor(self.clone())) + } +} + +impl TensorOrScalar for T { + fn to_tensor_scalar(self) -> Result { + let scalar = Tensor::new(self, &crate::Device::Cpu)?; + Ok(TensorScalar::Scalar(scalar)) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6bb3d740..8ad9322b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, }; +use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -776,8 +777,15 @@ impl Tensor { /// comparison operation is specified by the `op` argument. /// /// The returned tensor has the same shape as the original tensors and uses `u8` elements. - pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result { - let shape = self.same_shape_binary_op(rhs, "cmp")?; + pub fn cmp(&self, rhs: T, op: CmpOp) -> Result { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, "cmp")?; let storage = self .storage() .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; @@ -786,36 +794,36 @@ impl Tensor { } /// Element-wise equality. - pub fn eq(&self, rhs: &Self) -> Result { + pub fn eq(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Eq) } /// Element-wise non-equality. - pub fn ne(&self, rhs: &Self) -> Result { + pub fn ne(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Ne) } /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < /// rhs` and 0 otherwise. - pub fn lt(&self, rhs: &Self) -> Result { + pub fn lt(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Lt) } /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > /// rhs` and 0 otherwise. - pub fn gt(&self, rhs: &Self) -> Result { + pub fn gt(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Gt) } /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= /// rhs` and 0 otherwise. - pub fn ge(&self, rhs: &Self) -> Result { + pub fn ge(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Ge) } /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= /// rhs` and 0 otherwise. - pub fn le(&self, rhs: &Self) -> Result { + pub fn le(&self, rhs: T) -> Result { self.cmp(rhs, CmpOp::Le) } diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 4627248c..ce8e3bb4 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> { } } else { let point = Some((args.point_x, args.point_y)); + let start_time = std::time::Instant::now(); let (mask, iou_predictions) = sam.forward(&image, point, false)?; + println!( + "mask generated in {:.2}s", + start_time.elapsed().as_secs_f32() + ); println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); // Save the mask as an image. - let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?; + let mask = (mask.ge(0f32)? * 255.)?; let (_one, h, w) = mask.dims3()?; let mask = mask.expand((3, h, w))?; candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 40cc6e36..7bbe8419 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -161,21 +161,21 @@ impl PromptEncoder { .forward_with_coords(&points, self.input_image_size)?; let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(0f32)?.where_cond( &self .not_a_point_embed .embeddings() .broadcast_as(zeros.shape())?, &point_embedding, )?; - let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond( + let labels0 = labels.eq(0f32)?.where_cond( &self.point_embeddings[0] .embeddings() .broadcast_as(zeros.shape())?, &zeros, )?; let point_embedding = (point_embedding + labels0)?; - let labels1 = labels.eq(&labels.ones_like()?)?.where_cond( + let labels1 = labels.eq(1f32)?.where_cond( &self.point_embeddings[1] .embeddings() .broadcast_as(zeros.shape())?,