diff --git a/examples/basics.rs b/examples/basics.rs new file mode 100644 index 00000000..f01f7871 --- /dev/null +++ b/examples/basics.rs @@ -0,0 +1,9 @@ +use anyhow::Result; +use candle::{Device, Tensor}; + +fn main() -> Result<()> { + let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; + let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; + println!("{:?}", y.to_vec1::()?); + Ok(()) +} diff --git a/src/dtype.rs b/src/dtype.rs index d66d046c..fd0eaa1b 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -27,38 +27,26 @@ pub trait WithDType: Sized + Copy { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; } -impl WithDType for f32 { - const DTYPE: DType = DType::F32; +macro_rules! with_dtype { + ($ty:ty, $dtype:ident) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; - fn to_cpu_storage_owned(data: Vec) -> CpuStorage { - CpuStorage::F32(data) - } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { + CpuStorage::$dtype(data) + } - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F32(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F32, - got: s.dtype(), - }), + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + }), + } + } } - } -} - -impl WithDType for f64 { - const DTYPE: DType = DType::F64; - - fn to_cpu_storage_owned(data: Vec) -> CpuStorage { - CpuStorage::F64(data) - } - - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F64(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F64, - got: s.dtype(), - }), - } - } + }; } +with_dtype!(f32, F32); +with_dtype!(f64, F64);