Add topk sampling. (#1923)

This commit is contained in:
Laurent Mazare
2024-03-23 15:26:09 +01:00
committed by GitHub
parent fdfe8fd129
commit a62a97340c
2 changed files with 88 additions and 24 deletions

View File

@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> {
assert_eq!(token, 2);
Ok(())
}
#[test]
fn sample_with_top_k() -> Result<()> {
let mut logits_process = LogitsProcessor::from_sampling(
42,
candle_transformers::generation::Sampling::TopK {
k: 1,
temperature: 1.0,
},
);
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
let token = logits_process.sample(&logits)?;
assert_eq!(token, 3);
let mut logits_process = LogitsProcessor::from_sampling(
42,
candle_transformers::generation::Sampling::TopK {
k: 2,
temperature: 1.0,
},
);
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
let token = logits_process.sample(&logits)?;
assert_eq!(token, 3);
let token = logits_process.sample(&logits)?;
assert_eq!(token, 2);
Ok(())
}