Add support for i64 (#563)

* Add the i64 dtype.

* Adapt the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-08-23 10:42:19 +01:00
committed by GitHub
parent 3743bed2d7
commit 9a5c7db91a
16 changed files with 313 additions and 36 deletions

View File

@ -9,6 +9,7 @@ use half::{bf16, f16};
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
I64(Vec<i64>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
@ -25,6 +26,7 @@ pub trait Map1 {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
@ -45,6 +47,7 @@ pub trait Map1Any {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
@ -68,6 +71,7 @@ pub trait Map2 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
@ -96,6 +100,7 @@ pub trait Map2U8 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
@ -1527,6 +1532,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
@ -1545,6 +1551,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I64(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::BF16(data))
@ -1569,6 +1579,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I64(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
@ -1593,6 +1607,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I64(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
@ -1629,18 +1647,26 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I64(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data))
}
(Self::I64(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
@ -1657,6 +1683,34 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I64(data))
}
(Self::BF16(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::F16(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::F32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
@ -1665,6 +1719,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I64(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
@ -1791,6 +1849,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
@ -1840,6 +1899,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
Self::I64(storage) => {
let data = unary_map(storage, layout, B::i64);
Ok(Self::I64(data))
}
}
}
@ -1890,6 +1953,14 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
(Self::I64(lhs), Self::I64(rhs)) => {
let data = if B::I64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
};
Ok(Self::I64(data))
}
(Self::U8(lhs), Self::U8(rhs)) => {
let data = if B::U8_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
@ -1914,6 +1985,7 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -1942,6 +2014,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
}
@ -1970,6 +2043,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
}
}
@ -1978,6 +2052,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
}
}
@ -1994,6 +2069,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
}
}
@ -2022,6 +2098,13 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
}
}
@ -2074,7 +2157,9 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
@ -2118,7 +2203,9 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
}
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
@ -2162,6 +2249,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
@ -2175,6 +2263,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),