mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix T5 kv cache (#899)
* Fix T5 kv cache * Add argument for decoder prompt * Fix range
This commit is contained in:
@ -46,12 +46,16 @@ struct Args {
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
use_cache: bool,
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||
#[arg(long)]
|
||||
decoder_prompt: Option<String>,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
@ -116,7 +120,7 @@ impl T5ModelBuilder {
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = args.use_cache;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
@ -170,6 +174,16 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
tokenizer
|
||||
.encode(decoder_prompt.to_string(), false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec(),
|
||||
);
|
||||
}
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
@ -195,11 +209,11 @@ fn main() -> Result<()> {
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
@ -217,8 +231,8 @@ fn main() -> Result<()> {
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / dt.as_secs_f64(),
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -348,9 +348,14 @@ impl T5Attention {
|
||||
None => (scores, None),
|
||||
Some(relative_attention_bias) => {
|
||||
// This only handles the bidirectional case.
|
||||
let kv_len = k.dim(2)?;
|
||||
let (q_start, q_end) = match self.use_cache {
|
||||
true => ((kv_len - q_len) as u32, kv_len as u32),
|
||||
false => (0_u32, kv_len as u32),
|
||||
};
|
||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||
let max_exact = num_buckets / 2;
|
||||
let relative_position = (0..q_len as u32)
|
||||
let relative_position = (q_start..q_end)
|
||||
.map(|i| {
|
||||
(0..kv_len as u32)
|
||||
.map(|j| {
|
||||
|
Reference in New Issue
Block a user