mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Compare commits
9 Commits
0.8.4
...
fix_mkl_fe
Author | SHA1 | Date | |
---|---|---|---|
2e273ddf31 | |||
cb02b389d5 | |||
0d4097031c | |||
10853b803c | |||
f3d472952f | |||
67b85f79f1 | |||
0b24f7f0a4 | |||
3afb04925a | |||
cbf5fc80c2 |
@ -51,7 +51,7 @@ half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_di
|
|||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1" }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
|
@ -45,6 +45,7 @@ pub enum OpCode {
|
|||||||
BinFloat = b'G',
|
BinFloat = b'G',
|
||||||
Append = b'a',
|
Append = b'a',
|
||||||
Appends = b'e',
|
Appends = b'e',
|
||||||
|
Long1 = 0x8a,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||||
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
|
|||||||
b'G' => Ok(Self::BinFloat),
|
b'G' => Ok(Self::BinFloat),
|
||||||
b'a' => Ok(Self::Append),
|
b'a' => Ok(Self::Append),
|
||||||
b'e' => Ok(Self::Appends),
|
b'e' => Ok(Self::Appends),
|
||||||
|
0x8a => Ok(Self::Long1),
|
||||||
value => Err(value),
|
value => Err(value),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +108,7 @@ pub enum Object {
|
|||||||
class_name: String,
|
class_name: String,
|
||||||
},
|
},
|
||||||
Int(i32),
|
Int(i32),
|
||||||
|
Long(i64),
|
||||||
Float(f64),
|
Float(f64),
|
||||||
Unicode(String),
|
Unicode(String),
|
||||||
Bool(bool),
|
Bool(bool),
|
||||||
@ -170,6 +173,14 @@ impl Object {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn int_or_long(self) -> OResult<i64> {
|
||||||
|
match self {
|
||||||
|
Self::Int(t) => Ok(t as i64),
|
||||||
|
Self::Long(t) => Ok(t),
|
||||||
|
_ => Err(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||||
match self {
|
match self {
|
||||||
Self::Tuple(t) => Ok(t),
|
Self::Tuple(t) => Ok(t),
|
||||||
@ -590,6 +601,15 @@ impl Stack {
|
|||||||
let obj = self.new_obj(class, args)?;
|
let obj = self.new_obj(class, args)?;
|
||||||
self.push(obj)
|
self.push(obj)
|
||||||
}
|
}
|
||||||
|
OpCode::Long1 => {
|
||||||
|
let n_bytes = r.read_u8()?;
|
||||||
|
let mut v = 0;
|
||||||
|
// Decode the next n bytes in little endian
|
||||||
|
for i in 0..n_bytes {
|
||||||
|
v |= (r.read_u8()? as i64) << (i * 8);
|
||||||
|
}
|
||||||
|
self.push(Object::Long(v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
let mut args = args.tuple()?;
|
let mut args = args.tuple()?;
|
||||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||||
let offset = args.remove(1).int()? as usize;
|
let offset = args.remove(1).int_or_long()? as usize;
|
||||||
let storage = args.remove(0).persistent_load()?;
|
let storage = args.remove(0).persistent_load()?;
|
||||||
let mut storage = storage.tuple()?;
|
let mut storage = storage.tuple()?;
|
||||||
let storage_size = storage.remove(4).int()? as usize;
|
let storage_size = storage.remove(4).int_or_long()? as usize;
|
||||||
let path = storage.remove(2).unicode()?;
|
let path = storage.remove(2).unicode()?;
|
||||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||||
let dtype = match class_name.as_str() {
|
let dtype = match class_name.as_str() {
|
||||||
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
let layout = Layout::new(
|
||||||
|
crate::Shape::from(size),
|
||||||
|
stride,
|
||||||
|
offset * dtype.size_in_bytes(),
|
||||||
|
);
|
||||||
Ok((layout, dtype, path, storage_size))
|
Ok((layout, dtype, path, storage_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,8 @@ enum Which {
|
|||||||
InstructV2_9B,
|
InstructV2_9B,
|
||||||
#[value(name = "3-1b")]
|
#[value(name = "3-1b")]
|
||||||
BaseV3_1B,
|
BaseV3_1B,
|
||||||
|
#[value(name = "3-1b-it")]
|
||||||
|
InstructV3_1B,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
@ -272,6 +274,7 @@ fn main() -> Result<()> {
|
|||||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||||
|
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -292,13 +295,10 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => {
|
None => match args.which {
|
||||||
if args.which == Which::BaseV3_1B {
|
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||||
vec![repo.get("model.safetensors")?]
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
} else {
|
},
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
@ -331,7 +331,7 @@ fn main() -> Result<()> {
|
|||||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||||
Model::V2(model)
|
Model::V2(model)
|
||||||
}
|
}
|
||||||
Which::BaseV3_1B => {
|
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
||||||
Model::V3(model)
|
Model::V3(model)
|
||||||
|
@ -7,6 +7,7 @@ extern crate accelerate_src;
|
|||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use rand::rng;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||||
@ -138,7 +139,7 @@ fn training_loop_cnn(
|
|||||||
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
||||||
for epoch in 1..args.epochs {
|
for epoch in 1..args.epochs {
|
||||||
let mut sum_loss = 0f32;
|
let mut sum_loss = 0f32;
|
||||||
batch_idxs.shuffle(&mut thread_rng());
|
batch_idxs.shuffle(&mut rng());
|
||||||
for batch_idx in batch_idxs.iter() {
|
for batch_idx in batch_idxs.iter() {
|
||||||
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||||
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||||
|
@ -5,7 +5,7 @@ use candle_nn::{
|
|||||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||||
VarBuilder, VarMap,
|
VarBuilder, VarMap,
|
||||||
};
|
};
|
||||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
use rand::{distr::Uniform, rng, Rng};
|
||||||
|
|
||||||
use super::gym_env::GymEnv;
|
use super::gym_env::GymEnv;
|
||||||
|
|
||||||
@ -103,8 +103,8 @@ impl ReplayBuffer {
|
|||||||
if self.size < batch_size {
|
if self.size < batch_size {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
let transitions: Vec<&Transition> = thread_rng()
|
let transitions: Vec<&Transition> = rng()
|
||||||
.sample_iter(Uniform::from(0..self.size))
|
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
|
||||||
.take(batch_size)
|
.take(batch_size)
|
||||||
.map(|i| self.buffer.get(i).unwrap())
|
.map(|i| self.buffer.get(i).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
|
|||||||
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
for episode in 0..MAX_EPISODES {
|
||||||
// let mut state = env.reset(episode as u64)?;
|
// let mut state = env.reset(episode as u64)?;
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
for _ in 0..EPISODE_LENGTH {
|
for _ in 0..EPISODE_LENGTH {
|
||||||
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
|
|||||||
agent.train = false;
|
agent.train = false;
|
||||||
for episode in 0..10 {
|
for episode in 0..10 {
|
||||||
// let mut state = env.reset(episode as u64)?;
|
// let mut state = env.reset(episode as u64)?;
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
let mut total_reward = 0.0;
|
let mut total_reward = 0.0;
|
||||||
for _ in 0..EPISODE_LENGTH {
|
for _ in 0..EPISODE_LENGTH {
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
|
||||||
use rand::distributions::Uniform;
|
use rand::{distr::Uniform, rng, Rng};
|
||||||
use rand::{thread_rng, Rng};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor};
|
use candle::{DType, Device, Error, Module, Result, Tensor};
|
||||||
use candle_nn::loss::mse;
|
use candle_nn::loss::mse;
|
||||||
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
|
|||||||
// fed to the model so that it performs a backward pass.
|
// fed to the model so that it performs a backward pass.
|
||||||
if memory.len() > BATCH_SIZE {
|
if memory.len() > BATCH_SIZE {
|
||||||
// Sample randomly from the memory.
|
// Sample randomly from the memory.
|
||||||
let batch = thread_rng()
|
let batch = rng()
|
||||||
.sample_iter(Uniform::from(0..memory.len()))
|
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
|
||||||
.take(BATCH_SIZE)
|
.take(BATCH_SIZE)
|
||||||
.map(|i| memory.get(i).unwrap().clone())
|
.map(|i| memory.get(i).unwrap().clone())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -4,7 +4,7 @@ use candle_nn::{
|
|||||||
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
||||||
ParamsAdamW, VarBuilder, VarMap,
|
ParamsAdamW, VarBuilder, VarMap,
|
||||||
};
|
};
|
||||||
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
use rand::{distr::Distribution, rngs::ThreadRng, Rng};
|
||||||
|
|
||||||
fn new_model(
|
fn new_model(
|
||||||
input_shape: &[usize],
|
input_shape: &[usize],
|
||||||
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
||||||
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||||
let mut rng = rng;
|
let mut rng = rng;
|
||||||
Ok(distribution.sample(&mut rng))
|
Ok(distribution.sample(&mut rng))
|
||||||
}
|
}
|
||||||
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
|
|||||||
|
|
||||||
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
for epoch_idx in 0..100 {
|
for epoch_idx in 0..100 {
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
let mut state = env.reset(rng.random::<u64>())?;
|
||||||
let mut steps: Vec<Step<i64>> = vec![];
|
let mut steps: Vec<Step<i64>> = vec![];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
|
|||||||
steps.push(step.copy_with_obs(&state));
|
steps.push(step.copy_with_obs(&state));
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
if step.terminated || step.truncated {
|
||||||
state = env.reset(rng.gen::<u64>())?;
|
state = env.reset(rng.random::<u64>())?;
|
||||||
if steps.len() > 5000 {
|
if steps.len() > 5000 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ use candle::{Device, IndexOp, Tensor};
|
|||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distr::Distribution, SeedableRng};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
@ -204,7 +204,7 @@ impl Decoder {
|
|||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||||
distr.sample(&mut self.rng) as u32
|
distr.sample(&mut self.rng) as u32
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
@ -14,7 +14,9 @@ use candle::{Device, IndexOp, Tensor};
|
|||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::distr::weighted::WeightedIndex;
|
||||||
|
use rand::distr::Distribution;
|
||||||
|
use rand::SeedableRng;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
@ -208,7 +210,7 @@ impl Decoder {
|
|||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = WeightedIndex::new(&logits_v)?;
|
||||||
distr.sample(&mut self.rng) as u32
|
distr.sample(&mut self.rng) as u32
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
@ -88,19 +88,26 @@ fn main() -> Result<()> {
|
|||||||
.arg("--use_fast_math")
|
.arg("--use_fast_math")
|
||||||
.arg("--verbose");
|
.arg("--verbose");
|
||||||
|
|
||||||
|
let mut is_target_msvc = false;
|
||||||
if let Ok(target) = std::env::var("TARGET") {
|
if let Ok(target) = std::env::var("TARGET") {
|
||||||
if target.contains("msvc") {
|
if target.contains("msvc") {
|
||||||
|
is_target_msvc = true;
|
||||||
builder = builder.arg("-D_USE_MATH_DEFINES");
|
builder = builder.arg("-D_USE_MATH_DEFINES");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !is_target_msvc {
|
||||||
|
builder = builder.arg("-Xcompiler").arg("-fPIC");
|
||||||
|
}
|
||||||
|
|
||||||
let out_file = build_dir.join("libflashattention.a");
|
let out_file = build_dir.join("libflashattention.a");
|
||||||
builder.build_lib(out_file);
|
builder.build_lib(out_file);
|
||||||
|
|
||||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
println!("cargo:rustc-link-lib=flashattention");
|
println!("cargo:rustc-link-lib=flashattention");
|
||||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
if !is_target_msvc {
|
||||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ pub struct Cache {
|
|||||||
all_data: Option<Tensor>,
|
all_data: Option<Tensor>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
current_seq_len: usize,
|
current_seq_len: usize,
|
||||||
|
grow_by: usize,
|
||||||
max_seq_len: usize,
|
max_seq_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ impl Cache {
|
|||||||
all_data: None,
|
all_data: None,
|
||||||
dim,
|
dim,
|
||||||
current_seq_len: 0,
|
current_seq_len: 0,
|
||||||
|
grow_by: max_seq_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,11 +67,11 @@ impl Cache {
|
|||||||
};
|
};
|
||||||
let ad = self.all_data.as_mut().unwrap();
|
let ad = self.all_data.as_mut().unwrap();
|
||||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||||
candle::bail!(
|
let mut shape = src.dims().to_vec();
|
||||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
shape[self.dim] = self.grow_by;
|
||||||
self.current_seq_len,
|
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||||
self.max_seq_len
|
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
|
||||||
)
|
self.max_seq_len += self.grow_by;
|
||||||
}
|
}
|
||||||
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||||
self.current_seq_len += seq_len;
|
self.current_seq_len += seq_len;
|
||||||
|
Reference in New Issue
Block a user