Get the addition/multiplication to work.

This commit is contained in:
laurent
2023-06-20 11:07:59 +01:00
parent 6c5fc767a8
commit f5b0aa815a
2 changed files with 62 additions and 15 deletions

View File

@ -1,15 +1,26 @@
use candle::{DType, Device, Result, Tensor};
#[test]
fn add() -> Result<()> {
fn zeros() -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu);
let (dim1, dim2) = tensor.shape().r2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
Ok(())
}
#[test]
fn add_mul() -> Result<()> {
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3);
let content: Vec<f32> = tensor.to_vec1()?;
assert_eq!(content, [3., 1., 4.]);
let tensor = Tensor::add(&tensor, &tensor)?;
let content: Vec<f32> = tensor.to_vec1()?;
assert_eq!(content, [6., 2., 8.]);
let tensor = Tensor::mul(&tensor, &tensor)?;
let content: Vec<f32> = tensor.to_vec1()?;
assert_eq!(content, [36., 4., 64.]);
Ok(())
}