mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add support for i64 (#563)
* Add the i64 dtype. * Adapt the cuda kernels.
This commit is contained in:
@ -49,6 +49,7 @@ impl std::fmt::Debug for Tensor {
|
||||
match self.dtype() {
|
||||
DType::U8 => self.fmt_dt::<u8>(f),
|
||||
DType::U32 => self.fmt_dt::<u32>(f),
|
||||
DType::I64 => self.fmt_dt::<i64>(f),
|
||||
DType::BF16 => self.fmt_dt::<bf16>(f),
|
||||
DType::F16 => self.fmt_dt::<f16>(f),
|
||||
DType::F32 => self.fmt_dt::<f32>(f),
|
||||
@ -431,6 +432,12 @@ impl std::fmt::Display for Tensor {
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::I64 => {
|
||||
let tf: IntFormatter<i64> = IntFormatter::new();
|
||||
let max_w = tf.max_width(&to_display);
|
||||
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
DType::BF16 => {
|
||||
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
|
||||
let max_w = tf.max_width(&to_display);
|
||||
|
Reference in New Issue
Block a user