Fix reinforcement learning example (#2837)

This commit is contained in:
LongYinan
2025-03-26 08:27:45 -07:00
committed by GitHub
parent 0d4097031c
commit cb02b389d5
3 changed files with 15 additions and 16 deletions

View File

@ -5,7 +5,7 @@ use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use super::gym_env::GymEnv;
@ -103,8 +103,8 @@ impl ReplayBuffer {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
let transitions: Vec<&Transition> = rng()
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;