mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Convert the logits to f32 before extracting them. (#102)
This commit is contained in:
@ -52,7 +52,7 @@ impl TextGeneration {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let next_token = if let Some(temperature) = TEMPERATURE {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
|
Reference in New Issue
Block a user