Add support for i64 (#563)

* Add the i64 dtype.

* Adapt the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-08-23 10:42:19 +01:00
committed by GitHub
parent 3743bed2d7
commit 9a5c7db91a
16 changed files with 313 additions and 36 deletions

View File

@ -49,6 +49,7 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
@ -431,6 +432,12 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I64 => {
let tf: IntFormatter<i64> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);