mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the fill kernel and use it for 'ones'.
This commit is contained in:
@ -82,10 +82,7 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Instead of allocating memory on the host and transfering it,
|
||||
// allocate some zeros on the device and use a shader to set them to 1.
|
||||
let storage = CpuStorage::ones_impl(shape, dtype);
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user