mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add topk sampling. (#1923)
This commit is contained in:
@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> {
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_with_top_k() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::TopK {
|
||||
k: 1,
|
||||
temperature: 1.0,
|
||||
},
|
||||
);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 3);
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::TopK {
|
||||
k: 2,
|
||||
temperature: 1.0,
|
||||
},
|
||||
);
|
||||
let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 3);
|
||||
let token = logits_process.sample(&logits)?;
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user