Convert the logits to f32 before extracting them. (#102)

This commit is contained in:
Laurent Mazare
2023-07-07 08:07:57 +01:00
committed by GitHub
parent 02b5c38049
commit d38a926c14

View File

@ -52,7 +52,7 @@ impl TextGeneration {
let start_gen = std::time::Instant::now();
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = TEMPERATURE {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;