Simplify Tensor::randn. (#255)

* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
This commit is contained in:
Laurent Mazare
2023-07-27 07:40:36 +01:00
committed by GitHub
parent 89ba005962
commit 6475bfadfe
10 changed files with 111 additions and 72 deletions

View File

@ -5,6 +5,11 @@ use anyhow::Result;
use candle::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
let c = a.matmul(&b)?;
println!("{a} {b} {c}");
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
let t1 = Tensor::new(data, &Device::Cpu)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];