mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the fill kernel and use it for 'ones'.
This commit is contained in:
@ -23,6 +23,18 @@ extern "C" __global__ void affine_f32(
|
|||||||
}
|
}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
|
const FILL_CU: &str = r#"
|
||||||
|
template<typename T>
|
||||||
|
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
|
buf[i] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
|
"#;
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||||
@ -47,6 +59,45 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let dev = &self.0;
|
||||||
|
match dtype {
|
||||||
|
DType::F32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
|
let module_name = "fill_f32";
|
||||||
|
if !dev.has_func(module_name, module_name) {
|
||||||
|
let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
|
||||||
|
dev.load_ptx(ptx, module_name, &[module_name])?;
|
||||||
|
}
|
||||||
|
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let params = (&data, v as f32, elem_count);
|
||||||
|
unsafe { fwd_fn.launch(cfg, params) }?;
|
||||||
|
Ok(CudaStorage::F32(data))
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||||
|
let module_name = "fill_f64";
|
||||||
|
if !dev.has_func(module_name, module_name) {
|
||||||
|
let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
|
||||||
|
dev.load_ptx(ptx, module_name, &[module_name])?;
|
||||||
|
}
|
||||||
|
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let params = (&data, v, elem_count);
|
||||||
|
unsafe { fwd_fn.launch(cfg, params) }?;
|
||||||
|
Ok(CudaStorage::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
self.const_impl(1., shape, dtype)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
match storage {
|
match storage {
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
|
@ -82,10 +82,7 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
// TODO: Instead of allocating memory on the host and transfering it,
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
// allocate some zeros on the device and use a shader to set them to 1.
|
|
||||||
let storage = CpuStorage::ones_impl(shape, dtype);
|
|
||||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,10 @@ impl CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user