mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add Policy Gradient to Reinforcement Learning examples (#1500)
* added policy_gradient, modified main, ddpg and README * fixed typo in README * removed unnecessary imports * small refactor * Use clap for picking up the subcommand to run. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -8,9 +8,16 @@ Python package with:
|
|||||||
pip install "gymnasium[accept-rom-license]"
|
pip install "gymnasium[accept-rom-license]"
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to run the example, use the following command. Note the additional
|
In order to run the examples, use the following commands. Note the additional
|
||||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||||
crate.
|
crate.
|
||||||
|
|
||||||
|
For the Policy Gradient example:
|
||||||
```bash
|
```bash
|
||||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
|
||||||
|
```
|
||||||
|
|
||||||
|
For the Deep Deterministic Policy Gradient example:
|
||||||
|
```bash
|
||||||
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
|
||||||
```
|
```
|
||||||
|
@ -8,6 +8,8 @@ use candle_nn::{
|
|||||||
};
|
};
|
||||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::gym_env::GymEnv;
|
||||||
|
|
||||||
pub struct OuNoise {
|
pub struct OuNoise {
|
||||||
mu: f64,
|
mu: f64,
|
||||||
theta: f64,
|
theta: f64,
|
||||||
@ -449,3 +451,106 @@ impl DDPG<'_> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let size_state = env.observation_space().iter().product::<usize>();
|
||||||
|
let size_action = env.action_space();
|
||||||
|
|
||||||
|
let mut agent = DDPG::new(
|
||||||
|
&Device::Cpu,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
true,
|
||||||
|
ACTOR_LEARNING_RATE,
|
||||||
|
CRITIC_LEARNING_RATE,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
agent.remember(
|
||||||
|
&state,
|
||||||
|
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||||
|
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||||
|
&step.state,
|
||||||
|
step.terminated,
|
||||||
|
step.truncated,
|
||||||
|
);
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Testing...");
|
||||||
|
agent.train = false;
|
||||||
|
for episode in 0..10 {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -6,139 +6,32 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::Result;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
mod gym_env;
|
mod gym_env;
|
||||||
mod vec_gym_env;
|
mod vec_gym_env;
|
||||||
|
|
||||||
mod ddpg;
|
mod ddpg;
|
||||||
|
mod policy_gradient;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
#[derive(Parser)]
|
||||||
use clap::Parser;
|
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
// The impact of the q value of the next state on the current state's q value.
|
|
||||||
const GAMMA: f64 = 0.99;
|
|
||||||
// The weight for updating the target networks.
|
|
||||||
const TAU: f64 = 0.005;
|
|
||||||
// The capacity of the replay buffer used for sampling training data.
|
|
||||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
|
||||||
// The training batch size for each training iteration.
|
|
||||||
const TRAINING_BATCH_SIZE: usize = 100;
|
|
||||||
// The total number of episodes.
|
|
||||||
const MAX_EPISODES: usize = 100;
|
|
||||||
// The maximum length of an episode.
|
|
||||||
const EPISODE_LENGTH: usize = 200;
|
|
||||||
// The number of training iterations after one episode finishes.
|
|
||||||
const TRAINING_ITERATIONS: usize = 200;
|
|
||||||
|
|
||||||
// Ornstein-Uhlenbeck process parameters.
|
|
||||||
const MU: f64 = 0.0;
|
|
||||||
const THETA: f64 = 0.15;
|
|
||||||
const SIGMA: f64 = 0.1;
|
|
||||||
|
|
||||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
|
||||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Run on CPU rather than on GPU.
|
#[command(subcommand)]
|
||||||
#[arg(long)]
|
command: Command,
|
||||||
cpu: bool,
|
}
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
#[derive(Subcommand)]
|
||||||
#[arg(long)]
|
enum Command {
|
||||||
tracing: bool,
|
Pg,
|
||||||
|
Ddpg,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
let _guard = if args.tracing {
|
Command::Pg => policy_gradient::run()?,
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
Command::Ddpg => ddpg::run()?,
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
|
||||||
println!("action space: {}", env.action_space());
|
|
||||||
println!("observation space: {:?}", env.observation_space());
|
|
||||||
|
|
||||||
let size_state = env.observation_space().iter().product::<usize>();
|
|
||||||
let size_action = env.action_space();
|
|
||||||
|
|
||||||
let mut agent = ddpg::DDPG::new(
|
|
||||||
&Device::Cpu,
|
|
||||||
size_state,
|
|
||||||
size_action,
|
|
||||||
true,
|
|
||||||
ACTOR_LEARNING_RATE,
|
|
||||||
CRITIC_LEARNING_RATE,
|
|
||||||
GAMMA,
|
|
||||||
TAU,
|
|
||||||
REPLAY_BUFFER_CAPACITY,
|
|
||||||
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
agent.remember(
|
|
||||||
&state,
|
|
||||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
|
||||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
|
||||||
&step.state,
|
|
||||||
step.terminated,
|
|
||||||
step.truncated,
|
|
||||||
);
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
|
|
||||||
for _ in 0..TRAINING_ITERATIONS {
|
|
||||||
agent.train(TRAINING_BATCH_SIZE)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Testing...");
|
|
||||||
agent.train = false;
|
|
||||||
for episode in 0..10 {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,146 @@
|
|||||||
|
use super::gym_env::{GymEnv, Step};
|
||||||
|
use candle::{DType, Device, Error, Module, Result, Tensor};
|
||||||
|
use candle_nn::{
|
||||||
|
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
||||||
|
ParamsAdamW, VarBuilder, VarMap,
|
||||||
|
};
|
||||||
|
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
||||||
|
|
||||||
|
fn new_model(
|
||||||
|
input_shape: &[usize],
|
||||||
|
num_actions: usize,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(impl Module, VarMap)> {
|
||||||
|
let input_size = input_shape.iter().product();
|
||||||
|
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let model = seq()
|
||||||
|
.add(linear(input_size, 32, var_builder.pp("lin1"))?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(32, num_actions, var_builder.pp("lin2"))?);
|
||||||
|
|
||||||
|
Ok((model, varmap))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
||||||
|
let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
|
||||||
|
let mut acc_reward = 0f64;
|
||||||
|
for (i, reward) in rewards.iter_mut().enumerate().rev() {
|
||||||
|
if steps[i].terminated {
|
||||||
|
acc_reward = 0.0;
|
||||||
|
}
|
||||||
|
acc_reward += *reward;
|
||||||
|
*reward = acc_reward;
|
||||||
|
}
|
||||||
|
rewards
|
||||||
|
}
|
||||||
|
|
||||||
|
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
||||||
|
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||||
|
let mut rng = rng;
|
||||||
|
Ok(distribution.sample(&mut rng))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("CartPole-v1")?;
|
||||||
|
|
||||||
|
println!("action space: {:?}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let (model, varmap) = new_model(
|
||||||
|
env.observation_space(),
|
||||||
|
env.action_space(),
|
||||||
|
DType::F32,
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let optimizer_params = ParamsAdamW {
|
||||||
|
lr: 0.01,
|
||||||
|
weight_decay: 0.01,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for epoch_idx in 0..100 {
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut steps: Vec<Step<i64>> = vec![];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let action = {
|
||||||
|
let action_probs: Vec<f32> =
|
||||||
|
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||||
|
.squeeze(0)?
|
||||||
|
.to_vec1()?;
|
||||||
|
weighted_sample(action_probs, &mut rng)? as i64
|
||||||
|
};
|
||||||
|
|
||||||
|
let step = env.step(action)?;
|
||||||
|
steps.push(step.copy_with_obs(&state));
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
state = env.reset(rng.gen::<u64>())?;
|
||||||
|
if steps.len() > 5000 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
|
||||||
|
let episodes: i64 = steps
|
||||||
|
.iter()
|
||||||
|
.map(|s| (s.terminated || s.truncated) as i64)
|
||||||
|
.sum();
|
||||||
|
println!(
|
||||||
|
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
|
||||||
|
epoch_idx,
|
||||||
|
episodes,
|
||||||
|
total_reward / episodes as f64
|
||||||
|
);
|
||||||
|
|
||||||
|
let batch_size = steps.len();
|
||||||
|
|
||||||
|
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.detach()?;
|
||||||
|
|
||||||
|
let actions_mask = {
|
||||||
|
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||||
|
let actions_mask: Vec<Tensor> = actions
|
||||||
|
.iter()
|
||||||
|
.map(|&action| {
|
||||||
|
// One-hot encoding
|
||||||
|
let mut action_mask = vec![0.0; env.action_space()];
|
||||||
|
action_mask[action as usize] = 1.0;
|
||||||
|
|
||||||
|
Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||||
|
};
|
||||||
|
|
||||||
|
let states = {
|
||||||
|
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||||
|
Tensor::stack(&states, 0)?.detach()?
|
||||||
|
};
|
||||||
|
|
||||||
|
let log_probs = actions_mask
|
||||||
|
.mul(&log_softmax(&model.forward(&states)?, 1)?)?
|
||||||
|
.sum(1)?;
|
||||||
|
|
||||||
|
let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
|
||||||
|
optimizer.backward_step(&loss)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user