PyTorch like display implementation.

This commit is contained in:
laurent
2023-06-27 21:16:35 +01:00
parent 934655a60d
commit 8c81a70170
3 changed files with 265 additions and 258 deletions

View File

@ -1,7 +1,7 @@
/// Pretty printing of tensors /// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version. /// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py /// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
use crate::{DType, Tensor, WithDType}; use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16}; use half::{bf16, f16};
impl Tensor { impl Tensor {
@ -52,26 +52,7 @@ impl std::fmt::Debug for Tensor {
} }
} }
/*
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BasicKind {
Float,
Int,
Bool,
Complex,
}
impl BasicKind {
fn for_tensor(t: &Tensor) -> BasicKind {
match t.dtype() {
DType::U32 => BasicKind::Int,
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => BasicKind::Float,
}
}
}
/// Options for Tensor pretty printing /// Options for Tensor pretty printing
pub struct PrinterOptions { pub struct PrinterOptions {
precision: usize, precision: usize,
@ -81,9 +62,20 @@ pub struct PrinterOptions {
sci_mode: Option<bool>, sci_mode: Option<bool>,
} }
lazy_static! { static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
static ref PRINT_OPTS: std::sync::Mutex<PrinterOptions> = std::sync::Mutex::new(PrinterOptions::const_default());
std::sync::Mutex::new(Default::default());
impl PrinterOptions {
// We cannot use the default trait as it's not const.
const fn const_default() -> Self {
Self {
precision: 4,
threshold: 1000,
edge_items: 3,
line_width: 80,
sci_mode: None,
}
}
} }
pub fn set_print_options(options: PrinterOptions) { pub fn set_print_options(options: PrinterOptions) {
@ -91,7 +83,7 @@ pub fn set_print_options(options: PrinterOptions) {
} }
pub fn set_print_options_default() { pub fn set_print_options_default() {
*PRINT_OPTS.lock().unwrap() = Default::default() *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
} }
pub fn set_print_options_short() { pub fn set_print_options_short() {
@ -114,122 +106,6 @@ pub fn set_print_options_full() {
} }
} }
impl Default for PrinterOptions {
fn default() -> Self {
Self {
precision: 4,
threshold: 1000,
edge_items: 3,
line_width: 80,
sci_mode: None,
}
}
}
trait TensorFormatter {
type Elem;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
fn value(tensor: &Tensor) -> Self::Elem;
fn values(tensor: &Tensor) -> Vec<Self::Elem>;
fn max_width(&self, to_display: &Tensor) -> usize {
let mut max_width = 1;
for v in Self::values(to_display) {
let mut fmt_size = FmtSize::new();
let _res = self.fmt(v, 1, &mut fmt_size);
max_width = usize::max(max_width, fmt_size.final_size())
}
max_width
}
fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f)?;
for _ in 0..i {
write!(f, " ")?
}
Ok(())
}
fn fmt_tensor(
&self,
t: &Tensor,
indent: usize,
max_w: usize,
summarize: bool,
po: &PrinterOptions,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let size = t.size();
let edge_items = po.edge_items as i64;
write!(f, "[")?;
match size.as_slice() {
[] => self.fmt(Self::value(t), max_w, f)?,
[v] if summarize && *v > 2 * edge_items => {
for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
self.fmt(v, max_w, f)?;
write!(f, ", ")?;
}
write!(f, "...")?;
for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() {
write!(f, ", ")?;
self.fmt(v, max_w, f)?
}
}
[_] => {
let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
for (i, v) in Self::values(t).into_iter().enumerate() {
if i > 0 {
if i % elements_per_line == 0 {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
} else {
write!(f, ", ")?;
}
}
self.fmt(v, max_w, f)?
}
}
_ => {
if summarize && size[0] > 2 * edge_items {
for i in 0..edge_items {
self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
write!(f, "...")?;
Self::write_newline_indent(indent, f)?;
for i in size[0] - edge_items..size[0] {
self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
if i + 1 != size[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
} else {
for i in 0..size[0] {
self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
if i + 1 != size[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
}
}
}
write!(f, "]")?;
Ok(())
}
}
struct FloatFormatter {
int_mode: bool,
sci_mode: bool,
precision: usize,
}
struct FmtSize { struct FmtSize {
current_size: usize, current_size: usize,
} }
@ -251,26 +127,161 @@ impl std::fmt::Write for FmtSize {
} }
} }
impl FloatFormatter { trait TensorFormatter {
fn new(t: &Tensor, po: &PrinterOptions) -> Self { type Elem: WithDType;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
fn max_width(&self, to_display: &Tensor) -> usize {
let mut max_width = 1;
if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
for &v in vs.iter() {
let mut fmt_size = FmtSize::new();
let _res = self.fmt(v, 1, &mut fmt_size);
max_width = usize::max(max_width, fmt_size.final_size())
}
}
max_width
}
fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f)?;
for _ in 0..i {
write!(f, " ")?
}
Ok(())
}
fn fmt_tensor(
&self,
t: &Tensor,
indent: usize,
max_w: usize,
summarize: bool,
po: &PrinterOptions,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let dims = t.dims();
let edge_items = po.edge_items;
write!(f, "[")?;
match dims {
[] => {
if let Ok(v) = t.to_scalar::<Self::Elem>() {
self.fmt(v, max_w, f)?
}
}
[v] if summarize && *v > 2 * edge_items => {
if let Ok(vs) = t
.narrow(0, 0, edge_items)
.and_then(|t| t.to_vec1::<Self::Elem>())
{
for v in vs.into_iter() {
self.fmt(v, max_w, f)?;
write!(f, ", ")?;
}
}
write!(f, "...")?;
if let Ok(vs) = t
.narrow(0, v - edge_items, edge_items)
.and_then(|t| t.to_vec1::<Self::Elem>())
{
for v in vs.into_iter() {
self.fmt(v, max_w, f)?;
write!(f, ", ")?;
}
}
}
[_] => {
let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
if let Ok(vs) = t.to_vec1::<Self::Elem>() {
for (i, v) in vs.into_iter().enumerate() {
if i > 0 {
if i % elements_per_line == 0 {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
} else {
write!(f, ", ")?;
}
}
self.fmt(v, max_w, f)?
}
}
}
_ => {
if summarize && dims[0] > 2 * edge_items {
for i in 0..edge_items {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
write!(f, "...")?;
Self::write_newline_indent(indent, f)?;
for i in dims[0] - edge_items..dims[0] {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
if i + 1 != dims[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
} else {
for i in 0..dims[0] {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
if i + 1 != dims[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
}
}
}
write!(f, "]")?;
Ok(())
}
}
struct FloatFormatter<S: WithDType> {
int_mode: bool,
sci_mode: bool,
precision: usize,
_phantom: std::marker::PhantomData<S>,
}
impl<S> FloatFormatter<S>
where
S: WithDType + num_traits::Float,
{
fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
let mut int_mode = true; let mut int_mode = true;
let mut sci_mode = false; let mut sci_mode = false;
let _guard = crate::no_grad_guard();
let t = t.to_device(crate::Device::Cpu);
// Rather than containing all values, this should only include // Rather than containing all values, this should only include
// values that end up being displayed according to [threshold]. // values that end up being displayed according to [threshold].
let nonzero_finite_vals = { let values = t
let t = t.reshape([-1]); .flatten_all()?
t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) .to_vec1()?
}; .into_iter()
.filter(|v: &S| v.is_finite() && !v.is_zero())
let values = Vec::<f64>::try_from(&nonzero_finite_vals).unwrap(); .collect::<Vec<_>>();
if nonzero_finite_vals.numel() > 0 { if !values.is_empty() {
let nonzero_finite_abs = nonzero_finite_vals.abs(); let mut nonzero_finite_min = S::max_value();
let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); let mut nonzero_finite_max = S::min_value();
let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]); for &v in values.iter() {
if v < nonzero_finite_min {
nonzero_finite_min = v
}
if v > nonzero_finite_max {
nonzero_finite_max = v
}
}
for &value in values.iter() { for &value in values.iter() {
if value.ceil() != value { if value.ceil() != value {
@ -279,25 +290,35 @@ impl FloatFormatter {
} }
} }
sci_mode = nonzero_finite_max / nonzero_finite_min > 1000. if let Some(v1) = S::from(1000.) {
|| nonzero_finite_max > 1e8 if let Some(v2) = S::from(1e8) {
|| nonzero_finite_min < 1e-4 if let Some(v3) = S::from(1e-4) {
sci_mode = nonzero_finite_max / nonzero_finite_min > v1
|| nonzero_finite_max > v2
|| nonzero_finite_min < v3
}
}
}
} }
match po.sci_mode { match po.sci_mode {
None => {} None => {}
Some(v) => sci_mode = v, Some(v) => sci_mode = v,
} }
Self { Ok(Self {
int_mode, int_mode,
sci_mode, sci_mode,
precision: po.precision, precision: po.precision,
} _phantom: std::marker::PhantomData,
})
} }
} }
impl TensorFormatter for FloatFormatter { impl<S> TensorFormatter for FloatFormatter<S>
type Elem = f64; where
S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
{
type Elem = S;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
if self.sci_mode { if self.sci_mode {
@ -324,125 +345,111 @@ impl TensorFormatter for FloatFormatter {
) )
} }
} }
}
fn value(tensor: &Tensor) -> Self::Elem { struct IntFormatter<S: WithDType> {
tensor.double_value(&[]) _phantom: std::marker::PhantomData<S>,
} }
fn values(tensor: &Tensor) -> Vec<Self::Elem> { impl<S: WithDType> IntFormatter<S> {
Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap() fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
} }
} }
struct IntFormatter; impl<S> TensorFormatter for IntFormatter<S>
where
impl TensorFormatter for IntFormatter { S: WithDType + std::fmt::Display,
type Elem = i64; {
type Elem = S;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
write!(f, "{v:max_w$}") write!(f, "{v:max_w$}")
} }
fn value(tensor: &Tensor) -> Self::Elem {
tensor.int64_value(&[])
}
fn values(tensor: &Tensor) -> Vec<Self::Elem> {
Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
}
} }
struct BoolFormatter; fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
let dims = t.dims();
impl TensorFormatter for BoolFormatter { if dims.is_empty() {
type Elem = bool; Ok(t.clone())
} else if dims.len() == 1 {
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { if dims[0] > 2 * edge_items {
let v = if v { "true" } else { "false" };
write!(f, "{v:max_w$}")
}
fn value(tensor: &Tensor) -> Self::Elem {
tensor.int64_value(&[]) != 0
}
fn values(tensor: &Tensor) -> Vec<Self::Elem> {
Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
}
}
fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor {
let size = t.size();
if size.is_empty() {
t.shallow_clone()
} else if size.len() == 1 {
if size[0] > 2 * edge_items {
Tensor::cat( Tensor::cat(
&[ &[
t.slice(0, None, Some(edge_items), 1), t.narrow(0, 0, edge_items)?,
t.slice(0, Some(-edge_items), None, 1), t.narrow(0, dims[0] - edge_items, edge_items)?,
], ],
0, 0,
) )
} else { } else {
t.shallow_clone() Ok(t.clone())
} }
} else if size[0] > 2 * edge_items { } else if dims[0] > 2 * edge_items {
let mut vs: Vec<_> = (0..edge_items) let mut vs: Vec<_> = (0..edge_items)
.map(|i| get_summarized_data(&t.get(i), edge_items)) .map(|i| get_summarized_data(&t.get(i)?, edge_items))
.collect(); .collect::<Result<Vec<_>>>()?;
for i in (size[0] - edge_items)..size[0] { for i in (dims[0] - edge_items)..dims[0] {
vs.push(get_summarized_data(&t.get(i), edge_items)) vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
} }
Tensor::stack(&vs, 0) Tensor::cat(&vs, 0)
} else { } else {
let vs: Vec<_> = (0..size[0]) let vs: Vec<_> = (0..dims[0])
.map(|i| get_summarized_data(&t.get(i), edge_items)) .map(|i| get_summarized_data(&t.get(i)?, edge_items))
.collect(); .collect::<Result<Vec<_>>>()?;
Tensor::stack(&vs, 0) Tensor::cat(&vs, 0)
} }
} }
impl std::fmt::Display for Tensor { impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.defined() { let po = PRINT_OPTS.lock().unwrap();
let po = PRINT_OPTS.lock().unwrap(); let summarize = self.elem_count() > po.threshold;
let summarize = self.numel() > po.threshold; let to_display = if summarize {
let basic_kind = BasicKind::for_tensor(self); match get_summarized_data(self, po.edge_items) {
let to_display = if summarize { Ok(v) => v,
get_summarized_data(self, po.edge_items as i64) Err(err) => return write!(f, "{err:?}"),
} else { }
self.shallow_clone()
};
match basic_kind {
BasicKind::Int => {
let tf = IntFormatter;
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
BasicKind::Float => {
let tf = FloatFormatter::new(&to_display, &po);
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
BasicKind::Bool => {
let tf = BoolFormatter;
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
BasicKind::Complex => {}
};
let kind = match self.f_kind() {
Ok(kind) => format!("{kind:?}"),
Err(err) => format!("{err:?}"),
};
write!(f, "Tensor[{:?}, {}]", self.size(), kind)
} else { } else {
write!(f, "Tensor[Undefined]") self.clone()
} };
match self.dtype() {
DType::U32 => {
let tf: IntFormatter<u32> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F16 => {
if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F64 => {
if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F32 => {
if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
};
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
} }
} }
*/

View File

@ -3,7 +3,7 @@ mod cpu_backend;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
mod cuda_backend; mod cuda_backend;
mod device; mod device;
mod display; pub mod display;
mod dtype; mod dtype;
mod dummy_cuda_backend; mod dummy_cuda_backend;
mod error; mod error;
@ -13,7 +13,7 @@ mod shape;
mod storage; mod storage;
mod strided_index; mod strided_index;
mod tensor; mod tensor;
mod utils; pub mod utils;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation}; pub use device::{Device, DeviceLocation};

View File

@ -1,6 +1,6 @@
use std::str::FromStr; use std::str::FromStr;
pub(crate) fn get_num_threads() -> usize { pub fn get_num_threads() -> usize {
// Respond to the same environment variable as rayon. // Respond to the same environment variable as rayon.
match std::env::var("RAYON_NUM_THREADS") match std::env::var("RAYON_NUM_THREADS")
.ok() .ok()