mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Tmp state.
This commit is contained in:

committed by
Nicolas Patry

parent
f710fab02e
commit
d46670f7c0
@ -112,7 +112,13 @@ macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -124,7 +130,13 @@ macro_rules! ops{
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
impl std::fmt::Display for Kernel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -859,6 +871,30 @@ mod tests {
|
||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos_strided_random() {
|
||||
let v: Vec<_> = (0..10_000).map(|i| rand::random::<f32>()).collect();
|
||||
let shape = vec![5_000, 2];
|
||||
let strides = vec![1, 5_000];
|
||||
let offset = 0;
|
||||
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
|
||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||
assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[1]], 4),
|
||||
approx(vec![expected[5_000]], 4)
|
||||
);
|
||||
assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
|
||||
assert_eq!(
|
||||
approx(vec![results[3]], 4),
|
||||
approx(vec![expected[5_001]], 4)
|
||||
);
|
||||
assert_eq!(
|
||||
approx(vec![results[5_000]], 4),
|
||||
approx(vec![expected[2_500]], 4)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
|
Reference in New Issue
Block a user