Get the comparison operation to work on scalar values. (#780)

* Get the comparison operation to work on scalar values.

* Add some time measurement.
This commit is contained in:
Laurent Mazare
2023-09-08 20:13:29 +01:00
committed by GitHub
parent 0906acab91
commit acf8f10ae1
5 changed files with 49 additions and 12 deletions

23
candle-core/src/scalar.rs Normal file
View File

@ -0,0 +1,23 @@
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {
Tensor(Tensor),
Scalar(Tensor),
}
pub trait TensorOrScalar {
fn to_tensor_scalar(self) -> Result<TensorScalar>;
}
impl TensorOrScalar for &Tensor {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
Ok(TensorScalar::Tensor(self.clone()))
}
}
impl<T: WithDType> TensorOrScalar for T {
fn to_tensor_scalar(self) -> Result<TensorScalar> {
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
Ok(TensorScalar::Scalar(scalar))
}
}