mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Trace softmax (#568)
* Trace the softmax op. * Inline the sum. * Add min/max vec operations.
This commit is contained in:
@ -88,6 +88,7 @@ struct CrossAttention {
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
@ -111,6 +112,7 @@ impl CrossAttention {
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
@ -121,6 +123,7 @@ impl CrossAttention {
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
span_softmax,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
@ -193,9 +196,11 @@ impl CrossAttention {
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
.matmul(&value)?
|
||||
.to_dtype(in_dtype)?
|
||||
let xs = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
};
|
||||
xs.matmul(&value)?.to_dtype(in_dtype)?
|
||||
};
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
Reference in New Issue
Block a user