From 8c81a7017092d405658eea2ab12f82c3c80d6b28 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 21:16:35 +0100 Subject: [PATCH] PyTorch like display implementation. --- candle-core/src/display.rs | 517 +++++++++++++++++++------------------ candle-core/src/lib.rs | 4 +- candle-core/src/utils.rs | 2 +- 3 files changed, 265 insertions(+), 258 deletions(-) diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index e9053a1d..8d1f16f6 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -1,7 +1,7 @@ /// Pretty printing of tensors /// This implementation should be in line with the PyTorch version. /// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py -use crate::{DType, Tensor, WithDType}; +use crate::{DType, Result, Tensor, WithDType}; use half::{bf16, f16}; impl Tensor { @@ -52,26 +52,7 @@ impl std::fmt::Debug for Tensor { } } -/* #[allow(dead_code)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum BasicKind { - Float, - Int, - Bool, - Complex, -} - -impl BasicKind { - fn for_tensor(t: &Tensor) -> BasicKind { - match t.dtype() { - DType::U32 => BasicKind::Int, - DType::BF16 | DType::F16 | DType::F32 | DType::F64 => BasicKind::Float, - } - } -} - - /// Options for Tensor pretty printing pub struct PrinterOptions { precision: usize, @@ -81,9 +62,20 @@ pub struct PrinterOptions { sci_mode: Option, } -lazy_static! { - static ref PRINT_OPTS: std::sync::Mutex = - std::sync::Mutex::new(Default::default()); +static PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(PrinterOptions::const_default()); + +impl PrinterOptions { + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } } pub fn set_print_options(options: PrinterOptions) { @@ -91,7 +83,7 @@ pub fn set_print_options(options: PrinterOptions) { } pub fn set_print_options_default() { - *PRINT_OPTS.lock().unwrap() = Default::default() + *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default() } pub fn set_print_options_short() { @@ -114,122 +106,6 @@ pub fn set_print_options_full() { } } -impl Default for PrinterOptions { - fn default() -> Self { - Self { - precision: 4, - threshold: 1000, - edge_items: 3, - line_width: 80, - sci_mode: None, - } - } -} - -trait TensorFormatter { - type Elem; - - fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; - - fn value(tensor: &Tensor) -> Self::Elem; - - fn values(tensor: &Tensor) -> Vec; - - fn max_width(&self, to_display: &Tensor) -> usize { - let mut max_width = 1; - for v in Self::values(to_display) { - let mut fmt_size = FmtSize::new(); - let _res = self.fmt(v, 1, &mut fmt_size); - max_width = usize::max(max_width, fmt_size.final_size()) - } - max_width - } - - fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { - writeln!(f)?; - for _ in 0..i { - write!(f, " ")? - } - Ok(()) - } - - fn fmt_tensor( - &self, - t: &Tensor, - indent: usize, - max_w: usize, - summarize: bool, - po: &PrinterOptions, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - let size = t.size(); - let edge_items = po.edge_items as i64; - write!(f, "[")?; - match size.as_slice() { - [] => self.fmt(Self::value(t), max_w, f)?, - [v] if summarize && *v > 2 * edge_items => { - for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() { - self.fmt(v, max_w, f)?; - write!(f, ", ")?; - } - write!(f, "...")?; - for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() { - write!(f, ", ")?; - self.fmt(v, max_w, f)? - } - } - [_] => { - let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); - for (i, v) in Self::values(t).into_iter().enumerate() { - if i > 0 { - if i % elements_per_line == 0 { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } else { - write!(f, ", ")?; - } - } - self.fmt(v, max_w, f)? - } - } - _ => { - if summarize && size[0] > 2 * edge_items { - for i in 0..edge_items { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - write!(f, "...")?; - Self::write_newline_indent(indent, f)?; - for i in size[0] - edge_items..size[0] { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - if i + 1 != size[0] { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - } - } else { - for i in 0..size[0] { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - if i + 1 != size[0] { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - } - } - } - } - write!(f, "]")?; - Ok(()) - } -} - -struct FloatFormatter { - int_mode: bool, - sci_mode: bool, - precision: usize, -} - struct FmtSize { current_size: usize, } @@ -251,26 +127,161 @@ impl std::fmt::Write for FmtSize { } } -impl FloatFormatter { - fn new(t: &Tensor, po: &PrinterOptions) -> Self { +trait TensorFormatter { + type Elem: WithDType; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) { + for &v in vs.iter() { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let dims = t.dims(); + let edge_items = po.edge_items; + write!(f, "[")?; + match dims { + [] => { + if let Ok(v) = t.to_scalar::() { + self.fmt(v, max_w, f)? + } + } + [v] if summarize && *v > 2 * edge_items => { + if let Ok(vs) = t + .narrow(0, 0, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + write!(f, "...")?; + if let Ok(vs) = t + .narrow(0, v - edge_items, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + if let Ok(vs) = t.to_vec1::() { + for (i, v) in vs.into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + } + _ => { + if summarize && dims[0] > 2 * edge_items { + for i in 0..edge_items { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in dims[0] - edge_items..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, + _phantom: std::marker::PhantomData, +} + +impl FloatFormatter +where + S: WithDType + num_traits::Float, +{ + fn new(t: &Tensor, po: &PrinterOptions) -> Result { let mut int_mode = true; let mut sci_mode = false; - let _guard = crate::no_grad_guard(); - let t = t.to_device(crate::Device::Cpu); - // Rather than containing all values, this should only include // values that end up being displayed according to [threshold]. - let nonzero_finite_vals = { - let t = t.reshape([-1]); - t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) - }; - - let values = Vec::::try_from(&nonzero_finite_vals).unwrap(); - if nonzero_finite_vals.numel() > 0 { - let nonzero_finite_abs = nonzero_finite_vals.abs(); - let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); - let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]); + let values = t + .flatten_all()? + .to_vec1()? + .into_iter() + .filter(|v: &S| v.is_finite() && !v.is_zero()) + .collect::>(); + if !values.is_empty() { + let mut nonzero_finite_min = S::max_value(); + let mut nonzero_finite_max = S::min_value(); + for &v in values.iter() { + if v < nonzero_finite_min { + nonzero_finite_min = v + } + if v > nonzero_finite_max { + nonzero_finite_max = v + } + } for &value in values.iter() { if value.ceil() != value { @@ -279,25 +290,35 @@ impl FloatFormatter { } } - sci_mode = nonzero_finite_max / nonzero_finite_min > 1000. - || nonzero_finite_max > 1e8 - || nonzero_finite_min < 1e-4 + if let Some(v1) = S::from(1000.) { + if let Some(v2) = S::from(1e8) { + if let Some(v3) = S::from(1e-4) { + sci_mode = nonzero_finite_max / nonzero_finite_min > v1 + || nonzero_finite_max > v2 + || nonzero_finite_min < v3 + } + } + } } match po.sci_mode { None => {} Some(v) => sci_mode = v, } - Self { + Ok(Self { int_mode, sci_mode, precision: po.precision, - } + _phantom: std::marker::PhantomData, + }) } } -impl TensorFormatter for FloatFormatter { - type Elem = f64; +impl TensorFormatter for FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp, +{ + type Elem = S; fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { if self.sci_mode { @@ -324,125 +345,111 @@ impl TensorFormatter for FloatFormatter { ) } } +} - fn value(tensor: &Tensor) -> Self::Elem { - tensor.double_value(&[]) - } +struct IntFormatter { + _phantom: std::marker::PhantomData, +} - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() +impl IntFormatter { + fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } } } -struct IntFormatter; - -impl TensorFormatter for IntFormatter { - type Elem = i64; +impl TensorFormatter for IntFormatter +where + S: WithDType + std::fmt::Display, +{ + type Elem = S; fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { write!(f, "{v:max_w$}") } - - fn value(tensor: &Tensor) -> Self::Elem { - tensor.int64_value(&[]) - } - - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() - } } -struct BoolFormatter; - -impl TensorFormatter for BoolFormatter { - type Elem = bool; - - fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { - let v = if v { "true" } else { "false" }; - write!(f, "{v:max_w$}") - } - - fn value(tensor: &Tensor) -> Self::Elem { - tensor.int64_value(&[]) != 0 - } - - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() - } -} - -fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor { - let size = t.size(); - if size.is_empty() { - t.shallow_clone() - } else if size.len() == 1 { - if size[0] > 2 * edge_items { +fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result { + let dims = t.dims(); + if dims.is_empty() { + Ok(t.clone()) + } else if dims.len() == 1 { + if dims[0] > 2 * edge_items { Tensor::cat( &[ - t.slice(0, None, Some(edge_items), 1), - t.slice(0, Some(-edge_items), None, 1), + t.narrow(0, 0, edge_items)?, + t.narrow(0, dims[0] - edge_items, edge_items)?, ], 0, ) } else { - t.shallow_clone() + Ok(t.clone()) } - } else if size[0] > 2 * edge_items { + } else if dims[0] > 2 * edge_items { let mut vs: Vec<_> = (0..edge_items) - .map(|i| get_summarized_data(&t.get(i), edge_items)) - .collect(); - for i in (size[0] - edge_items)..size[0] { - vs.push(get_summarized_data(&t.get(i), edge_items)) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + for i in (dims[0] - edge_items)..dims[0] { + vs.push(get_summarized_data(&t.get(i)?, edge_items)?) } - Tensor::stack(&vs, 0) + Tensor::cat(&vs, 0) } else { - let vs: Vec<_> = (0..size[0]) - .map(|i| get_summarized_data(&t.get(i), edge_items)) - .collect(); - Tensor::stack(&vs, 0) + let vs: Vec<_> = (0..dims[0]) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + Tensor::cat(&vs, 0) } } impl std::fmt::Display for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - if self.defined() { - let po = PRINT_OPTS.lock().unwrap(); - let summarize = self.numel() > po.threshold; - let basic_kind = BasicKind::for_tensor(self); - let to_display = if summarize { - get_summarized_data(self, po.edge_items as i64) - } else { - self.shallow_clone() - }; - match basic_kind { - BasicKind::Int => { - let tf = IntFormatter; - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Float => { - let tf = FloatFormatter::new(&to_display, &po); - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Bool => { - let tf = BoolFormatter; - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Complex => {} - }; - let kind = match self.f_kind() { - Ok(kind) => format!("{kind:?}"), - Err(err) => format!("{err:?}"), - }; - write!(f, "Tensor[{:?}, {}]", self.size(), kind) + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.elem_count() > po.threshold; + let to_display = if summarize { + match get_summarized_data(self, po.edge_items) { + Ok(v) => v, + Err(err) => return write!(f, "{err:?}"), + } } else { - write!(f, "Tensor[Undefined]") - } + self.clone() + }; + match self.dtype() { + DType::U32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::BF16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F64 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F32 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + }; + write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str()) } } -*/ diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 2084d7ca..5771517f 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -3,7 +3,7 @@ mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; mod device; -mod display; +pub mod display; mod dtype; mod dummy_cuda_backend; mod error; @@ -13,7 +13,7 @@ mod shape; mod storage; mod strided_index; mod tensor; -mod utils; +pub mod utils; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 0be63c66..4b1e941b 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -pub(crate) fn get_num_threads() -> usize { +pub fn get_num_threads() -> usize { // Respond to the same environment variable as rayon. match std::env::var("RAYON_NUM_THREADS") .ok()