mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -47,7 +47,7 @@ cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand"
|
|||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
@ -58,8 +58,8 @@ memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
|||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "51.0.0" }
|
parquet = { version = "51.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.9.0"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.5.1"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
safetensors = "0.4.1"
|
safetensors = "0.4.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
|
@ -2482,15 +2482,15 @@ impl BackendDevice for CpuDevice {
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::U8 | DType::U32 | DType::I64 => {
|
DType::U8 | DType::U32 | DType::I64 => {
|
||||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
|
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let mut data = Vec::with_capacity(elem_count);
|
let mut data = Vec::with_capacity(elem_count);
|
||||||
let uniform =
|
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
|
||||||
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
|
.map_err(Error::wrap)?;
|
||||||
for _i in 0..elem_count {
|
for _i in 0..elem_count {
|
||||||
data.push(rng.sample::<bf16, _>(uniform))
|
data.push(rng.sample::<bf16, _>(uniform))
|
||||||
}
|
}
|
||||||
@ -2498,8 +2498,8 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let mut data = Vec::with_capacity(elem_count);
|
let mut data = Vec::with_capacity(elem_count);
|
||||||
let uniform =
|
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
|
||||||
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
|
.map_err(Error::wrap)?;
|
||||||
for _i in 0..elem_count {
|
for _i in 0..elem_count {
|
||||||
data.push(rng.sample::<f16, _>(uniform))
|
data.push(rng.sample::<f16, _>(uniform))
|
||||||
}
|
}
|
||||||
@ -2507,7 +2507,8 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = Vec::with_capacity(elem_count);
|
let mut data = Vec::with_capacity(elem_count);
|
||||||
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
|
let uniform =
|
||||||
|
rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
|
||||||
for _i in 0..elem_count {
|
for _i in 0..elem_count {
|
||||||
data.push(rng.sample::<f32, _>(uniform))
|
data.push(rng.sample::<f32, _>(uniform))
|
||||||
}
|
}
|
||||||
@ -2515,7 +2516,7 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = Vec::with_capacity(elem_count);
|
let mut data = Vec::with_capacity(elem_count);
|
||||||
let uniform = rand::distributions::Uniform::new(min, max);
|
let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
|
||||||
for _i in 0..elem_count {
|
for _i in 0..elem_count {
|
||||||
data.push(rng.sample::<f64, _>(uniform))
|
data.push(rng.sample::<f64, _>(uniform))
|
||||||
}
|
}
|
||||||
@ -2528,7 +2529,7 @@ impl BackendDevice for CpuDevice {
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::U8 | DType::U32 | DType::I64 => {
|
DType::U8 | DType::U32 | DType::I64 => {
|
||||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||||
|
@ -880,10 +880,10 @@ fn get_random_tensors(
|
|||||||
let mut rng = StdRng::seed_from_u64(314159265358979);
|
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||||
|
|
||||||
let lhs = (0..m * k)
|
let lhs = (0..m * k)
|
||||||
.map(|_| rng.gen::<f32>() - 0.5)
|
.map(|_| rng.random::<f32>() - 0.5)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let rhs = (0..n * k)
|
let rhs = (0..n * k)
|
||||||
.map(|_| rng.gen::<f32>() - 0.5)
|
.map(|_| rng.random::<f32>() - 0.5)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
||||||
|
@ -60,8 +60,8 @@ pub struct DatasetRandomIter<'a> {
|
|||||||
|
|
||||||
impl<'a> DatasetRandomIter<'a> {
|
impl<'a> DatasetRandomIter<'a> {
|
||||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||||
|
use rand::rng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let all_tokens = if valid {
|
let all_tokens = if valid {
|
||||||
&ds.valid_tokens
|
&ds.valid_tokens
|
||||||
@ -69,13 +69,13 @@ impl<'a> DatasetRandomIter<'a> {
|
|||||||
&ds.train_tokens
|
&ds.train_tokens
|
||||||
};
|
};
|
||||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||||
tokens.shuffle(&mut thread_rng());
|
tokens.shuffle(&mut rng());
|
||||||
let current_tokens = tokens.pop().unwrap();
|
let current_tokens = tokens.pop().unwrap();
|
||||||
let seq_len_in_bytes = seq_len * 2;
|
let seq_len_in_bytes = seq_len * 2;
|
||||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||||
.step_by(seq_len_in_bytes)
|
.step_by(seq_len_in_bytes)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
indexes_in_bytes.shuffle(&mut rng());
|
||||||
Self {
|
Self {
|
||||||
all_tokens,
|
all_tokens,
|
||||||
tokens,
|
tokens,
|
||||||
@ -92,21 +92,21 @@ impl Iterator for DatasetRandomIter<'_> {
|
|||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
|
use rand::rng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let seq_len = self.seq_len;
|
let seq_len = self.seq_len;
|
||||||
if self.indexes_in_bytes.is_empty() {
|
if self.indexes_in_bytes.is_empty() {
|
||||||
if self.tokens.is_empty() {
|
if self.tokens.is_empty() {
|
||||||
self.tokens = self.all_tokens.iter().collect();
|
self.tokens = self.all_tokens.iter().collect();
|
||||||
self.tokens.shuffle(&mut thread_rng());
|
self.tokens.shuffle(&mut rng());
|
||||||
}
|
}
|
||||||
self.current_tokens = self.tokens.pop().unwrap();
|
self.current_tokens = self.tokens.pop().unwrap();
|
||||||
let seq_len_in_bytes = self.seq_len * 2;
|
let seq_len_in_bytes = self.seq_len * 2;
|
||||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||||
.step_by(seq_len_in_bytes)
|
.step_by(seq_len_in_bytes)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
self.indexes_in_bytes.shuffle(&mut rng());
|
||||||
}
|
}
|
||||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||||
|
@ -16,7 +16,7 @@ use candle_transformers::models::quantized_metavoice::transformer as qtransforme
|
|||||||
use candle::{DType, IndexOp, Tensor};
|
use candle::{DType, IndexOp, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distr::Distribution, SeedableRng};
|
||||||
|
|
||||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||||
|
|
||||||
@ -250,7 +250,7 @@ fn main() -> Result<()> {
|
|||||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
||||||
let logits = &(&logits / 1.0)?;
|
let logits = &(&logits / 1.0)?;
|
||||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
let distr = rand::distr::weighted::WeightedIndex::new(prs.as_slice())?;
|
||||||
let sample = distr.sample(&mut rng) as u32;
|
let sample = distr.sample(&mut rng) as u32;
|
||||||
codes_.push(sample)
|
codes_.push(sample)
|
||||||
}
|
}
|
||||||
|
@ -617,7 +617,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let mut scheduler = sd_config.build_scheduler(n_steps)?;
|
let mut scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
// If a seed is not given, generate a random seed and print it
|
// If a seed is not given, generate a random seed and print it
|
||||||
let seed = seed.unwrap_or(rand::thread_rng().gen_range(0u64..u64::MAX));
|
let seed = seed.unwrap_or(rand::rng().random_range(0u64..u64::MAX));
|
||||||
println!("Using seed {seed}");
|
println!("Using seed {seed}");
|
||||||
device.set_seed(seed)?;
|
device.set_seed(seed)?;
|
||||||
let use_guide_scale = guidance_scale > 1.0;
|
let use_guide_scale = guidance_scale > 1.0;
|
||||||
|
@ -83,7 +83,7 @@ fn rms_norml(device: &Device) -> Result<()> {
|
|||||||
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
||||||
let el_count = b_size * seq_len * head_dim;
|
let el_count = b_size * seq_len * head_dim;
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||||
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
||||||
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
||||||
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
|
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
|
||||||
@ -130,7 +130,7 @@ fn layer_norml(device: &Device) -> Result<()> {
|
|||||||
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
||||||
let el_count = b_size * seq_len * head_dim;
|
let el_count = b_size * seq_len * head_dim;
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||||
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
||||||
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
||||||
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
|
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
|
||||||
@ -161,12 +161,12 @@ fn ropei(device: &Device) -> Result<()> {
|
|||||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||||
let el_count = b_size * num_head * seq_len * head_dim;
|
let el_count = b_size * num_head * seq_len * head_dim;
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||||
@ -188,12 +188,12 @@ fn rope(device: &Device) -> Result<()> {
|
|||||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||||
let el_count = b_size * num_head * seq_len * head_dim;
|
let el_count = b_size * num_head * seq_len * head_dim;
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||||
@ -215,12 +215,12 @@ fn rope_thd(device: &Device) -> Result<()> {
|
|||||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||||
let el_count = b_size * num_head * seq_len * head_dim;
|
let el_count = b_size * num_head * seq_len * head_dim;
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||||
.map(|_| rng.gen::<f32>())
|
.map(|_| rng.random::<f32>())
|
||||||
.collect();
|
.collect();
|
||||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
||||||
//! and combinations thereof.
|
//! and combinations thereof.
|
||||||
use candle::{Context, DType, Error, Result, Tensor};
|
use candle::{Context, DType, Error, Result, Tensor};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distr::Distribution, SeedableRng};
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug)]
|
#[derive(Clone, PartialEq, Debug)]
|
||||||
pub enum Sampling {
|
pub enum Sampling {
|
||||||
@ -50,7 +50,7 @@ impl LogitsProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
let next_token = distr.sample(&mut self.rng) as u32;
|
let next_token = distr.sample(&mut self.rng) as u32;
|
||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ use anyhow::Error as E;
|
|||||||
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
pub use candle_transformers::models::whisper::{self as m, Config};
|
pub use candle_transformers::models::whisper::{self as m, Config};
|
||||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
@ -221,7 +221,7 @@ impl Decoder {
|
|||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||||
distr.sample(&mut self.rng) as u32
|
distr.sample(&mut self.rng) as u32
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
Reference in New Issue
Block a user