mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some basic test.
This commit is contained in:
19
src/error.rs
Normal file
19
src/error.rs
Normal file
@ -0,0 +1,19 @@
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
BinaryInvalidShape {
|
||||
lhs: Vec<usize>,
|
||||
rhs: Vec<usize>,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||
UnexpectedNumberOfDims {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
@ -1,9 +1,11 @@
|
||||
mod device;
|
||||
mod dtype;
|
||||
mod error;
|
||||
mod op;
|
||||
mod storage;
|
||||
mod tensor;
|
||||
|
||||
pub use device::Device;
|
||||
pub use dtype::DType;
|
||||
pub use error::{Error, Result};
|
||||
pub use tensor::Tensor;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -46,4 +46,56 @@ impl Tensor {
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.shape.iter().product()
|
||||
}
|
||||
|
||||
pub fn shape1(&self) -> Result<usize> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 1 {
|
||||
Ok(shape[0])
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape2(&self) -> Result<(usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 2 {
|
||||
Ok((shape[0], shape[1]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape3(&self) -> Result<(usize, usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 3 {
|
||||
Ok((shape[0], shape[1], shape[2]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 3,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape4(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let shape = self.shape();
|
||||
if shape.len() == 4 {
|
||||
Ok((shape[0], shape[1], shape[2], shape[4]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 4,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
10
tests/tensor_tests.rs
Normal file
10
tests/tensor_tests.rs
Normal file
@ -0,0 +1,10 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn add() -> Result<()> {
|
||||
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
|
||||
let (dim1, dim2) = tensor.shape2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user