Add some basic test.

This commit is contained in:
laurent
2023-06-19 19:50:17 +01:00
parent 8e2c534d1f
commit 634e0c88ae
4 changed files with 84 additions and 1 deletions

19
src/error.rs Normal file
View File

@ -0,0 +1,19 @@
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
BinaryInvalidShape {
lhs: Vec<usize>,
rhs: Vec<usize>,
op: &'static str,
},
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
UnexpectedNumberOfDims {
expected: usize,
got: usize,
shape: Vec<usize>,
},
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,9 +1,11 @@
mod device;
mod dtype;
mod error;
mod op;
mod storage;
mod tensor;
pub use device::Device;
pub use dtype::DType;
pub use error::{Error, Result};
pub use tensor::Tensor;

View File

@ -1,4 +1,4 @@
use crate::{op::Op, storage::Storage, DType, Device};
use crate::{op::Op, storage::Storage, DType, Device, Error, Result};
use std::sync::Arc;
#[allow(dead_code)]
@ -46,4 +46,56 @@ impl Tensor {
pub fn elem_count(&self) -> usize {
self.0.shape.iter().product()
}
pub fn shape1(&self) -> Result<usize> {
let shape = self.shape();
if shape.len() == 1 {
Ok(shape[0])
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: shape.len(),
shape: shape.to_vec(),
})
}
}
pub fn shape2(&self) -> Result<(usize, usize)> {
let shape = self.shape();
if shape.len() == 2 {
Ok((shape[0], shape[1]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 2,
got: shape.len(),
shape: shape.to_vec(),
})
}
}
pub fn shape3(&self) -> Result<(usize, usize, usize)> {
let shape = self.shape();
if shape.len() == 3 {
Ok((shape[0], shape[1], shape[2]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 3,
got: shape.len(),
shape: shape.to_vec(),
})
}
}
pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> {
let shape = self.shape();
if shape.len() == 4 {
Ok((shape[0], shape[1], shape[2], shape[4]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 4,
got: shape.len(),
shape: shape.to_vec(),
})
}
}
}

10
tests/tensor_tests.rs Normal file
View File

@ -0,0 +1,10 @@
use candle::{DType, Device, Result, Tensor};
#[test]
fn add() -> Result<()> {
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
let (dim1, dim2) = tensor.shape2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
Ok(())
}