Align tensor device print more with PyTorch (#590)

* Improve tensor print

* Use CudaDevice only if enabled with cuda feature

* run rust fmt

* up

* improve

* rustfmt
This commit is contained in:
Patrick von Platen
2023-08-26 12:20:22 +02:00
committed by GitHub
parent 6559eae72c
commit 71518caeee

View File

@ -9,11 +9,14 @@ impl Tensor {
&self,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let prefix = match self.device() {
crate::Device::Cpu => "Cpu",
crate::Device::Cuda(_) => "Cuda",
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
};
write!(f, "{prefix}Tensor[")?;
write!(f, "Tensor[")?;
match self.dims() {
[] => {
if let Ok(v) = self.to_scalar::<T>() {
@ -40,7 +43,7 @@ impl Tensor {
}
}
}
write!(f, "; {}]", self.dtype().as_str())
write!(f, "; {} ,{}]", self.dtype().as_str(), device_str)
}
}
@ -467,6 +470,20 @@ impl std::fmt::Display for Tensor {
}
}
};
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
};
write!(
f,
"Tensor[{:?}, {}{}]",
self.dims(),
self.dtype().as_str(),
device_str
)
}
}