diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs new file mode 100644 index 00000000..81ca3c98 --- /dev/null +++ b/candle-core/src/display.rs @@ -0,0 +1,455 @@ +/// Pretty printing of tensors +/// This implementation should be in line with the PyTorch version. +/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py +use crate::{DType, Result, Tensor, WithDType}; +use half::{bf16, f16}; + +impl Tensor { + fn fmt_dt( + &self, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "Tensor[")?; + match self.dims() { + [] => { + if let Ok(v) = self.to_scalar::() { + write!(f, "{v}")? + } + } + [s] if *s < 10 => { + if let Ok(vs) = self.to_vec1::() { + for (i, v) in vs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{v}")?; + } + } + } + dims => { + write!(f, "dims ")?; + for (i, d) in dims.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{d}")?; + } + } + } + write!(f, "; {}]", self.dtype().as_str()) + } +} + +impl std::fmt::Debug for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self.dtype() { + DType::U32 => self.fmt_dt::(f), + DType::BF16 => self.fmt_dt::(f), + DType::F16 => self.fmt_dt::(f), + DType::F32 => self.fmt_dt::(f), + DType::F64 => self.fmt_dt::(f), + } + } +} + +#[allow(dead_code)] +/// Options for Tensor pretty printing +pub struct PrinterOptions { + precision: usize, + threshold: usize, + edge_items: usize, + line_width: usize, + sci_mode: Option, +} + +static PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(PrinterOptions::const_default()); + +impl PrinterOptions { + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } +} + +pub fn set_print_options(options: PrinterOptions) { + *PRINT_OPTS.lock().unwrap() = options +} + +pub fn set_print_options_default() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default() +} + +pub fn set_print_options_short() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 2, + threshold: 1000, + edge_items: 2, + line_width: 80, + sci_mode: None, + } +} + +pub fn set_print_options_full() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 4, + threshold: usize::MAX, + edge_items: 3, + line_width: 80, + sci_mode: None, + } +} + +struct FmtSize { + current_size: usize, +} + +impl FmtSize { + fn new() -> Self { + Self { current_size: 0 } + } + + fn final_size(self) -> usize { + self.current_size + } +} + +impl std::fmt::Write for FmtSize { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.current_size += s.len(); + Ok(()) + } +} + +trait TensorFormatter { + type Elem: WithDType; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) { + for &v in vs.iter() { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let dims = t.dims(); + let edge_items = po.edge_items; + write!(f, "[")?; + match dims { + [] => { + if let Ok(v) = t.to_scalar::() { + self.fmt(v, max_w, f)? + } + } + [v] if summarize && *v > 2 * edge_items => { + if let Ok(vs) = t + .narrow(0, 0, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + write!(f, "...")?; + if let Ok(vs) = t + .narrow(0, v - edge_items, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + write!(f, ", ")?; + self.fmt(v, max_w, f)?; + } + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + if let Ok(vs) = t.to_vec1::() { + for (i, v) in vs.into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + } + _ => { + if summarize && dims[0] > 2 * edge_items { + for i in 0..edge_items { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in dims[0] - edge_items..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, + _phantom: std::marker::PhantomData, +} + +impl FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display, +{ + fn new(t: &Tensor, po: &PrinterOptions) -> Result { + let mut int_mode = true; + let mut sci_mode = false; + + // Rather than containing all values, this should only include + // values that end up being displayed according to [threshold]. + let values = t + .flatten_all()? + .to_vec1()? + .into_iter() + .filter(|v: &S| v.is_finite() && !v.is_zero()) + .collect::>(); + if !values.is_empty() { + let mut nonzero_finite_min = S::max_value(); + let mut nonzero_finite_max = S::min_value(); + for &v in values.iter() { + let v = v.abs(); + if v < nonzero_finite_min { + nonzero_finite_min = v + } + if v > nonzero_finite_max { + nonzero_finite_max = v + } + } + + for &value in values.iter() { + if value.ceil() != value { + int_mode = false; + break; + } + } + if let Some(v1) = S::from(1000.) { + if let Some(v2) = S::from(1e8) { + if let Some(v3) = S::from(1e-4) { + sci_mode = nonzero_finite_max / nonzero_finite_min > v1 + || nonzero_finite_max > v2 + || nonzero_finite_min < v3 + } + } + } + } + + match po.sci_mode { + None => {} + Some(v) => sci_mode = v, + } + Ok(Self { + int_mode, + sci_mode, + precision: po.precision, + _phantom: std::marker::PhantomData, + }) + } +} + +impl TensorFormatter for FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp, +{ + type Elem = S; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + if self.sci_mode { + write!( + f, + "{v:width$.prec$e}", + v = v, + width = max_w, + prec = self.precision + ) + } else if self.int_mode { + if v.is_finite() { + write!(f, "{v:width$.0}.", v = v, width = max_w - 1) + } else { + write!(f, "{v:max_w$.0}") + } + } else { + write!( + f, + "{v:width$.prec$}", + v = v, + width = max_w, + prec = self.precision + ) + } + } +} + +struct IntFormatter { + _phantom: std::marker::PhantomData, +} + +impl IntFormatter { + fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + +impl TensorFormatter for IntFormatter +where + S: WithDType + std::fmt::Display, +{ + type Elem = S; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + write!(f, "{v:max_w$}") + } +} + +fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result { + let dims = t.dims(); + if dims.is_empty() { + Ok(t.clone()) + } else if dims.len() == 1 { + if dims[0] > 2 * edge_items { + Tensor::cat( + &[ + t.narrow(0, 0, edge_items)?, + t.narrow(0, dims[0] - edge_items, edge_items)?, + ], + 0, + ) + } else { + Ok(t.clone()) + } + } else if dims[0] > 2 * edge_items { + let mut vs: Vec<_> = (0..edge_items) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + for i in (dims[0] - edge_items)..dims[0] { + vs.push(get_summarized_data(&t.get(i)?, edge_items)?) + } + Tensor::cat(&vs, 0) + } else { + let vs: Vec<_> = (0..dims[0]) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + Tensor::cat(&vs, 0) + } +} + +impl std::fmt::Display for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.elem_count() > po.threshold; + let to_display = if summarize { + match get_summarized_data(self, po.edge_items) { + Ok(v) => v, + Err(err) => return write!(f, "{err:?}"), + } + } else { + self.clone() + }; + match self.dtype() { + DType::U32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::BF16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F64 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F32 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + }; + write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str()) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index b220dfb9..5771517f 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -3,6 +3,7 @@ mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; mod device; +pub mod display; mod dtype; mod dummy_cuda_backend; mod error; @@ -12,7 +13,7 @@ mod shape; mod storage; mod strided_index; mod tensor; -mod utils; +pub mod utils; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index cb968487..95254bab 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -44,12 +44,6 @@ impl std::ops::Deref for Tensor { } } -impl std::fmt::Debug for Tensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) - } -} - macro_rules! unary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self) -> Result { @@ -658,18 +652,22 @@ impl Tensor { } pub fn flatten(&self, start_dim: Option, end_dim: Option) -> Result { - let start_dim = start_dim.unwrap_or(0); - let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); - if start_dim < end_dim { - let dims = self.dims(); - let mut dst_dims = dims[..start_dim].to_vec(); - dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); - if end_dim + 1 < dims.len() { - dst_dims.extend(&dims[end_dim + 1..]); - } - self.reshape(dst_dims) + if self.rank() == 0 { + self.reshape(1) } else { - Ok(self.clone()) + let start_dim = start_dim.unwrap_or(0); + let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); + if start_dim < end_dim { + let dims = self.dims(); + let mut dst_dims = dims[..start_dim].to_vec(); + dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); + if end_dim + 1 < dims.len() { + dst_dims.extend(&dims[end_dim + 1..]); + } + self.reshape(dst_dims) + } else { + Ok(self.clone()) + } } } @@ -930,6 +928,36 @@ impl Tensor { } } + pub fn squeeze(&self, index: usize) -> Result { + // The PyTorch semantics are to return the same tensor if the target dimension + // does not have a size of 1. + let dims = self.dims(); + if dims[index] == 1 { + let mut dims = dims.to_vec(); + dims.remove(index); + self.reshape(dims) + } else { + Ok(self.clone()) + } + } + + pub fn unsqueeze(&self, index: usize) -> Result { + let mut dims = self.dims().to_vec(); + dims.insert(index, 1); + self.reshape(dims) + } + + pub fn stack>(args: &[A], dim: usize) -> Result { + if args.is_empty() { + return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); + } + let args = args + .iter() + .map(|t| t.as_ref().unsqueeze(dim)) + .collect::>>()?; + Self::cat(&args, dim) + } + pub fn cat>(args: &[A], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 0be63c66..4b1e941b 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -pub(crate) fn get_num_threads() -> usize { +pub fn get_num_threads() -> usize { // Respond to the same environment variable as rayon. match std::env::var("RAYON_NUM_THREADS") .ok() diff --git a/candle-core/tests/display_tests.rs b/candle-core/tests/display_tests.rs new file mode 100644 index 00000000..eaa60180 --- /dev/null +++ b/candle-core/tests/display_tests.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use candle::{DType, Device::Cpu, Tensor}; + +#[test] +fn display_scalar() -> Result<()> { + let t = Tensor::new(1234u32, &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[1234]\nTensor[[], u32]"); + let t = t.to_dtype(DType::F32)?.neg()?; + let s = format!("{}", (&t / 10.0)?); + assert_eq!(&s, "[-123.4000]\nTensor[[], f32]"); + let s = format!("{}", (&t / 1e8)?); + assert_eq!(&s, "[-1.2340e-5]\nTensor[[], f32]"); + let s = format!("{}", (&t * 1e8)?); + assert_eq!(&s, "[-1.2340e11]\nTensor[[], f32]"); + let s = format!("{}", (&t * 0.)?); + assert_eq!(&s, "[0.]\nTensor[[], f32]"); + Ok(()) +} + +#[test] +fn display_vector() -> Result<()> { + let t = Tensor::new::<&[u32; 0]>(&[], &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[]\nTensor[[0], u32]"); + let t = Tensor::new(&[0.1234567, 1.0, -1.2, 4.1, f64::NAN], &Cpu)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[ 0.1235, 1.0000, -1.2000, 4.1000, NaN]\nTensor[[5], f64]" + ); + let t = (Tensor::ones(50, DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42.] +Tensor[[50], f32]"#; + assert_eq!(&s, expected); + let t = (Tensor::ones(11000, DType::F32, &Cpu)? * 42.)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[42., 42., 42., ..., 42., 42., 42.]\nTensor[[11000], f32]" + ); + Ok(()) +} + +#[test] +fn display_multi_dim() -> Result<()> { + let t = (Tensor::ones((200, 100), DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]] +Tensor[[200, 100], f32]"#; + assert_eq!(&s, expected); + let t = t.reshape(&[2, 1, 1, 100, 100])?; + let t = format!("\n{t}"); + let expected = r#" +[[[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]], + [[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]]] +Tensor[[2, 1, 1, 100, 100], f32]"#; + assert_eq!(&t, expected); + Ok(()) +}