mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
This commit is contained in:
23
candle-core/src/scalar.rs
Normal file
23
candle-core/src/scalar.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
Scalar(Tensor),
|
||||
}
|
||||
|
||||
pub trait TensorOrScalar {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
||||
}
|
||||
|
||||
impl TensorOrScalar for &Tensor {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
Ok(TensorScalar::Tensor(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: WithDType> TensorOrScalar for T {
|
||||
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
||||
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
||||
Ok(TensorScalar::Scalar(scalar))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user