Fix reinforcement learning example (#2837)

This commit is contained in:
LongYinan
2025-03-26 08:27:45 -07:00
committed by GitHub
parent 0d4097031c
commit cb02b389d5
3 changed files with 15 additions and 16 deletions

View File

@ -1,9 +1,8 @@
use std::collections::VecDeque;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use candle::{DType, Device, Module, Result, Tensor};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
// fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE {
// Sample randomly from the memory.
let batch = thread_rng()
.sample_iter(Uniform::from(0..memory.len()))
let batch = rng()
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
.take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>();