Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. (#383)

This commit is contained in:
Laurent Mazare
2023-08-10 06:48:19 +02:00
committed by GitHub
parent 3bbc08a8df
commit c7f92f985e
2 changed files with 21 additions and 37 deletions

View File

@ -9,23 +9,6 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
fn randn_hasneg(device: &Device) -> Result<()> {
let s = 200;
let t = Tensor::randn(
0f32,
1f32, s
as usize,
&Device::Cpu
)?
.to_vec1::<f32>()?;
for i in t {
if i < 0. {
return Ok(())
}
}
panic!("randn failed to generate a negative number")
}
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@ -866,7 +849,6 @@ fn broadcasting(device: &Device) -> Result<()> {
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(randn_hasneg, randn_hasneg_cpu, randn_hasneg_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
@ -887,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
#[test]
fn randn_hasneg() -> Result<()> {
let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
if t.iter().all(|&v| v >= 0.) {
candle_core::bail!("all values in tensors are non-negative")
}
Ok(())
}