Metal random-generation bug fixes (#1811)

* use_resource API misunderstood. It is not additive. Several usages must be bit-ORed together.

* The seeding was incorrect and used the address instead of the value of the passed in seed.

* Add a check that likely exhibits failure to update the seed between generation of random tensors.

* Buffer overrun, the length given to the std::ptr::copy call was in bytes, and not 32-bit units.

* By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine.
Use device.set_seed if determinism is warranted.

* Revert "By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine. Use device.set_seed if determinism is warranted."

This reverts commit d7302de9

Discussion in https://github.com/huggingface/candle/pull/1811#issuecomment-1983079119

* The Metal random kernel failed to set element N/2 of tensors with N elements, N being even.  The reason was that all threads but thread 0 all created 2 random samples, but thread 0 only one, i.e. an odd number.  In order to produce an even number of samples, the early termination of thread 0 should only everr occur for odd sized tensors.

* Add a test catching any deterministic tensor element in rand and randn output.

---------

Co-authored-by: niklas <niklas@appli.se>
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Niklas Hallqvist
2024-03-08 16:11:50 +01:00
committed by GitHub
parent ea984d0421
commit be5b68cd0b
4 changed files with 50 additions and 13 deletions

View File

@ -1558,8 +1558,10 @@ pub fn call_random_uniform(
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(
seed,
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1589,8 +1591,10 @@ pub fn call_random_normal(
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(
seed,
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();