Update for pyo3 0.21. (#1985)

* Update for pyo3 0.21.

* Also adapt the RL example.

* Fix for the pyo3-onnx bindings...

* Print details on failures.

* Revert pyi.
This commit is contained in:
Laurent Mazare
2024-04-01 17:07:02 +02:00
committed by GitHub
parent 5522bbc57c
commit b20acd622c
8 changed files with 84 additions and 59 deletions

View File

@ -42,7 +42,7 @@ impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import("gymnasium")?;
let gym = py.import_bound("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
@ -66,10 +66,10 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
state.bind(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
@ -81,8 +81,10 @@ impl GymEnv {
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let step = self
.env
.call_method_bound(py, "step", (action.clone(),), None)?;
let step = step.bind(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;