mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a KV cache to T5. (#873)
* Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment.
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
## Encoder-decoder example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
...
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
Eine schöne Kerze.
|
||||
@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
## Sentence embedding example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
...
|
||||
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||
|
@ -48,10 +48,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
@ -131,6 +127,7 @@ impl T5ModelBuilder {
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
@ -142,32 +139,32 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
if !args.decode {
|
||||
let model = builder.build_encoder()?;
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&input_token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
let mut model = builder.build_encoder()?;
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&input_token_ids)?;
|
||||
println!("{ys}");
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let model = builder.build_conditional_generation()?;
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _index in 0.. {
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids =
|
||||
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
|
||||
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
|
||||
if (next_token_id as usize) == builder.config.eos_token_id {
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
@ -186,7 +183,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let model = builder.build_encoder()?;
|
||||
let mut model = builder.build_encoder()?;
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
|
Reference in New Issue
Block a user