mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21. * Also adapt the RL example. * Fix for the pyo3-onnx bindings... * Print details on failures. * Revert pyi.
This commit is contained in:
@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let sys = py.import_bound("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let gym = py.import_bound("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -60,10 +60,10 @@ impl VecGymEnv {
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self.env.call_method_bound(py, "step", (action,), None)?;
|
||||
let step = step.bind(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
|
Reference in New Issue
Block a user