mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper - It was returning things in the wrong order - Gym now differentiates between terminated and truncated * Add DDPG * Apply fixes * Remove Result annotations * Also remove Vec annotation * rustfmt * Various small improvements (avoid cloning, mutability, get clippy to pass, ...) --------- Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -7,20 +7,22 @@ use pyo3::types::PyDict;
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub obs: Tensor,
|
||||
pub state: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub is_done: bool,
|
||||
pub terminated: bool,
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
||||
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
obs: obs.clone(),
|
||||
state: state.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
is_done: self.is_done,
|
||||
terminated: self.terminated,
|
||||
truncated: self.truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -63,14 +65,14 @@ impl GymEnv {
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
obs.as_ref(py).get_item(0)?.extract()
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
@ -78,21 +80,23 @@ impl GymEnv {
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let is_done: bool = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
let truncated: bool = step.get_item(3)?.extract()?;
|
||||
Ok((state, reward, terminated, truncated))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
||||
let state = Tensor::new(state, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
state,
|
||||
action,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user