mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
This commit is contained in:
@ -411,7 +411,7 @@ impl DDPG<'_> {
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.forward(&state.detach().unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
|
@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
|
||||
loop {
|
||||
let action = {
|
||||
let action_probs: Vec<f32> =
|
||||
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
|
||||
.squeeze(0)?
|
||||
.to_vec1()?;
|
||||
weighted_sample(action_probs, &mut rng)? as i64
|
||||
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
|
||||
|
||||
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.detach()?;
|
||||
.detach();
|
||||
|
||||
let actions_mask = {
|
||||
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||
Tensor::stack(&actions_mask, 0)?.detach()
|
||||
};
|
||||
|
||||
let states = {
|
||||
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||
Tensor::stack(&states, 0)?.detach()?
|
||||
Tensor::stack(&states, 0)?.detach()
|
||||
};
|
||||
|
||||
let log_probs = actions_mask
|
||||
|
Reference in New Issue
Block a user