mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Move the test-utils bits to a shared place. (#619)
This commit is contained in:
@ -63,6 +63,7 @@ pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
|
56
candle-core/src/test_utils.rs
Normal file
56
candle-core/src/test_utils.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use crate::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
#[test]
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
Ok(f32::round(t * b) / b)
|
||||
}
|
||||
|
||||
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec1::<f32>()?;
|
||||
let t = t.iter().map(|t| f32::round(t * b) / b).collect();
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec2::<f32>()?;
|
||||
let t = t
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||
.collect();
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
let t = t
|
||||
.iter()
|
||||
.map(|t| {
|
||||
t.iter()
|
||||
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
Ok(t)
|
||||
}
|
Reference in New Issue
Block a user