Simplify Tensor::randn. (#255)

* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
This commit is contained in:
Laurent Mazare
2023-07-27 07:40:36 +01:00
committed by GitHub
parent 89ba005962
commit 6475bfadfe
10 changed files with 111 additions and 72 deletions

View File

@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {