Fix T5 kv cache (#899)

* Fix T5 kv cache

* Add argument for decoder prompt

* Fix range
This commit is contained in:
Juarez Bochi
2023-09-19 12:36:15 -07:00
committed by GitHub
parent d7e48234d4
commit 8696f64bae
2 changed files with 26 additions and 7 deletions

View File

@ -46,12 +46,16 @@ struct Args {
// Enable/disable decoding.
#[arg(long, default_value = "false")]
use_cache: bool,
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: Option<String>,
/// If set along with --decode, will use this prompt to initialize the decoder.
#[arg(long)]
decoder_prompt: Option<String>,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
@ -116,7 +120,7 @@ impl T5ModelBuilder {
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = args.use_cache;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
@ -170,6 +174,16 @@ fn main() -> Result<()> {
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(
tokenizer
.encode(decoder_prompt.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec(),
);
}
let temperature = if args.temperature <= 0. {
None
} else {
@ -195,11 +209,11 @@ fn main() -> Result<()> {
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
&output_token_ids[start_at..],
)?
};
@ -217,8 +231,8 @@ fn main() -> Result<()> {
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
tokens.len(),
tokens.len() as f64 / dt.as_secs_f64(),
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
}
}