Add the fill kernel and use it for 'ones'.

This commit is contained in:
laurent
2023-06-22 08:33:32 +01:00
parent fc26bab3ed
commit 0a758ffa05
3 changed files with 56 additions and 4 deletions

View File

@ -82,10 +82,7 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Instead of allocating memory on the host and transfering it,
// allocate some zeros on the device and use a shader to set them to 1.
let storage = CpuStorage::ones_impl(shape, dtype);
let storage = device.cuda_from_cpu_storage(&storage)?;
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}