Add support for i64 (#563)

* Add the i64 dtype.

* Adapt the cuda kernels.
This commit is contained in:
Laurent Mazare
2023-08-23 10:42:19 +01:00
committed by GitHub
parent 3743bed2d7
commit 9a5c7db91a
16 changed files with 313 additions and 36 deletions

View File

@ -10,6 +10,7 @@ impl From<DType> for st::Dtype {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
DType::I64 => st::Dtype::I64,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
@ -24,6 +25,7 @@ impl TryFrom<st::Dtype> for DType {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
st::Dtype::I64 => Ok(DType::I64),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
@ -189,6 +191,7 @@ impl Tensor {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::I64 => convert_slice::<i64>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
@ -205,24 +208,15 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::I32 => {
let conv = |x| Ok(i64::from(x));
convert_with_cast_::<i32, i64, _>(view, device, conv)
}
st::Dtype::I64 => convert_::<i64>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
@ -233,6 +227,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
match tensor.dtype() {
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),