mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -158,7 +158,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
@ -88,7 +88,7 @@ impl LogitsProcessor {
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = if let Some(temperature) = self.temperature {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
|
||||
let prs: Vec<f32> = prs.to_vec1()?;
|
||||
let distr =
|
||||
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
|
||||
|
Reference in New Issue
Block a user