mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
24 lines
548 B
Rust
24 lines
548 B
Rust
use crate::{Result, Tensor, WithDType};
|
|
|
|
pub enum TensorScalar {
|
|
Tensor(Tensor),
|
|
Scalar(Tensor),
|
|
}
|
|
|
|
pub trait TensorOrScalar {
|
|
fn to_tensor_scalar(self) -> Result<TensorScalar>;
|
|
}
|
|
|
|
impl TensorOrScalar for &Tensor {
|
|
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
|
Ok(TensorScalar::Tensor(self.clone()))
|
|
}
|
|
}
|
|
|
|
impl<T: WithDType> TensorOrScalar for T {
|
|
fn to_tensor_scalar(self) -> Result<TensorScalar> {
|
|
let scalar = Tensor::new(self, &crate::Device::Cpu)?;
|
|
Ok(TensorScalar::Scalar(scalar))
|
|
}
|
|
}
|