mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add support for i64 (#563)
* Add the i64 dtype. * Adapt the cuda kernels.
This commit is contained in:
@ -9,6 +9,7 @@ use half::{bf16, f16};
|
||||
pub enum CpuStorage {
|
||||
U8(Vec<u8>),
|
||||
U32(Vec<u32>),
|
||||
I64(Vec<i64>),
|
||||
BF16(Vec<bf16>),
|
||||
F16(Vec<f16>),
|
||||
F32(Vec<f32>),
|
||||
@ -25,6 +26,7 @@ pub trait Map1 {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||
@ -45,6 +47,7 @@ pub trait Map1Any {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||
@ -68,6 +71,7 @@ pub trait Map2 {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
@ -96,6 +100,7 @@ pub trait Map2U8 {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
@ -1527,6 +1532,7 @@ impl BackendStorage for CpuStorage {
|
||||
match self {
|
||||
Self::U8(_) => DType::U8,
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::I64(_) => DType::I64,
|
||||
Self::BF16(_) => DType::BF16,
|
||||
Self::F16(_) => DType::F16,
|
||||
Self::F32(_) => DType::F32,
|
||||
@ -1545,6 +1551,10 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::I64(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::BF16(data))
|
||||
@ -1569,6 +1579,10 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F16) => {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
|
||||
Ok(Self::F16(data))
|
||||
@ -1593,6 +1607,10 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F32) => {
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
@ -1629,18 +1647,26 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::U8(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U8) => {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::I64(storage), DType::U8) => {
|
||||
let data = unary_map(storage, layout, |v| v as u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
(Self::U8(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::I64(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
@ -1657,6 +1683,34 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U8(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U32(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::I64(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F16(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::I64) => {
|
||||
let data = unary_map(storage, layout, |v| v as i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U8(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
@ -1665,6 +1719,10 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::I64(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, layout, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
@ -1791,6 +1849,7 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
|
||||
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1840,6 +1899,10 @@ impl BackendStorage for CpuStorage {
|
||||
let data = unary_map(storage, layout, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::I64(storage) => {
|
||||
let data = unary_map(storage, layout, B::i64);
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1890,6 +1953,14 @@ impl BackendStorage for CpuStorage {
|
||||
};
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::I64(lhs), Self::I64(rhs)) => {
|
||||
let data = if B::I64_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
|
||||
};
|
||||
Ok(Self::I64(data))
|
||||
}
|
||||
(Self::U8(lhs), Self::U8(rhs)) => {
|
||||
let data = if B::U8_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
|
||||
@ -1914,6 +1985,7 @@ impl BackendStorage for CpuStorage {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -1942,6 +2014,7 @@ impl BackendStorage for CpuStorage {
|
||||
match self {
|
||||
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
|
||||
}
|
||||
}
|
||||
@ -1970,6 +2043,7 @@ impl BackendStorage for CpuStorage {
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
@ -1978,6 +2052,7 @@ impl BackendStorage for CpuStorage {
|
||||
match ids {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
@ -1994,6 +2069,7 @@ impl BackendStorage for CpuStorage {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
@ -2022,6 +2098,13 @@ impl BackendStorage for CpuStorage {
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
@ -2074,7 +2157,9 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
|
||||
}
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
@ -2118,7 +2203,9 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
}
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
|
||||
@ -2162,6 +2249,7 @@ impl BackendDevice for CpuDevice {
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
@ -2175,6 +2263,7 @@ impl BackendDevice for CpuDevice {
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
|
||||
|
Reference in New Issue
Block a user