mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21. * Also adapt the RL example. * Fix for the pyo3-onnx bindings... * Print details on failures. * Revert pyi.
This commit is contained in:
@ -42,7 +42,7 @@ impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let gym = py.import_bound("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -66,10 +66,10 @@ impl GymEnv {
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
let kwargs = PyDict::new_bound(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
|
||||
state.bind(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
@ -81,8 +81,10 @@ impl GymEnv {
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self
|
||||
.env
|
||||
.call_method_bound(py, "step", (action.clone(),), None)?;
|
||||
let step = step.bind(py);
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
|
Reference in New Issue
Block a user