mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some display tests + bugfixes.
This commit is contained in:
@ -186,8 +186,8 @@ trait TensorFormatter {
|
||||
.and_then(|t| t.to_vec1::<Self::Elem>())
|
||||
{
|
||||
for v in vs.into_iter() {
|
||||
self.fmt(v, max_w, f)?;
|
||||
write!(f, ", ")?;
|
||||
self.fmt(v, max_w, f)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -257,7 +257,7 @@ struct FloatFormatter<S: WithDType> {
|
||||
|
||||
impl<S> FloatFormatter<S>
|
||||
where
|
||||
S: WithDType + num_traits::Float,
|
||||
S: WithDType + num_traits::Float + std::fmt::Display,
|
||||
{
|
||||
fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
|
||||
let mut int_mode = true;
|
||||
@ -275,6 +275,7 @@ where
|
||||
let mut nonzero_finite_min = S::max_value();
|
||||
let mut nonzero_finite_max = S::min_value();
|
||||
for &v in values.iter() {
|
||||
let v = v.abs();
|
||||
if v < nonzero_finite_min {
|
||||
nonzero_finite_min = v
|
||||
}
|
||||
@ -289,7 +290,6 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(v1) = S::from(1000.) {
|
||||
if let Some(v2) = S::from(1e8) {
|
||||
if let Some(v3) = S::from(1e-4) {
|
||||
|
@ -652,18 +652,22 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
|
||||
if end_dim + 1 < dims.len() {
|
||||
dst_dims.extend(&dims[end_dim + 1..]);
|
||||
}
|
||||
self.reshape(dst_dims)
|
||||
if self.rank() == 0 {
|
||||
self.reshape(1)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
|
||||
if end_dim + 1 < dims.len() {
|
||||
dst_dims.extend(&dims[end_dim + 1..]);
|
||||
}
|
||||
self.reshape(dst_dims)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
84
candle-core/tests/display_tests.rs
Normal file
84
candle-core/tests/display_tests.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device::Cpu, Tensor};
|
||||
|
||||
#[test]
|
||||
fn display_scalar() -> Result<()> {
|
||||
let t = Tensor::new(1234u32, &Cpu)?;
|
||||
let s = format!("{t}");
|
||||
assert_eq!(&s, "[1234]\nTensor[[], u32]");
|
||||
let t = t.to_dtype(DType::F32)?.neg()?;
|
||||
let s = format!("{}", (&t / 10.0)?);
|
||||
assert_eq!(&s, "[-123.4000]\nTensor[[], f32]");
|
||||
let s = format!("{}", (&t / 1e8)?);
|
||||
assert_eq!(&s, "[-1.2340e-5]\nTensor[[], f32]");
|
||||
let s = format!("{}", (&t * 1e8)?);
|
||||
assert_eq!(&s, "[-1.2340e11]\nTensor[[], f32]");
|
||||
let s = format!("{}", (&t * 0.)?);
|
||||
assert_eq!(&s, "[0.]\nTensor[[], f32]");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_vector() -> Result<()> {
|
||||
let t = Tensor::new::<&[u32; 0]>(&[], &Cpu)?;
|
||||
let s = format!("{t}");
|
||||
assert_eq!(&s, "[]\nTensor[[0], u32]");
|
||||
let t = Tensor::new(&[0.1234567, 1.0, -1.2, 4.1, f64::NAN], &Cpu)?;
|
||||
let s = format!("{t}");
|
||||
assert_eq!(
|
||||
&s,
|
||||
"[ 0.1235, 1.0000, -1.2000, 4.1000, NaN]\nTensor[[5], f64]"
|
||||
);
|
||||
let t = (Tensor::ones(50, DType::F32, &Cpu)? * 42.)?;
|
||||
let s = format!("\n{t}");
|
||||
let expected = r#"
|
||||
[42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42.,
|
||||
42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42.,
|
||||
42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42.,
|
||||
42., 42.]
|
||||
Tensor[[50], f32]"#;
|
||||
assert_eq!(&s, expected);
|
||||
let t = (Tensor::ones(11000, DType::F32, &Cpu)? * 42.)?;
|
||||
let s = format!("{t}");
|
||||
assert_eq!(
|
||||
&s,
|
||||
"[42., 42., 42., ..., 42., 42., 42.]\nTensor[[11000], f32]"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_multi_dim() -> Result<()> {
|
||||
let t = (Tensor::ones((200, 100), DType::F32, &Cpu)? * 42.)?;
|
||||
let s = format!("\n{t}");
|
||||
let expected = r#"
|
||||
[[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
...
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.]]
|
||||
Tensor[[200, 100], f32]"#;
|
||||
assert_eq!(&s, expected);
|
||||
let t = t.reshape(&[2, 1, 1, 100, 100])?;
|
||||
let t = format!("\n{t}");
|
||||
let expected = r#"
|
||||
[[[[[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
...
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.]]]],
|
||||
[[[[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
...
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.],
|
||||
[42., 42., 42., ..., 42., 42., 42.]]]]]
|
||||
Tensor[[2, 1, 1, 100, 100], f32]"#;
|
||||
assert_eq!(&t, expected);
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user