mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add support for i64 (#563)
* Add the i64 dtype. * Adapt the cuda kernels.
This commit is contained in:
@ -139,6 +139,14 @@ impl CudaDevice {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
@ -236,6 +244,10 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -265,11 +277,13 @@ impl BackendDevice for CudaDevice {
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?,
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
@ -297,11 +311,13 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?,
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand
|
||||
@ -336,6 +352,10 @@ impl BackendDevice for CudaDevice {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
@ -364,6 +384,7 @@ impl BackendDevice for CudaDevice {
|
||||
enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
@ -383,6 +404,7 @@ trait Map1 {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
@ -406,6 +428,7 @@ trait Map2 {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
@ -437,6 +460,7 @@ trait Map2InPlace {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
@ -459,6 +483,7 @@ trait Map1Any {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
@ -482,6 +507,7 @@ trait Map2Any {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
@ -714,6 +740,9 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
@ -773,8 +802,11 @@ impl<'a> Map1 for Gather<'a> {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8 or u32",
|
||||
msg: "gather ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -820,9 +852,10 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8 or u32",
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -867,9 +900,10 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8 or u32",
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
@ -1080,8 +1114,12 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u32")
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_i64")
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u8 or u32",
|
||||
msg: "where conditions should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -1225,6 +1263,7 @@ macro_rules! cuda_dtype {
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(i64, I64);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
@ -1338,6 +1377,7 @@ impl BackendStorage for CudaStorage {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U8(_) => DType::U8,
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::I64(_) => DType::I64,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
@ -1363,6 +1403,7 @@ impl BackendStorage for CudaStorage {
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
@ -1385,6 +1426,12 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::I64 => {
|
||||
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
@ -1469,6 +1516,11 @@ impl BackendStorage for CudaStorage {
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::I64(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
Ok(CpuStorage::I64(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
||||
@ -1588,6 +1640,7 @@ impl BackendStorage for CudaStorage {
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
@ -1802,6 +1855,18 @@ impl BackendStorage for CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
|
Reference in New Issue
Block a user