mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix reinforcement learning example (#2837)
This commit is contained in:
@ -5,7 +5,7 @@ use candle_nn::{
|
|||||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||||
VarBuilder, VarMap,
|
VarBuilder, VarMap,
|
||||||
};
|
};
|
||||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
use rand::{distr::Uniform, rng, Rng};
|
||||||
|
|
||||||
use super::gym_env::GymEnv;
|
use super::gym_env::GymEnv;
|
||||||
|
|
||||||
@ -103,8 +103,8 @@ impl ReplayBuffer {
|
|||||||
if self.size < batch_size {
|
if self.size < batch_size {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
let transitions: Vec<&Transition> = thread_rng()
|
let transitions: Vec<&Transition> = rng()
|
||||||
.sample_iter(Uniform::from(0..self.size))
|
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
|
||||||
.take(batch_size)
|
.take(batch_size)
|
||||||
.map(|i| self.buffer.get(i).unwrap())
|
.map(|i| self.buffer.get(i).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
|
|||||||
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
for episode in 0..MAX_EPISODES {
|
||||||
// let mut state = env.reset(episode as u64)?;
|
// let mut state = env.reset(episode as u64)?;
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
for _ in 0..EPISODE_LENGTH {
|
for _ in 0..EPISODE_LENGTH {
|
||||||
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
|
|||||||
agent.train = false;
|
agent.train = false;
|
||||||
for episode in 0..10 {
|
for episode in 0..10 {
|
||||||
// let mut state = env.reset(episode as u64)?;
|
// let mut state = env.reset(episode as u64)?;
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
for _ in 0..EPISODE_LENGTH {
|
for _ in 0..EPISODE_LENGTH {
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
|
||||||
use rand::distributions::Uniform;
|
use rand::{distr::Uniform, rng, Rng};
|
||||||
use rand::{thread_rng, Rng};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor};
|
use candle::{DType, Device, Error, Module, Result, Tensor};
|
||||||
use candle_nn::loss::mse;
|
use candle_nn::loss::mse;
|
||||||
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
|
|||||||
// fed to the model so that it performs a backward pass.
|
// fed to the model so that it performs a backward pass.
|
||||||
if memory.len() > BATCH_SIZE {
|
if memory.len() > BATCH_SIZE {
|
||||||
// Sample randomly from the memory.
|
// Sample randomly from the memory.
|
||||||
let batch = thread_rng()
|
let batch = rng()
|
||||||
.sample_iter(Uniform::from(0..memory.len()))
|
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
|
||||||
.take(BATCH_SIZE)
|
.take(BATCH_SIZE)
|
||||||
.map(|i| memory.get(i).unwrap().clone())
|
.map(|i| memory.get(i).unwrap().clone())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -4,7 +4,7 @@ use candle_nn::{
|
|||||||
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
||||||
ParamsAdamW, VarBuilder, VarMap,
|
ParamsAdamW, VarBuilder, VarMap,
|
||||||
};
|
};
|
||||||
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
use rand::{distr::Distribution, rngs::ThreadRng, Rng};
|
||||||
|
|
||||||
fn new_model(
|
fn new_model(
|
||||||
input_shape: &[usize],
|
input_shape: &[usize],
|
||||||
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
||||||
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||||
let mut rng = rng;
|
let mut rng = rng;
|
||||||
Ok(distribution.sample(&mut rng))
|
Ok(distribution.sample(&mut rng))
|
||||||
}
|
}
|
||||||
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
|
|||||||
|
|
||||||
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
for epoch_idx in 0..100 {
|
for epoch_idx in 0..100 {
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
let mut steps: Vec<Step<i64>> = vec![];
|
let mut steps: Vec<Step<i64>> = vec![];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
|
|||||||
steps.push(step.copy_with_obs(&state));
|
steps.push(step.copy_with_obs(&state));
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
if step.terminated || step.truncated {
|
||||||
state = env.reset(rng.gen::<u64>())?;
|
state = env.reset(rng.random::<u64>())?;
|
||||||
if steps.len() > 5000 {
|
if steps.len() > 5000 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user