mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Align tensor device print more with PyTorch (#590)
* Improve tensor print * Use CudaDevice only if enabled with cuda feature * run rust fmt * up * improve * rustfmt
This commit is contained in:

committed by
GitHub

parent
6559eae72c
commit
71518caeee
@ -9,11 +9,14 @@ impl Tensor {
|
||||
&self,
|
||||
f: &mut std::fmt::Formatter,
|
||||
) -> std::fmt::Result {
|
||||
let prefix = match self.device() {
|
||||
crate::Device::Cpu => "Cpu",
|
||||
crate::Device::Cuda(_) => "Cuda",
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
write!(f, "{prefix}Tensor[")?;
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
match self.dims() {
|
||||
[] => {
|
||||
if let Ok(v) = self.to_scalar::<T>() {
|
||||
@ -40,7 +43,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
write!(f, "; {}]", self.dtype().as_str())
|
||||
write!(f, "; {} ,{}]", self.dtype().as_str(), device_str)
|
||||
}
|
||||
}
|
||||
|
||||
@ -467,6 +470,20 @@ impl std::fmt::Display for Tensor {
|
||||
}
|
||||
}
|
||||
};
|
||||
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
|
||||
|
||||
let device_str = match self.device().location() {
|
||||
crate::DeviceLocation::Cpu => "".to_owned(),
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
f,
|
||||
"Tensor[{:?}, {}{}]",
|
||||
self.dims(),
|
||||
self.dtype().as_str(),
|
||||
device_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user