mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
@ -1,8 +1,6 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
|
|
||||||
members = [
|
members = [
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-hub",
|
"candle-hub",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ mod var_store;
|
|||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const CONTEXT_SIZE: usize = 512;
|
||||||
|
const USE_KV_CACHE: bool = false;
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
@ -218,13 +219,16 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Cache {
|
struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
fn new(device: &Device) -> Self {
|
fn new(config: &Config, device: &Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -249,7 +253,6 @@ struct CausalSelfAttention {
|
|||||||
c_attn: Linear,
|
c_attn: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
// n_embd: usize,
|
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,6 +268,11 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let mut dims = x.dims().to_vec();
|
let mut dims = x.dims().to_vec();
|
||||||
|
let freqs_cis = if dims[1] < CONTEXT_SIZE {
|
||||||
|
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
|
||||||
|
} else {
|
||||||
|
freqs_cis.clone()
|
||||||
|
};
|
||||||
let v = dims.pop().unwrap();
|
let v = dims.pop().unwrap();
|
||||||
dims.push(v / 2);
|
dims.push(v / 2);
|
||||||
dims.push(2);
|
dims.push(2);
|
||||||
@ -285,7 +293,7 @@ impl CausalSelfAttention {
|
|||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||||
let (t, c) = x.shape().r2()?;
|
let (t, c) = x.shape().r2()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
let qkv = qkv.to_dtype(DType::F32)?;
|
let qkv = qkv.to_dtype(DType::F32)?;
|
||||||
@ -296,9 +304,31 @@ impl CausalSelfAttention {
|
|||||||
let target_dim = [t, self.n_head, c / self.n_head];
|
let target_dim = [t, self.n_head, c / self.n_head];
|
||||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
|
|
||||||
|
if USE_KV_CACHE {
|
||||||
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||||
|
k = Tensor::cat(&[cache_k, &k], 1)?;
|
||||||
|
v = Tensor::cat(&[cache_v, &v], 1)?;
|
||||||
|
let k_seq_len = k.dims()[1];
|
||||||
|
if k_seq_len > CONTEXT_SIZE {
|
||||||
|
k = k
|
||||||
|
.narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
||||||
|
.contiguous()?
|
||||||
|
}
|
||||||
|
let v_seq_len = v.dims()[1];
|
||||||
|
if v_seq_len > CONTEXT_SIZE {
|
||||||
|
v = v
|
||||||
|
.narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
||||||
|
.contiguous()?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
let k_shape = k.shape();
|
let k_shape = k.shape();
|
||||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||||
@ -330,8 +360,11 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||||
let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
|
let x = (self
|
||||||
|
.attn
|
||||||
|
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
||||||
|
+ x)?;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
@ -358,8 +391,8 @@ impl Llama {
|
|||||||
// TODO: Support for mini-batches? (i.e. r2)
|
// TODO: Support for mini-batches? (i.e. r2)
|
||||||
let t = x.shape().r1()?;
|
let t = x.shape().r1()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for block in self.blocks.iter() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, freqs_cis)?;
|
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.narrow(0, t - 1, 1)?;
|
let x = x.narrow(0, t - 1, 1)?;
|
||||||
@ -400,7 +433,7 @@ struct Args {
|
|||||||
|
|
||||||
/// Use npy instead of safetensors
|
/// Use npy instead of safetensors
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
npy: bool,
|
npy: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -426,33 +459,35 @@ async fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&config, &device);
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let (llama, tokenizer_filename) = if args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
println!("building the model (NPY)");
|
Some(npy) => {
|
||||||
(
|
println!("building the model (NPY)");
|
||||||
Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
|
let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
|
||||||
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
|
||||||
)
|
(weights, token_path)
|
||||||
} else {
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
|
||||||
println!("building the model");
|
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
|
||||||
let mut filenames = vec![];
|
|
||||||
for rfilename in [
|
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
let filename = api.get(&repo, rfilename).await?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
}
|
||||||
|
None => {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
|
println!("building the model");
|
||||||
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
|
let mut filenames = vec![];
|
||||||
|
for rfilename in [
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
] {
|
||||||
|
let filename = api.get(&repo, rfilename).await?;
|
||||||
|
filenames.push(filename);
|
||||||
|
}
|
||||||
|
|
||||||
println!("building the model (SF)");
|
println!("building the model (SF)");
|
||||||
(
|
(
|
||||||
Llama::load(&device, &filenames, &cache, &config)?,
|
Llama::load(&device, &filenames, &cache, &config)?,
|
||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
println!("Loaded in {:?}", start.elapsed());
|
println!("Loaded in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
@ -470,7 +505,12 @@ async fn main() -> Result<()> {
|
|||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let context_size = if USE_KV_CACHE && index > 0 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
CONTEXT_SIZE
|
||||||
|
};
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
|
|
||||||
|
@ -10,6 +10,14 @@ pub enum Error {
|
|||||||
got: DType,
|
got: DType,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
|
||||||
|
NarrowInvalidArgs {
|
||||||
|
shape: Shape,
|
||||||
|
dim: usize,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("{op} only supports contiguous tensors")]
|
#[error("{op} only supports contiguous tensors")]
|
||||||
RequiresContiguous { op: &'static str },
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
|
@ -349,21 +349,34 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
/// ranges from `start` to `start + length`.
|
/// ranges from `start` to `start + len`.
|
||||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let op = if self.track_op() {
|
let dims = self.dims();
|
||||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
if dim >= dims.len() || start + len > dims[dim] {
|
||||||
|
Err(Error::NarrowInvalidArgs {
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
dim,
|
||||||
|
start,
|
||||||
|
len,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
if start == 0 && dims[dim] == len {
|
||||||
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
None
|
let op = if self.track_op() {
|
||||||
};
|
Some(Op::Narrow(self.clone(), dim, start, len))
|
||||||
let tensor_ = Tensor_ {
|
} else {
|
||||||
id: TensorId::new(),
|
None
|
||||||
storage: self.storage.clone(),
|
};
|
||||||
layout: self.layout().narrow(dim, start, length)?,
|
let tensor_ = Tensor_ {
|
||||||
op,
|
id: TensorId::new(),
|
||||||
is_variable: false,
|
storage: self.storage.clone(),
|
||||||
};
|
layout: self.layout().narrow(dim, start, len)?,
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
op,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||||
|
Reference in New Issue
Block a user