Add DDPG and fix Gym wrapper (#1207)

* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated

* Add DDPG

* Apply fixes

* Remove Result annotations

* Also remove Vec annotation

* rustfmt

* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)

---------

Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Travis Hammond
2023-10-28 20:53:34 +02:00
committed by GitHub
parent 012ae0090e
commit 498c50348c
3 changed files with 549 additions and 25 deletions

View File

@ -7,20 +7,22 @@ use pyo3::types::PyDict;
/// The return value for a step.
#[derive(Debug)]
pub struct Step<A> {
pub obs: Tensor,
pub state: Tensor,
pub action: A,
pub reward: f64,
pub is_done: bool,
pub terminated: bool,
pub truncated: bool,
}
impl<A: Copy> Step<A> {
/// Returns a copy of this step changing the observation tensor.
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
Step {
obs: obs.clone(),
state: state.clone(),
action: self.action,
reward: self.reward,
is_done: self.is_done,
terminated: self.terminated,
truncated: self.truncated,
}
}
}
@ -63,14 +65,14 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let obs: Vec<f32> = Python::with_gil(|py| {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
obs.as_ref(py).get_item(0)?.extract()
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(obs, &Device::Cpu)
Tensor::new(state, &Device::Cpu)
}
/// Applies an environment step using the specified action.
@ -78,21 +80,23 @@ impl GymEnv {
&self,
action: A,
) -> Result<Step<A>> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let obs: Vec<f32> = step.get_item(0)?.extract()?;
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let is_done: bool = step.get_item(2)?.extract()?;
Ok((obs, reward, is_done))
let terminated: bool = step.get_item(2)?.extract()?;
let truncated: bool = step.get_item(3)?.extract()?;
Ok((state, reward, terminated, truncated))
})
.map_err(w)?;
let obs = Tensor::new(obs, &Device::Cpu)?;
let state = Tensor::new(state, &Device::Cpu)?;
Ok(Step {
obs,
reward,
is_done,
state,
action,
reward,
terminated,
truncated,
})
}