mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Fix T5 kv cache (#899)
* Fix T5 kv cache * Add argument for decoder prompt * Fix range
This commit is contained in:
@ -46,12 +46,16 @@ struct Args {
|
|||||||
|
|
||||||
// Enable/disable decoding.
|
// Enable/disable decoding.
|
||||||
#[arg(long, default_value = "false")]
|
#[arg(long, default_value = "false")]
|
||||||
use_cache: bool,
|
disable_cache: bool,
|
||||||
|
|
||||||
/// Use this prompt, otherwise compute sentence similarities.
|
/// Use this prompt, otherwise compute sentence similarities.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||||
|
#[arg(long)]
|
||||||
|
decoder_prompt: Option<String>,
|
||||||
|
|
||||||
/// L2 normalization for embeddings.
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "true")]
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
@ -116,7 +120,7 @@ impl T5ModelBuilder {
|
|||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||||
config.use_cache = args.use_cache;
|
config.use_cache = !args.disable_cache;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
Ok((
|
Ok((
|
||||||
Self {
|
Self {
|
||||||
@ -170,6 +174,16 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
let mut model = builder.build_conditional_generation()?;
|
let mut model = builder.build_conditional_generation()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
|
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||||
|
print!("{decoder_prompt}");
|
||||||
|
output_token_ids.extend(
|
||||||
|
tokenizer
|
||||||
|
.encode(decoder_prompt.to_string(), false)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec(),
|
||||||
|
);
|
||||||
|
}
|
||||||
let temperature = if args.temperature <= 0. {
|
let temperature = if args.temperature <= 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@ -195,11 +209,11 @@ fn main() -> Result<()> {
|
|||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
} else {
|
} else {
|
||||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
&logits,
|
&logits,
|
||||||
args.repeat_penalty,
|
args.repeat_penalty,
|
||||||
&tokens[start_at..],
|
&output_token_ids[start_at..],
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -217,8 +231,8 @@ fn main() -> Result<()> {
|
|||||||
let dt = start.elapsed();
|
let dt = start.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n{} tokens generated ({:.2} token/s)\n",
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
tokens.len(),
|
output_token_ids.len(),
|
||||||
tokens.len() as f64 / dt.as_secs_f64(),
|
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -348,9 +348,14 @@ impl T5Attention {
|
|||||||
None => (scores, None),
|
None => (scores, None),
|
||||||
Some(relative_attention_bias) => {
|
Some(relative_attention_bias) => {
|
||||||
// This only handles the bidirectional case.
|
// This only handles the bidirectional case.
|
||||||
|
let kv_len = k.dim(2)?;
|
||||||
|
let (q_start, q_end) = match self.use_cache {
|
||||||
|
true => ((kv_len - q_len) as u32, kv_len as u32),
|
||||||
|
false => (0_u32, kv_len as u32),
|
||||||
|
};
|
||||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||||
let max_exact = num_buckets / 2;
|
let max_exact = num_buckets / 2;
|
||||||
let relative_position = (0..q_len as u32)
|
let relative_position = (q_start..q_end)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
(0..kv_len as u32)
|
(0..kv_len as u32)
|
||||||
.map(|j| {
|
.map(|j| {
|
||||||
|
Reference in New Issue
Block a user