diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 8d1f16f6..81ca3c98 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -186,8 +186,8 @@ trait TensorFormatter { .and_then(|t| t.to_vec1::()) { for v in vs.into_iter() { - self.fmt(v, max_w, f)?; write!(f, ", ")?; + self.fmt(v, max_w, f)?; } } } @@ -257,7 +257,7 @@ struct FloatFormatter { impl FloatFormatter where - S: WithDType + num_traits::Float, + S: WithDType + num_traits::Float + std::fmt::Display, { fn new(t: &Tensor, po: &PrinterOptions) -> Result { let mut int_mode = true; @@ -275,6 +275,7 @@ where let mut nonzero_finite_min = S::max_value(); let mut nonzero_finite_max = S::min_value(); for &v in values.iter() { + let v = v.abs(); if v < nonzero_finite_min { nonzero_finite_min = v } @@ -289,7 +290,6 @@ where break; } } - if let Some(v1) = S::from(1000.) { if let Some(v2) = S::from(1e8) { if let Some(v3) = S::from(1e-4) { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b64f63e1..95254bab 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -652,18 +652,22 @@ impl Tensor { } pub fn flatten(&self, start_dim: Option, end_dim: Option) -> Result { - let start_dim = start_dim.unwrap_or(0); - let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); - if start_dim < end_dim { - let dims = self.dims(); - let mut dst_dims = dims[..start_dim].to_vec(); - dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); - if end_dim + 1 < dims.len() { - dst_dims.extend(&dims[end_dim + 1..]); - } - self.reshape(dst_dims) + if self.rank() == 0 { + self.reshape(1) } else { - Ok(self.clone()) + let start_dim = start_dim.unwrap_or(0); + let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); + if start_dim < end_dim { + let dims = self.dims(); + let mut dst_dims = dims[..start_dim].to_vec(); + dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); + if end_dim + 1 < dims.len() { + dst_dims.extend(&dims[end_dim + 1..]); + } + self.reshape(dst_dims) + } else { + Ok(self.clone()) + } } } diff --git a/candle-core/tests/display_tests.rs b/candle-core/tests/display_tests.rs new file mode 100644 index 00000000..eaa60180 --- /dev/null +++ b/candle-core/tests/display_tests.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use candle::{DType, Device::Cpu, Tensor}; + +#[test] +fn display_scalar() -> Result<()> { + let t = Tensor::new(1234u32, &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[1234]\nTensor[[], u32]"); + let t = t.to_dtype(DType::F32)?.neg()?; + let s = format!("{}", (&t / 10.0)?); + assert_eq!(&s, "[-123.4000]\nTensor[[], f32]"); + let s = format!("{}", (&t / 1e8)?); + assert_eq!(&s, "[-1.2340e-5]\nTensor[[], f32]"); + let s = format!("{}", (&t * 1e8)?); + assert_eq!(&s, "[-1.2340e11]\nTensor[[], f32]"); + let s = format!("{}", (&t * 0.)?); + assert_eq!(&s, "[0.]\nTensor[[], f32]"); + Ok(()) +} + +#[test] +fn display_vector() -> Result<()> { + let t = Tensor::new::<&[u32; 0]>(&[], &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[]\nTensor[[0], u32]"); + let t = Tensor::new(&[0.1234567, 1.0, -1.2, 4.1, f64::NAN], &Cpu)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[ 0.1235, 1.0000, -1.2000, 4.1000, NaN]\nTensor[[5], f64]" + ); + let t = (Tensor::ones(50, DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42.] +Tensor[[50], f32]"#; + assert_eq!(&s, expected); + let t = (Tensor::ones(11000, DType::F32, &Cpu)? * 42.)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[42., 42., 42., ..., 42., 42., 42.]\nTensor[[11000], f32]" + ); + Ok(()) +} + +#[test] +fn display_multi_dim() -> Result<()> { + let t = (Tensor::ones((200, 100), DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]] +Tensor[[200, 100], f32]"#; + assert_eq!(&s, expected); + let t = t.reshape(&[2, 1, 1, 100, 100])?; + let t = format!("\n{t}"); + let expected = r#" +[[[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]], + [[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]]] +Tensor[[2, 1, 1, 100, 100], f32]"#; + assert_eq!(&t, expected); + Ok(()) +}