mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Apply the cast before the scaling. (#2135)
This commit is contained in:
@ -70,7 +70,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
|||||||
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
|
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
|
||||||
let scale = 1.0 / (1.0 - drop_p as f64);
|
let scale = 1.0 / (1.0 - drop_p as f64);
|
||||||
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
|
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
|
||||||
let mask = (rand.ge(&drop_p)? * scale)?.to_dtype(xs.dtype())?;
|
let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
|
||||||
xs * mask
|
xs * mask
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user