diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index da2adf37..8390a4a0 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -9,11 +9,14 @@ impl Tensor { &self, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - let prefix = match self.device() { - crate::Device::Cpu => "Cpu", - crate::Device::Cuda(_) => "Cuda", + let device_str = match self.device().location() { + crate::DeviceLocation::Cpu => "".to_owned(), + crate::DeviceLocation::Cuda { gpu_id } => { + format!(", cuda:{}", gpu_id) + } }; - write!(f, "{prefix}Tensor[")?; + + write!(f, "Tensor[")?; match self.dims() { [] => { if let Ok(v) = self.to_scalar::() { @@ -40,7 +43,7 @@ impl Tensor { } } } - write!(f, "; {}]", self.dtype().as_str()) + write!(f, "; {} ,{}]", self.dtype().as_str(), device_str) } } @@ -467,6 +470,20 @@ impl std::fmt::Display for Tensor { } } }; - write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str()) + + let device_str = match self.device().location() { + crate::DeviceLocation::Cpu => "".to_owned(), + crate::DeviceLocation::Cuda { gpu_id } => { + format!(", cuda:{}", gpu_id) + } + }; + + write!( + f, + "Tensor[{:?}, {}{}]", + self.dims(), + self.dtype().as_str(), + device_str + ) } }