mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
@ -1,8 +1,6 @@
|
||||
[workspace]
|
||||
|
||||
members = [
|
||||
"candle-core",
|
||||
"candle-hub",
|
||||
"candle-kernels",
|
||||
]
|
||||
|
||||
|
@ -24,6 +24,7 @@ mod var_store;
|
||||
mod weights;
|
||||
|
||||
const CONTEXT_SIZE: usize = 512;
|
||||
const USE_KV_CACHE: bool = false;
|
||||
const START_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
@ -218,13 +219,16 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
#[derive(Clone)]
|
||||
struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
fn new(device: &Device) -> Self {
|
||||
fn new(config: &Config, device: &Device) -> Self {
|
||||
Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
@ -249,7 +253,6 @@ struct CausalSelfAttention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
n_head: usize,
|
||||
// n_embd: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
@ -265,6 +268,11 @@ impl CausalSelfAttention {
|
||||
|
||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let mut dims = x.dims().to_vec();
|
||||
let freqs_cis = if dims[1] < CONTEXT_SIZE {
|
||||
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
let v = dims.pop().unwrap();
|
||||
dims.push(v / 2);
|
||||
dims.push(2);
|
||||
@ -285,7 +293,7 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||
let (t, c) = x.shape().r2()?;
|
||||
let qkv = self.c_attn.forward(x)?;
|
||||
let qkv = qkv.to_dtype(DType::F32)?;
|
||||
@ -296,9 +304,31 @@ impl CausalSelfAttention {
|
||||
let target_dim = [t, self.n_head, c / self.n_head];
|
||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
|
||||
if USE_KV_CACHE {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > CONTEXT_SIZE {
|
||||
k = k
|
||||
.narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > CONTEXT_SIZE {
|
||||
v = v
|
||||
.narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k_shape = k.shape();
|
||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||
@ -330,8 +360,11 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||
let x = (self
|
||||
.attn
|
||||
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
||||
+ x)?;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
||||
Ok(x)
|
||||
}
|
||||
@ -358,8 +391,8 @@ impl Llama {
|
||||
// TODO: Support for mini-batches? (i.e. r2)
|
||||
let t = x.shape().r1()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, freqs_cis)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.narrow(0, t - 1, 1)?;
|
||||
@ -400,7 +433,7 @@ struct Args {
|
||||
|
||||
/// Use npy instead of safetensors
|
||||
#[arg(long)]
|
||||
npy: bool,
|
||||
npy: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
@ -426,33 +459,35 @@ async fn main() -> Result<()> {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
let config = Config::config_7b();
|
||||
let cache = Cache::new(&device);
|
||||
let cache = Cache::new(&config, &device);
|
||||
let start = std::time::Instant::now();
|
||||
let (llama, tokenizer_filename) = if args.npy {
|
||||
println!("building the model (NPY)");
|
||||
(
|
||||
Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
|
||||
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
println!("building the model");
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
filenames.push(filename);
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(npy) => {
|
||||
println!("building the model (NPY)");
|
||||
let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
|
||||
let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
|
||||
(weights, token_path)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
println!("building the model");
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
println!("building the model (SF)");
|
||||
(
|
||||
Llama::load(&device, &filenames, &cache, &config)?,
|
||||
tokenizer_filename,
|
||||
)
|
||||
println!("building the model (SF)");
|
||||
(
|
||||
Llama::load(&device, &filenames, &cache, &config)?,
|
||||
tokenizer_filename,
|
||||
)
|
||||
}
|
||||
};
|
||||
println!("Loaded in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
@ -470,7 +505,12 @@ async fn main() -> Result<()> {
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let context_size = if USE_KV_CACHE && index > 0 {
|
||||
1
|
||||
} else {
|
||||
CONTEXT_SIZE
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
|
||||
|
@ -10,6 +10,14 @@ pub enum Error {
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||
NarrowInvalidArgs {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
start: usize,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
|
@ -349,21 +349,34 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + length`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
if dim >= dims.len() || start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
})?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout().narrow(dim, start, length)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, len))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout().narrow(dim, start, len)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user