From 1d504cc6b330e47158888921802ea13828ea593f Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 19:10:30 +0100 Subject: [PATCH 1/4] Rework the debug trait. --- candle-core/src/display.rs | 448 +++++++++++++++++++++++++++++++++++++ candle-core/src/lib.rs | 1 + candle-core/src/tensor.rs | 6 - 3 files changed, 449 insertions(+), 6 deletions(-) create mode 100644 candle-core/src/display.rs diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs new file mode 100644 index 00000000..e9053a1d --- /dev/null +++ b/candle-core/src/display.rs @@ -0,0 +1,448 @@ +/// Pretty printing of tensors +/// This implementation should be in line with the PyTorch version. +/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py +use crate::{DType, Tensor, WithDType}; +use half::{bf16, f16}; + +impl Tensor { + fn fmt_dt( + &self, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "Tensor[")?; + match self.dims() { + [] => { + if let Ok(v) = self.to_scalar::() { + write!(f, "{v}")? + } + } + [s] if *s < 10 => { + if let Ok(vs) = self.to_vec1::() { + for (i, v) in vs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{v}")?; + } + } + } + dims => { + write!(f, "dims ")?; + for (i, d) in dims.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{d}")?; + } + } + } + write!(f, "; {}]", self.dtype().as_str()) + } +} + +impl std::fmt::Debug for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self.dtype() { + DType::U32 => self.fmt_dt::(f), + DType::BF16 => self.fmt_dt::(f), + DType::F16 => self.fmt_dt::(f), + DType::F32 => self.fmt_dt::(f), + DType::F64 => self.fmt_dt::(f), + } + } +} + +/* +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BasicKind { + Float, + Int, + Bool, + Complex, +} + +impl BasicKind { + fn for_tensor(t: &Tensor) -> BasicKind { + match t.dtype() { + DType::U32 => BasicKind::Int, + DType::BF16 | DType::F16 | DType::F32 | DType::F64 => BasicKind::Float, + } + } +} + + +/// Options for Tensor pretty printing +pub struct PrinterOptions { + precision: usize, + threshold: usize, + edge_items: usize, + line_width: usize, + sci_mode: Option, +} + +lazy_static! { + static ref PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(Default::default()); +} + +pub fn set_print_options(options: PrinterOptions) { + *PRINT_OPTS.lock().unwrap() = options +} + +pub fn set_print_options_default() { + *PRINT_OPTS.lock().unwrap() = Default::default() +} + +pub fn set_print_options_short() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 2, + threshold: 1000, + edge_items: 2, + line_width: 80, + sci_mode: None, + } +} + +pub fn set_print_options_full() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 4, + threshold: usize::MAX, + edge_items: 3, + line_width: 80, + sci_mode: None, + } +} + +impl Default for PrinterOptions { + fn default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } +} + +trait TensorFormatter { + type Elem; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn value(tensor: &Tensor) -> Self::Elem; + + fn values(tensor: &Tensor) -> Vec; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + for v in Self::values(to_display) { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let size = t.size(); + let edge_items = po.edge_items as i64; + write!(f, "[")?; + match size.as_slice() { + [] => self.fmt(Self::value(t), max_w, f)?, + [v] if summarize && *v > 2 * edge_items => { + for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + write!(f, "...")?; + for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() { + write!(f, ", ")?; + self.fmt(v, max_w, f)? + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + for (i, v) in Self::values(t).into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + _ => { + if summarize && size[0] > 2 * edge_items { + for i in 0..edge_items { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in size[0] - edge_items..size[0] { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + if i + 1 != size[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..size[0] { + self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; + if i + 1 != size[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, +} + +struct FmtSize { + current_size: usize, +} + +impl FmtSize { + fn new() -> Self { + Self { current_size: 0 } + } + + fn final_size(self) -> usize { + self.current_size + } +} + +impl std::fmt::Write for FmtSize { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.current_size += s.len(); + Ok(()) + } +} + +impl FloatFormatter { + fn new(t: &Tensor, po: &PrinterOptions) -> Self { + let mut int_mode = true; + let mut sci_mode = false; + + let _guard = crate::no_grad_guard(); + let t = t.to_device(crate::Device::Cpu); + + // Rather than containing all values, this should only include + // values that end up being displayed according to [threshold]. + let nonzero_finite_vals = { + let t = t.reshape([-1]); + t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) + }; + + let values = Vec::::try_from(&nonzero_finite_vals).unwrap(); + if nonzero_finite_vals.numel() > 0 { + let nonzero_finite_abs = nonzero_finite_vals.abs(); + let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); + let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]); + + for &value in values.iter() { + if value.ceil() != value { + int_mode = false; + break; + } + } + + sci_mode = nonzero_finite_max / nonzero_finite_min > 1000. + || nonzero_finite_max > 1e8 + || nonzero_finite_min < 1e-4 + } + + match po.sci_mode { + None => {} + Some(v) => sci_mode = v, + } + Self { + int_mode, + sci_mode, + precision: po.precision, + } + } +} + +impl TensorFormatter for FloatFormatter { + type Elem = f64; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + if self.sci_mode { + write!( + f, + "{v:width$.prec$e}", + v = v, + width = max_w, + prec = self.precision + ) + } else if self.int_mode { + if v.is_finite() { + write!(f, "{v:width$.0}.", v = v, width = max_w - 1) + } else { + write!(f, "{v:max_w$.0}") + } + } else { + write!( + f, + "{v:width$.prec$}", + v = v, + width = max_w, + prec = self.precision + ) + } + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.double_value(&[]) + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +struct IntFormatter; + +impl TensorFormatter for IntFormatter { + type Elem = i64; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + write!(f, "{v:max_w$}") + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.int64_value(&[]) + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +struct BoolFormatter; + +impl TensorFormatter for BoolFormatter { + type Elem = bool; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + let v = if v { "true" } else { "false" }; + write!(f, "{v:max_w$}") + } + + fn value(tensor: &Tensor) -> Self::Elem { + tensor.int64_value(&[]) != 0 + } + + fn values(tensor: &Tensor) -> Vec { + Vec::::try_from(tensor.reshape(-1)).unwrap() + } +} + +fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor { + let size = t.size(); + if size.is_empty() { + t.shallow_clone() + } else if size.len() == 1 { + if size[0] > 2 * edge_items { + Tensor::cat( + &[ + t.slice(0, None, Some(edge_items), 1), + t.slice(0, Some(-edge_items), None, 1), + ], + 0, + ) + } else { + t.shallow_clone() + } + } else if size[0] > 2 * edge_items { + let mut vs: Vec<_> = (0..edge_items) + .map(|i| get_summarized_data(&t.get(i), edge_items)) + .collect(); + for i in (size[0] - edge_items)..size[0] { + vs.push(get_summarized_data(&t.get(i), edge_items)) + } + Tensor::stack(&vs, 0) + } else { + let vs: Vec<_> = (0..size[0]) + .map(|i| get_summarized_data(&t.get(i), edge_items)) + .collect(); + Tensor::stack(&vs, 0) + } +} + +impl std::fmt::Display for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.defined() { + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.numel() > po.threshold; + let basic_kind = BasicKind::for_tensor(self); + let to_display = if summarize { + get_summarized_data(self, po.edge_items as i64) + } else { + self.shallow_clone() + }; + match basic_kind { + BasicKind::Int => { + let tf = IntFormatter; + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Float => { + let tf = FloatFormatter::new(&to_display, &po); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Bool => { + let tf = BoolFormatter; + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + BasicKind::Complex => {} + }; + let kind = match self.f_kind() { + Ok(kind) => format!("{kind:?}"), + Err(err) => format!("{err:?}"), + }; + write!(f, "Tensor[{:?}, {}]", self.size(), kind) + } else { + write!(f, "Tensor[Undefined]") + } + } +} +*/ diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index b220dfb9..2084d7ca 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -3,6 +3,7 @@ mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; mod device; +mod display; mod dtype; mod dummy_cuda_backend; mod error; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index cb968487..fc67ae94 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -44,12 +44,6 @@ impl std::ops::Deref for Tensor { } } -impl std::fmt::Debug for Tensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device()) - } -} - macro_rules! unary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self) -> Result { From 934655a60d9bbf773d5b02c6aff2a8c26edd9be8 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 19:32:00 +0100 Subject: [PATCH 2/4] Add squeeze/unsqueeze/stack. --- candle-core/src/tensor.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index fc67ae94..b64f63e1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -924,6 +924,36 @@ impl Tensor { } } + pub fn squeeze(&self, index: usize) -> Result { + // The PyTorch semantics are to return the same tensor if the target dimension + // does not have a size of 1. + let dims = self.dims(); + if dims[index] == 1 { + let mut dims = dims.to_vec(); + dims.remove(index); + self.reshape(dims) + } else { + Ok(self.clone()) + } + } + + pub fn unsqueeze(&self, index: usize) -> Result { + let mut dims = self.dims().to_vec(); + dims.insert(index, 1); + self.reshape(dims) + } + + pub fn stack>(args: &[A], dim: usize) -> Result { + if args.is_empty() { + return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); + } + let args = args + .iter() + .map(|t| t.as_ref().unsqueeze(dim)) + .collect::>>()?; + Self::cat(&args, dim) + } + pub fn cat>(args: &[A], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); From 8c81a7017092d405658eea2ab12f82c3c80d6b28 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 21:16:35 +0100 Subject: [PATCH 3/4] PyTorch like display implementation. --- candle-core/src/display.rs | 517 +++++++++++++++++++------------------ candle-core/src/lib.rs | 4 +- candle-core/src/utils.rs | 2 +- 3 files changed, 265 insertions(+), 258 deletions(-) diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index e9053a1d..8d1f16f6 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -1,7 +1,7 @@ /// Pretty printing of tensors /// This implementation should be in line with the PyTorch version. /// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py -use crate::{DType, Tensor, WithDType}; +use crate::{DType, Result, Tensor, WithDType}; use half::{bf16, f16}; impl Tensor { @@ -52,26 +52,7 @@ impl std::fmt::Debug for Tensor { } } -/* #[allow(dead_code)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum BasicKind { - Float, - Int, - Bool, - Complex, -} - -impl BasicKind { - fn for_tensor(t: &Tensor) -> BasicKind { - match t.dtype() { - DType::U32 => BasicKind::Int, - DType::BF16 | DType::F16 | DType::F32 | DType::F64 => BasicKind::Float, - } - } -} - - /// Options for Tensor pretty printing pub struct PrinterOptions { precision: usize, @@ -81,9 +62,20 @@ pub struct PrinterOptions { sci_mode: Option, } -lazy_static! { - static ref PRINT_OPTS: std::sync::Mutex = - std::sync::Mutex::new(Default::default()); +static PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(PrinterOptions::const_default()); + +impl PrinterOptions { + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } } pub fn set_print_options(options: PrinterOptions) { @@ -91,7 +83,7 @@ pub fn set_print_options(options: PrinterOptions) { } pub fn set_print_options_default() { - *PRINT_OPTS.lock().unwrap() = Default::default() + *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default() } pub fn set_print_options_short() { @@ -114,122 +106,6 @@ pub fn set_print_options_full() { } } -impl Default for PrinterOptions { - fn default() -> Self { - Self { - precision: 4, - threshold: 1000, - edge_items: 3, - line_width: 80, - sci_mode: None, - } - } -} - -trait TensorFormatter { - type Elem; - - fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; - - fn value(tensor: &Tensor) -> Self::Elem; - - fn values(tensor: &Tensor) -> Vec; - - fn max_width(&self, to_display: &Tensor) -> usize { - let mut max_width = 1; - for v in Self::values(to_display) { - let mut fmt_size = FmtSize::new(); - let _res = self.fmt(v, 1, &mut fmt_size); - max_width = usize::max(max_width, fmt_size.final_size()) - } - max_width - } - - fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { - writeln!(f)?; - for _ in 0..i { - write!(f, " ")? - } - Ok(()) - } - - fn fmt_tensor( - &self, - t: &Tensor, - indent: usize, - max_w: usize, - summarize: bool, - po: &PrinterOptions, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - let size = t.size(); - let edge_items = po.edge_items as i64; - write!(f, "[")?; - match size.as_slice() { - [] => self.fmt(Self::value(t), max_w, f)?, - [v] if summarize && *v > 2 * edge_items => { - for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() { - self.fmt(v, max_w, f)?; - write!(f, ", ")?; - } - write!(f, "...")?; - for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() { - write!(f, ", ")?; - self.fmt(v, max_w, f)? - } - } - [_] => { - let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); - for (i, v) in Self::values(t).into_iter().enumerate() { - if i > 0 { - if i % elements_per_line == 0 { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } else { - write!(f, ", ")?; - } - } - self.fmt(v, max_w, f)? - } - } - _ => { - if summarize && size[0] > 2 * edge_items { - for i in 0..edge_items { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - write!(f, "...")?; - Self::write_newline_indent(indent, f)?; - for i in size[0] - edge_items..size[0] { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - if i + 1 != size[0] { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - } - } else { - for i in 0..size[0] { - self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?; - if i + 1 != size[0] { - write!(f, ",")?; - Self::write_newline_indent(indent, f)? - } - } - } - } - } - write!(f, "]")?; - Ok(()) - } -} - -struct FloatFormatter { - int_mode: bool, - sci_mode: bool, - precision: usize, -} - struct FmtSize { current_size: usize, } @@ -251,26 +127,161 @@ impl std::fmt::Write for FmtSize { } } -impl FloatFormatter { - fn new(t: &Tensor, po: &PrinterOptions) -> Self { +trait TensorFormatter { + type Elem: WithDType; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) { + for &v in vs.iter() { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let dims = t.dims(); + let edge_items = po.edge_items; + write!(f, "[")?; + match dims { + [] => { + if let Ok(v) = t.to_scalar::() { + self.fmt(v, max_w, f)? + } + } + [v] if summarize && *v > 2 * edge_items => { + if let Ok(vs) = t + .narrow(0, 0, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + write!(f, "...")?; + if let Ok(vs) = t + .narrow(0, v - edge_items, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + if let Ok(vs) = t.to_vec1::() { + for (i, v) in vs.into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + } + _ => { + if summarize && dims[0] > 2 * edge_items { + for i in 0..edge_items { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in dims[0] - edge_items..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, + _phantom: std::marker::PhantomData, +} + +impl FloatFormatter +where + S: WithDType + num_traits::Float, +{ + fn new(t: &Tensor, po: &PrinterOptions) -> Result { let mut int_mode = true; let mut sci_mode = false; - let _guard = crate::no_grad_guard(); - let t = t.to_device(crate::Device::Cpu); - // Rather than containing all values, this should only include // values that end up being displayed according to [threshold]. - let nonzero_finite_vals = { - let t = t.reshape([-1]); - t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) - }; - - let values = Vec::::try_from(&nonzero_finite_vals).unwrap(); - if nonzero_finite_vals.numel() > 0 { - let nonzero_finite_abs = nonzero_finite_vals.abs(); - let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); - let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]); + let values = t + .flatten_all()? + .to_vec1()? + .into_iter() + .filter(|v: &S| v.is_finite() && !v.is_zero()) + .collect::>(); + if !values.is_empty() { + let mut nonzero_finite_min = S::max_value(); + let mut nonzero_finite_max = S::min_value(); + for &v in values.iter() { + if v < nonzero_finite_min { + nonzero_finite_min = v + } + if v > nonzero_finite_max { + nonzero_finite_max = v + } + } for &value in values.iter() { if value.ceil() != value { @@ -279,25 +290,35 @@ impl FloatFormatter { } } - sci_mode = nonzero_finite_max / nonzero_finite_min > 1000. - || nonzero_finite_max > 1e8 - || nonzero_finite_min < 1e-4 + if let Some(v1) = S::from(1000.) { + if let Some(v2) = S::from(1e8) { + if let Some(v3) = S::from(1e-4) { + sci_mode = nonzero_finite_max / nonzero_finite_min > v1 + || nonzero_finite_max > v2 + || nonzero_finite_min < v3 + } + } + } } match po.sci_mode { None => {} Some(v) => sci_mode = v, } - Self { + Ok(Self { int_mode, sci_mode, precision: po.precision, - } + _phantom: std::marker::PhantomData, + }) } } -impl TensorFormatter for FloatFormatter { - type Elem = f64; +impl TensorFormatter for FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp, +{ + type Elem = S; fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { if self.sci_mode { @@ -324,125 +345,111 @@ impl TensorFormatter for FloatFormatter { ) } } +} - fn value(tensor: &Tensor) -> Self::Elem { - tensor.double_value(&[]) - } +struct IntFormatter { + _phantom: std::marker::PhantomData, +} - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() +impl IntFormatter { + fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } } } -struct IntFormatter; - -impl TensorFormatter for IntFormatter { - type Elem = i64; +impl TensorFormatter for IntFormatter +where + S: WithDType + std::fmt::Display, +{ + type Elem = S; fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { write!(f, "{v:max_w$}") } - - fn value(tensor: &Tensor) -> Self::Elem { - tensor.int64_value(&[]) - } - - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() - } } -struct BoolFormatter; - -impl TensorFormatter for BoolFormatter { - type Elem = bool; - - fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { - let v = if v { "true" } else { "false" }; - write!(f, "{v:max_w$}") - } - - fn value(tensor: &Tensor) -> Self::Elem { - tensor.int64_value(&[]) != 0 - } - - fn values(tensor: &Tensor) -> Vec { - Vec::::try_from(tensor.reshape(-1)).unwrap() - } -} - -fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor { - let size = t.size(); - if size.is_empty() { - t.shallow_clone() - } else if size.len() == 1 { - if size[0] > 2 * edge_items { +fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result { + let dims = t.dims(); + if dims.is_empty() { + Ok(t.clone()) + } else if dims.len() == 1 { + if dims[0] > 2 * edge_items { Tensor::cat( &[ - t.slice(0, None, Some(edge_items), 1), - t.slice(0, Some(-edge_items), None, 1), + t.narrow(0, 0, edge_items)?, + t.narrow(0, dims[0] - edge_items, edge_items)?, ], 0, ) } else { - t.shallow_clone() + Ok(t.clone()) } - } else if size[0] > 2 * edge_items { + } else if dims[0] > 2 * edge_items { let mut vs: Vec<_> = (0..edge_items) - .map(|i| get_summarized_data(&t.get(i), edge_items)) - .collect(); - for i in (size[0] - edge_items)..size[0] { - vs.push(get_summarized_data(&t.get(i), edge_items)) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + for i in (dims[0] - edge_items)..dims[0] { + vs.push(get_summarized_data(&t.get(i)?, edge_items)?) } - Tensor::stack(&vs, 0) + Tensor::cat(&vs, 0) } else { - let vs: Vec<_> = (0..size[0]) - .map(|i| get_summarized_data(&t.get(i), edge_items)) - .collect(); - Tensor::stack(&vs, 0) + let vs: Vec<_> = (0..dims[0]) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + Tensor::cat(&vs, 0) } } impl std::fmt::Display for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - if self.defined() { - let po = PRINT_OPTS.lock().unwrap(); - let summarize = self.numel() > po.threshold; - let basic_kind = BasicKind::for_tensor(self); - let to_display = if summarize { - get_summarized_data(self, po.edge_items as i64) - } else { - self.shallow_clone() - }; - match basic_kind { - BasicKind::Int => { - let tf = IntFormatter; - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Float => { - let tf = FloatFormatter::new(&to_display, &po); - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Bool => { - let tf = BoolFormatter; - let max_w = tf.max_width(&to_display); - tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; - writeln!(f)?; - } - BasicKind::Complex => {} - }; - let kind = match self.f_kind() { - Ok(kind) => format!("{kind:?}"), - Err(err) => format!("{err:?}"), - }; - write!(f, "Tensor[{:?}, {}]", self.size(), kind) + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.elem_count() > po.threshold; + let to_display = if summarize { + match get_summarized_data(self, po.edge_items) { + Ok(v) => v, + Err(err) => return write!(f, "{err:?}"), + } } else { - write!(f, "Tensor[Undefined]") - } + self.clone() + }; + match self.dtype() { + DType::U32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::BF16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F64 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F32 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + }; + write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str()) } } -*/ diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 2084d7ca..5771517f 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -3,7 +3,7 @@ mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; mod device; -mod display; +pub mod display; mod dtype; mod dummy_cuda_backend; mod error; @@ -13,7 +13,7 @@ mod shape; mod storage; mod strided_index; mod tensor; -mod utils; +pub mod utils; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 0be63c66..4b1e941b 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -pub(crate) fn get_num_threads() -> usize { +pub fn get_num_threads() -> usize { // Respond to the same environment variable as rayon. match std::env::var("RAYON_NUM_THREADS") .ok() From b0f5f2d22d3eb199ce6ba03279401a4b0997087e Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 21:37:28 +0100 Subject: [PATCH 4/4] Add some display tests + bugfixes. --- candle-core/src/display.rs | 6 +-- candle-core/src/tensor.rs | 26 +++++---- candle-core/tests/display_tests.rs | 84 ++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 candle-core/tests/display_tests.rs diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 8d1f16f6..81ca3c98 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -186,8 +186,8 @@ trait TensorFormatter { .and_then(|t| t.to_vec1::()) { for v in vs.into_iter() { - self.fmt(v, max_w, f)?; write!(f, ", ")?; + self.fmt(v, max_w, f)?; } } } @@ -257,7 +257,7 @@ struct FloatFormatter { impl FloatFormatter where - S: WithDType + num_traits::Float, + S: WithDType + num_traits::Float + std::fmt::Display, { fn new(t: &Tensor, po: &PrinterOptions) -> Result { let mut int_mode = true; @@ -275,6 +275,7 @@ where let mut nonzero_finite_min = S::max_value(); let mut nonzero_finite_max = S::min_value(); for &v in values.iter() { + let v = v.abs(); if v < nonzero_finite_min { nonzero_finite_min = v } @@ -289,7 +290,6 @@ where break; } } - if let Some(v1) = S::from(1000.) { if let Some(v2) = S::from(1e8) { if let Some(v3) = S::from(1e-4) { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b64f63e1..95254bab 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -652,18 +652,22 @@ impl Tensor { } pub fn flatten(&self, start_dim: Option, end_dim: Option) -> Result { - let start_dim = start_dim.unwrap_or(0); - let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); - if start_dim < end_dim { - let dims = self.dims(); - let mut dst_dims = dims[..start_dim].to_vec(); - dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); - if end_dim + 1 < dims.len() { - dst_dims.extend(&dims[end_dim + 1..]); - } - self.reshape(dst_dims) + if self.rank() == 0 { + self.reshape(1) } else { - Ok(self.clone()) + let start_dim = start_dim.unwrap_or(0); + let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); + if start_dim < end_dim { + let dims = self.dims(); + let mut dst_dims = dims[..start_dim].to_vec(); + dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); + if end_dim + 1 < dims.len() { + dst_dims.extend(&dims[end_dim + 1..]); + } + self.reshape(dst_dims) + } else { + Ok(self.clone()) + } } } diff --git a/candle-core/tests/display_tests.rs b/candle-core/tests/display_tests.rs new file mode 100644 index 00000000..eaa60180 --- /dev/null +++ b/candle-core/tests/display_tests.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use candle::{DType, Device::Cpu, Tensor}; + +#[test] +fn display_scalar() -> Result<()> { + let t = Tensor::new(1234u32, &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[1234]\nTensor[[], u32]"); + let t = t.to_dtype(DType::F32)?.neg()?; + let s = format!("{}", (&t / 10.0)?); + assert_eq!(&s, "[-123.4000]\nTensor[[], f32]"); + let s = format!("{}", (&t / 1e8)?); + assert_eq!(&s, "[-1.2340e-5]\nTensor[[], f32]"); + let s = format!("{}", (&t * 1e8)?); + assert_eq!(&s, "[-1.2340e11]\nTensor[[], f32]"); + let s = format!("{}", (&t * 0.)?); + assert_eq!(&s, "[0.]\nTensor[[], f32]"); + Ok(()) +} + +#[test] +fn display_vector() -> Result<()> { + let t = Tensor::new::<&[u32; 0]>(&[], &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[]\nTensor[[0], u32]"); + let t = Tensor::new(&[0.1234567, 1.0, -1.2, 4.1, f64::NAN], &Cpu)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[ 0.1235, 1.0000, -1.2000, 4.1000, NaN]\nTensor[[5], f64]" + ); + let t = (Tensor::ones(50, DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42.] +Tensor[[50], f32]"#; + assert_eq!(&s, expected); + let t = (Tensor::ones(11000, DType::F32, &Cpu)? * 42.)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[42., 42., 42., ..., 42., 42., 42.]\nTensor[[11000], f32]" + ); + Ok(()) +} + +#[test] +fn display_multi_dim() -> Result<()> { + let t = (Tensor::ones((200, 100), DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]] +Tensor[[200, 100], f32]"#; + assert_eq!(&s, expected); + let t = t.reshape(&[2, 1, 1, 100, 100])?; + let t = format!("\n{t}"); + let expected = r#" +[[[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]], + [[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]]] +Tensor[[2, 1, 1, 100, 100], f32]"#; + assert_eq!(&t, expected); + Ok(()) +}