Fix reinforcement learning example (#2837)

This commit is contained in:
LongYinan
2025-03-26 08:27:45 -07:00
committed by GitHub
parent 0d4097031c
commit cb02b389d5
3 changed files with 15 additions and 16 deletions

View File

@ -5,7 +5,7 @@ use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential, func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap, VarBuilder, VarMap,
}; };
use rand::{distributions::Uniform, thread_rng, Rng}; use rand::{distr::Uniform, rng, Rng};
use super::gym_env::GymEnv; use super::gym_env::GymEnv;
@ -103,8 +103,8 @@ impl ReplayBuffer {
if self.size < batch_size { if self.size < batch_size {
Ok(None) Ok(None)
} else { } else {
let transitions: Vec<&Transition> = thread_rng() let transitions: Vec<&Transition> = rng()
.sample_iter(Uniform::from(0..self.size)) .sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
.take(batch_size) .take(batch_size)
.map(|i| self.buffer.get(i).unwrap()) .map(|i| self.buffer.get(i).unwrap())
.collect(); .collect();
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
OuNoise::new(MU, THETA, SIGMA, size_action)?, OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?; )?;
let mut rng = rand::thread_rng(); let mut rng = rand::rng();
for episode in 0..MAX_EPISODES { for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?; // let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?; let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0; let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH { for _ in 0..EPISODE_LENGTH {
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
agent.train = false; agent.train = false;
for episode in 0..10 { for episode in 0..10 {
// let mut state = env.reset(episode as u64)?; // let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?; let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0; let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH { for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?; let mut action = 2.0 * agent.actions(&state)?;

View File

@ -1,9 +1,8 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use rand::distributions::Uniform; use rand::{distr::Uniform, rng, Rng};
use rand::{thread_rng, Rng};
use candle::{DType, Device, Module, Result, Tensor}; use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::loss::mse; use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap}; use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
// fed to the model so that it performs a backward pass. // fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE { if memory.len() > BATCH_SIZE {
// Sample randomly from the memory. // Sample randomly from the memory.
let batch = thread_rng() let batch = rng()
.sample_iter(Uniform::from(0..memory.len())) .sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
.take(BATCH_SIZE) .take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone()) .map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@ -4,7 +4,7 @@ use candle_nn::{
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer, linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
ParamsAdamW, VarBuilder, VarMap, ParamsAdamW, VarBuilder, VarMap,
}; };
use rand::{distributions::Distribution, rngs::ThreadRng, Rng}; use rand::{distr::Distribution, rngs::ThreadRng, Rng};
fn new_model( fn new_model(
input_shape: &[usize], input_shape: &[usize],
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
} }
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> { fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?; let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
let mut rng = rng; let mut rng = rng;
Ok(distribution.sample(&mut rng)) Ok(distribution.sample(&mut rng))
} }
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?; let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let mut rng = rand::thread_rng(); let mut rng = rand::rng();
for epoch_idx in 0..100 { for epoch_idx in 0..100 {
let mut state = env.reset(rng.gen::<u64>())?; let mut state = env.reset(rng.random::<u64>())?;
let mut steps: Vec<Step<i64>> = vec![]; let mut steps: Vec<Step<i64>> = vec![];
loop { loop {
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
steps.push(step.copy_with_obs(&state)); steps.push(step.copy_with_obs(&state));
if step.terminated || step.truncated { if step.terminated || step.truncated {
state = env.reset(rng.gen::<u64>())?; state = env.reset(rng.random::<u64>())?;
if steps.len() > 5000 { if steps.len() > 5000 {
break; break;
} }