From cb02b389d53a1cf5547dfa69b5168bdc1a50d325 Mon Sep 17 00:00:00 2001 From: LongYinan Date: Wed, 26 Mar 2025 08:27:45 -0700 Subject: [PATCH] Fix reinforcement learning example (#2837) --- .../examples/reinforcement-learning/ddpg.rs | 12 ++++++------ .../examples/reinforcement-learning/dqn.rs | 9 ++++----- .../reinforcement-learning/policy_gradient.rs | 10 +++++----- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index 389caac1..541dc796 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -5,7 +5,7 @@ use candle_nn::{ func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential, VarBuilder, VarMap, }; -use rand::{distributions::Uniform, thread_rng, Rng}; +use rand::{distr::Uniform, rng, Rng}; use super::gym_env::GymEnv; @@ -103,8 +103,8 @@ impl ReplayBuffer { if self.size < batch_size { Ok(None) } else { - let transitions: Vec<&Transition> = thread_rng() - .sample_iter(Uniform::from(0..self.size)) + let transitions: Vec<&Transition> = rng() + .sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?) .take(batch_size) .map(|i| self.buffer.get(i).unwrap()) .collect(); @@ -498,11 +498,11 @@ pub fn run() -> Result<()> { OuNoise::new(MU, THETA, SIGMA, size_action)?, )?; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for episode in 0..MAX_EPISODES { // let mut state = env.reset(episode as u64)?; - let mut state = env.reset(rng.gen::())?; + let mut state = env.reset(rng.random::())?; let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { @@ -538,7 +538,7 @@ pub fn run() -> Result<()> { agent.train = false; for episode in 0..10 { // let mut state = env.reset(episode as u64)?; - let mut state = env.reset(rng.gen::())?; + let mut state = env.reset(rng.random::())?; let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { let mut action = 2.0 * agent.actions(&state)?; diff --git a/candle-examples/examples/reinforcement-learning/dqn.rs b/candle-examples/examples/reinforcement-learning/dqn.rs index 83457810..f08e84b0 100644 --- a/candle-examples/examples/reinforcement-learning/dqn.rs +++ b/candle-examples/examples/reinforcement-learning/dqn.rs @@ -1,9 +1,8 @@ use std::collections::VecDeque; -use rand::distributions::Uniform; -use rand::{thread_rng, Rng}; +use rand::{distr::Uniform, rng, Rng}; -use candle::{DType, Device, Module, Result, Tensor}; +use candle::{DType, Device, Error, Module, Result, Tensor}; use candle_nn::loss::mse; use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap}; @@ -65,8 +64,8 @@ pub fn run() -> Result<()> { // fed to the model so that it performs a backward pass. if memory.len() > BATCH_SIZE { // Sample randomly from the memory. - let batch = thread_rng() - .sample_iter(Uniform::from(0..memory.len())) + let batch = rng() + .sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?) .take(BATCH_SIZE) .map(|i| memory.get(i).unwrap().clone()) .collect::>(); diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs index 3ae2617d..8f797358 100644 --- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -4,7 +4,7 @@ use candle_nn::{ linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap, }; -use rand::{distributions::Distribution, rngs::ThreadRng, Rng}; +use rand::{distr::Distribution, rngs::ThreadRng, Rng}; fn new_model( input_shape: &[usize], @@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step]) -> Vec { } fn weighted_sample(probs: Vec, rng: &mut ThreadRng) -> Result { - let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?; + let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?; let mut rng = rng; Ok(distribution.sample(&mut rng)) } @@ -65,10 +65,10 @@ pub fn run() -> Result<()> { let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for epoch_idx in 0..100 { - let mut state = env.reset(rng.gen::())?; + let mut state = env.reset(rng.random::())?; let mut steps: Vec> = vec![]; loop { @@ -84,7 +84,7 @@ pub fn run() -> Result<()> { steps.push(step.copy_with_obs(&state)); if step.terminated || step.truncated { - state = env.reset(rng.gen::())?; + state = env.reset(rng.random::())?; if steps.len() > 5000 { break; }