diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs new file mode 100644 index 00000000..e9053a1d --- /dev/null +++ b/candle-core/src/display.rs @@ -0,0 +1,448 @@ +/// Pretty printing of tensors +/// This implementation should be in line with the PyTorch version. +/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py +use crate::{DType, Tensor, WithDType}; +use half::{bf16, f16}; + +impl Tensor { + fn fmt_dt( + &self, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "Tensor[")?; + match self.dims() { + [] => { + if let Ok(v) = self.to_scalar::() { + write!(f, "{v}")? + } + } + [s] if *s < 10 => { + if let Ok(vs) = self.to_vec1::() { + for (i, v) in vs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{v}")?; + } + } + } + dims => { + write!(f, "dims ")?; + for (i, d) in dims.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{d}")?; + } + } + } + write!(f, "; {}]", self.dtype().as_str()) + } +} + +impl std::fmt::Debug for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self.dtype() { + DType::U32 => self.fmt_dt::(f), + DType::BF16 => self.fmt_dt::(f), + DType::F16 => self.fmt_dt::(f), + DType::F32 => self.fmt_dt::(f), + DType::F64 => self.fmt_dt::(f), + } + } +} + +/* +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BasicKind { + Float, + Int, + Bool, + Complex, +} + +impl BasicKind { + fn for_tensor(t: &Tensor) -> BasicKind { + match t.dtype() { + DType::U32 => BasicKind::Int, + DType::BF16 | DType::F16 | DType::F32 | DType::F64 => BasicKind::Float, + } + } +} + + +/// Options for Tensor pretty printing +pub struct PrinterOptions { + precision: usize, + threshold: usize, + edge_items: usize, + line_width: usize, + sci_mode: Option, +} + +lazy_static! { + static ref PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(Default::default()); +} + +pub fn set_print_options(options: PrinterOptions) { + *PRINT_OPTS.lock().unwrap() = options +} + +pub fn set_print_options_default() { + *PRINT_OPTS.lock().unwrap() = Default::default() +} + +pub fn set_print_options_short() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 2, + threshold: 1000, + edge_items: 2, + line_width: 80, + sci_mode: None, + } +} + +pub fn set_print_options_full() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 4, + threshold: usize::MAX, + edge_items: 3, + line_width: 80, + sci_mode: None, + } +} + +impl Default for PrinterOptions { + fn default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } +} + +trait TensorFormatter { + type Elem; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn value(tensor: &Tensor) -> Self::Elem; + + fn values(tensor: &Tensor) -> Vec; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + for v in Self::values(to_display) { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let size = t.size(); + let edge_items = po.edge_items as i64; + write!(f, "[")?; + match size.as_slice() { + [] => self.fmt(Self::value(t), max_w, f)?, + [v] if summarize && *v > 2 * edge_items => { + for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + write!(f, "...")?; + for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() { + write!(f, ", ")?; + self.fmt(v, max_w, f)? + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + for (i, v) in Self::values(t).into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + _ => { + if summarize && size[0] > 2 * edge_items { + for i in 0..edge_items { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in size[0] - edge_items..size[0] { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + if i + 1 != size[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..size[0] { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + if i + 1 != size[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, +} + +struct FmtSize { + current_size: usize, +} + +impl FmtSize { + fn new() -> Self { + Self { current_size: 0 } + } + + fn final_size(self) -> usize { + self.current_size + } +} + +impl std::fmt::Write for FmtSize { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.current_size += s.len(); + Ok(()) + } +} + +impl FloatFormatter { + fn new(t: &Tensor, po: &PrinterOptions) -> Self { + let mut int_mode = true; + let mut sci_mode = false; + + let _guard = crate::no_grad_guard(); + let t = t.to_device(crate::Device::Cpu); + + // Rather than containing all values, this should only include + // values that end up being displayed according to [threshold]. + let nonzero_finite_vals = { + let t = t.reshape([-1]); + t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) + }; + + let values = Vec::::try_from(&nonzero_finite_vals).unwrap(); + if nonzero_finite_vals.numel() > 0 { + let nonzero_finite_abs = nonzero_finite_vals.abs(); + let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); + let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]); + + for &value in values.iter() { + if value.ceil() != value { + int_mode = false; + break; + } + } + + sci_mode = nonzero_finite_max / nonzero_finite_min > 1000. + || nonzero_finite_max > 1e8 + || nonzero_finite_min < 1e-4 + } + + match po.sci_mode { + None => {} + Some(v) => sci_mode = v, + } + Self { + int_mode, + sci_mode, + precision: po.precision, + } + } +} + +impl TensorFormatter for FloatFormatter { + type Elem = f64; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + if self.sci_mode { + write!( + f, + "{v:width$.prec$e}", + v = v, + width = max_w, + prec = self.precision + ) + } else if self.int_mode { + if v.is_finite() { + write!(f, "{v:width$.0}.", v = v, width = max_w - 1) + } else { + write!(f, "{v:max_w$.0}") + } + } else { + write!( + f, + "{v:width$.prec$}", + v = v, + width = max_w, + prec = self.precision + ) + } + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.double_value(&[]) + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +struct IntFormatter; + +impl TensorFormatter for IntFormatter { + type Elem = i64; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + write!(f, "{v:max_w$}") + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.int64_value(&[]) + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +struct BoolFormatter; + +impl TensorFormatter for BoolFormatter { + type Elem = bool; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + let v = if v { "true" } else { "false" }; + write!(f, "{v:max_w$}") + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.int64_value(&[]) != 0 + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor { + let size = t.size(); + if size.is_empty() { + t.shallow_clone() + } else if size.len() == 1 { + if size[0] > 2 * edge_items { + Tensor::cat( + &[ + t.slice(0, None, Some(edge_items), 1), + t.slice(0, Some(-edge_items), None, 1), + ], + 0, + ) + } else { + t.shallow_clone() + } + } else if size[0] > 2 * edge_items { + let mut vs: Vec<_> = (0..edge_items) + .map(|i| get_summarized_data(&t.get(i), edge_items)) + .collect(); + for i in (size[0] - edge_items)..size[0] { + vs.push(get_summarized_data(&t.get(i), edge_items)) + } + Tensor::stack(&vs, 0) + } else { + let vs: Vec<_> = (0..size[0]) + .map(|i| get_summarized_data(&t.get(i), edge_items)) + .collect(); + Tensor::stack(&vs, 0) + } +} + +impl std::fmt::Display for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.defined() { + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.numel() > po.threshold; + let basic_kind = BasicKind::for_tensor(self); + let to_display = if summarize { + get_summarized_data(self, po.edge_items as i64) + } else { + self.shallow_clone() + }; + match basic_kind { + BasicKind::Int => { + let tf = IntFormatter; + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Float => { + let tf = FloatFormatter::new(&to_display, &po); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Bool => { + let tf = BoolFormatter; + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Complex => {} + }; + let kind = match self.f_kind() { + Ok(kind) => format!("{kind:?}"), + Err(err) => format!("{err:?}"), + }; + write!(f, "Tensor[{:?}, {}]", self.size(), kind) + } else { + write!(f, "Tensor[Undefined]") + } + } +} +*/ diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index b220dfb9..2084d7ca 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -3,6 +3,7 @@ mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; mod device; +mod display; mod dtype; mod dummy_cuda_backend; mod error; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index cb968487..fc67ae94 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -44,12 +44,6 @@ impl std::ops::Deref for Tensor { } } -impl std::fmt::Debug for Tensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) - } -} - macro_rules! unary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self) -> Result {