mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge branch 'main' into ivarflakstad/metal-prng
This commit is contained in:
@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
|
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||||
implementation of the Mamba state space model.
|
implementation of the Mamba state space model.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
better performance than all publicly available 13b models as of 2023-09-28.
|
||||||
@ -109,6 +109,9 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
|
- [VGG](./candle-examples/examples/vgg/),
|
||||||
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
@ -204,7 +207,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
mod benchmarks;
|
mod benchmarks;
|
||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(benchmarks::matmul::benches, benchmarks::random::benches);
|
criterion_main!(
|
||||||
|
benchmarks::affine::benches,
|
||||||
|
benchmarks::matmul::benches,
|
||||||
|
benchmarks::random::benches,
|
||||||
|
benchmarks::where_cond::benches
|
||||||
|
);
|
||||||
|
43
candle-core/benches/benchmarks/affine.rs
Normal file
43
candle-core/benches/benchmarks/affine.rs
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor) {
|
||||||
|
a.affine(12.34, 56.78).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&tensor));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
for device in handler.devices {
|
||||||
|
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||||
|
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||||
|
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,7 @@
|
|||||||
|
pub(crate) mod affine;
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod random;
|
pub(crate) mod random;
|
||||||
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
|
||||||
|
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||||
|
a.where_cond(b, c).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||||
|
let mut arr = [0u8; N];
|
||||||
|
let mut i = 0;
|
||||||
|
while i < N {
|
||||||
|
arr[i] = (i % 2) as u8;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
arr
|
||||||
|
}
|
||||||
|
|
||||||
|
const B: usize = 1;
|
||||||
|
const M: usize = 1024;
|
||||||
|
const K: usize = 1024;
|
||||||
|
const SIZE: usize = B * M * K;
|
||||||
|
|
||||||
|
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||||
|
|
||||||
|
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||||
|
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||||
|
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||||
|
|
||||||
|
let elements = B * M * K;
|
||||||
|
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||||
|
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(
|
||||||
|
black_box(&tensor),
|
||||||
|
black_box(&on_true),
|
||||||
|
black_box(&on_false),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let device = BenchDeviceHandler::new().unwrap();
|
||||||
|
for d in device.devices {
|
||||||
|
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||||
|
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||||
|
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32",
|
DType::F32 => "affine_f32",
|
||||||
DType::F16 => "affine_f16",
|
DType::F16 => "affine_f16",
|
||||||
|
DType::BF16 => "affine_bf16",
|
||||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32_strided",
|
DType::F32 => "affine_f32_strided",
|
||||||
DType::F16 => "affine_f16_strided",
|
DType::F16 => "affine_f16_strided",
|
||||||
|
DType::BF16 => "affine_bf16_strided",
|
||||||
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine_strided(
|
candle_metal_kernels::call_affine_strided(
|
||||||
@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
let name = match (self.dtype, t.dtype()) {
|
let name = match (self.dtype, t.dtype()) {
|
||||||
(DType::U8, DType::F32) => "where_u8_f32",
|
(DType::U8, DType::F32) => "where_u8_f32",
|
||||||
|
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||||
(DType::U8, DType::F16) => "where_u8_f16",
|
(DType::U8, DType::F16) => "where_u8_f16",
|
||||||
(DType::U8, DType::I64) => "where_u8_i64",
|
(DType::U8, DType::I64) => "where_u8_i64",
|
||||||
(DType::U8, DType::U32) => "where_u8_u32",
|
(DType::U8, DType::U32) => "where_u8_u32",
|
||||||
|
@ -2578,11 +2578,21 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns log(sum(exp(tensor), dim)).
|
/// Returns log(sum(exp(tensor), dim)).
|
||||||
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
let exp = self.exp()?;
|
let exp = self.exp()?;
|
||||||
let sum = exp.sum(sum_dims)?;
|
let sum = exp.sum(sum_dims)?;
|
||||||
sum.log()
|
sum.log()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pointwise pow operation.
|
||||||
|
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcasting version of `pow`.
|
||||||
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn logsumexp() -> Result<()> {
|
fn log_sum_exp() -> Result<()> {
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
let output = input.logsumexp(D::Minus1)?;
|
let output = input.log_sum_exp(D::Minus1)?;
|
||||||
// The expectations obtained from pytorch.
|
// The expectations obtained from pytorch.
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
assert_close(&output, &expected, 0.00001)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pow() -> Result<()> {
|
||||||
|
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
let rhs = (&lhs - 2.)?;
|
||||||
|
let res = lhs.pow(&rhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&res, 4)?,
|
||||||
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -27,11 +27,5 @@ fn main() -> Result<()> {
|
|||||||
bindings.write(kdir.rust_target).unwrap()
|
bindings.write(kdir.rust_target).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
{
|
|
||||||
for kdir in KERNEL_DIRS.iter() {
|
|
||||||
let _file = std::fs::File::create(kdir.rust_target)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1 +0,0 @@
|
|||||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
|
||||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
|
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
MixFormer(MixFormer),
|
MixFormer(MixFormer),
|
||||||
|
Phi(Phi),
|
||||||
Quantized(QMixFormer),
|
Quantized(QMixFormer),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,6 +86,7 @@ impl TextGeneration {
|
|||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = match &mut self.model {
|
let logits = match &mut self.model {
|
||||||
Model::MixFormer(m) => m.forward(&input)?,
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
|
Model::Phi(m) => m.forward(&input)?,
|
||||||
Model::Quantized(m) => m.forward(&input)?,
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
@ -117,7 +120,7 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
#[value(name = "1")]
|
#[value(name = "1")]
|
||||||
V1,
|
V1,
|
||||||
@ -125,6 +128,9 @@ enum WhichModel {
|
|||||||
V1_5,
|
V1_5,
|
||||||
#[value(name = "2")]
|
#[value(name = "2")]
|
||||||
V2,
|
V2,
|
||||||
|
// TODO: Make this the default once it has been battle tested.
|
||||||
|
#[value(name = "2-new")]
|
||||||
|
V2New,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
PhiHermes,
|
PhiHermes,
|
||||||
}
|
}
|
||||||
@ -169,7 +175,7 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "1.5")]
|
#[arg(long, default_value = "2")]
|
||||||
model: WhichModel,
|
model: WhichModel,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -230,7 +236,7 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::V2 => "microsoft/phi-2".to_string(),
|
WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -247,7 +253,8 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||||
|
WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"main".to_string()
|
"main".to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,7 +265,9 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => {
|
||||||
|
repo.get("tokenizer.json")?
|
||||||
|
}
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
}
|
}
|
||||||
@ -271,14 +280,14 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||||
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
|
WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors(
|
||||||
&repo,
|
&repo,
|
||||||
"model.safetensors.index.json",
|
"model.safetensors.index.json",
|
||||||
)?,
|
)?,
|
||||||
@ -292,25 +301,35 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = match args.model {
|
let config = || match args.model {
|
||||||
WhichModel::V1 => Config::v1(),
|
WhichModel::V1 => Config::v1(),
|
||||||
WhichModel::V1_5 => Config::v1_5(),
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
WhichModel::V2 => Config::v2(),
|
WhichModel::V2 | WhichModel::V2New => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
};
|
};
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.model == WhichModel::V2New {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
|
let phi = Phi::new(&config, vb)?;
|
||||||
|
(Model::Phi(phi), device)
|
||||||
|
} else if args.quantized {
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||||
|
let config = config();
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
|
WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?,
|
||||||
_ => QMixFormer::new(&config, vb)?,
|
_ => QMixFormer::new(&config, vb)?,
|
||||||
};
|
};
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let config = config();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
|
WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?,
|
||||||
_ => MixFormer::new(&config, vb)?,
|
_ => MixFormer::new(&config, vb)?,
|
||||||
};
|
};
|
||||||
(Model::MixFormer(model), device)
|
(Model::MixFormer(model), device)
|
||||||
@ -393,6 +412,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
|||||||
m.clear_kv_cache();
|
m.clear_kv_cache();
|
||||||
m.forward(&input)?
|
m.forward(&input)?
|
||||||
}
|
}
|
||||||
|
Model::Phi(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
Model::Quantized(m) => {
|
Model::Quantized(m) => {
|
||||||
m.clear_kv_cache();
|
m.clear_kv_cache();
|
||||||
m.forward(&input)?
|
m.forward(&input)?
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# candle-repvgg
|
# candle-repvgg
|
||||||
|
|
||||||
A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697).
|
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
|
||||||
This uses a classification head trained on the ImageNet dataset and returns the
|
|
||||||
|
This candle implementation uses a pre-trained RepVGG network for inference. The
|
||||||
|
classification head has been trained on the ImageNet dataset and returns the
|
||||||
probabilities for the top-5 classes.
|
probabilities for the top-5 classes.
|
||||||
|
|
||||||
## Running an example
|
## Running an example
|
||||||
|
@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define AFFINE(FN_NAME, TYPENAME) \
|
#define AFFINE(FN_NAME, T) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
constant float &mul, \
|
constant float &mul, \
|
||||||
constant float &add, \
|
constant float &add, \
|
||||||
device const TYPENAME *input, \
|
device const T *input, \
|
||||||
device TYPENAME *output, \
|
device T *output, \
|
||||||
uint id [[ thread_position_in_grid ]] \
|
uint id [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (id >= dim) { \
|
if (id >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
output[id] = T(fma(float(input[id]), mul, add)); \
|
||||||
} \
|
} \
|
||||||
kernel void FN_NAME##_strided( \
|
kernel void FN_NAME##_strided( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
|
|||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
constant float &mul, \
|
constant float &mul, \
|
||||||
constant float &add, \
|
constant float &add, \
|
||||||
device const TYPENAME *input, \
|
device const T *input, \
|
||||||
device TYPENAME *output, \
|
device T *output, \
|
||||||
uint id [[ thread_position_in_grid ]] \
|
uint id [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (id >= dim) { \
|
if (id >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define POWF(FN_NAME, TYPENAME) \
|
#define POWF(FN_NAME, TYPENAME) \
|
||||||
|
@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T, typename ID>
|
||||||
|
METAL_FUNC void where_cond(
|
||||||
|
constant size_t &numel,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
constant size_t *strides_t,
|
||||||
|
constant size_t *strides_f,
|
||||||
|
device const ID *ids,
|
||||||
|
device const T *t,
|
||||||
|
device const T *f,
|
||||||
|
device T *out,
|
||||||
|
uint i [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
if (i >= numel){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||||
|
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||||
|
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||||
|
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||||
|
}
|
||||||
|
|
||||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
#define WHERE_OP(T, ID, FN_NAME) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &numel, \
|
constant size_t &numel, \
|
||||||
constant size_t &num_dims, \
|
constant size_t &num_dims, \
|
||||||
constant size_t *dims, \
|
constant size_t *dims, \
|
||||||
constant size_t *strides, \
|
constant size_t *strides, \
|
||||||
constant size_t *strides_t, \
|
constant size_t *strides_t, \
|
||||||
constant size_t *strides_f, \
|
constant size_t *strides_f, \
|
||||||
device const ID_TYPENAME *ids, \
|
device const ID *ids, \
|
||||||
device const TYPENAME *t, \
|
device const T *t, \
|
||||||
device const TYPENAME *f, \
|
device const T *f, \
|
||||||
device TYPENAME *out ,\
|
device T *out, \
|
||||||
uint i [[ thread_position_in_grid ]] \
|
uint i [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (i >= numel){ \
|
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||||
return; \
|
} \
|
||||||
} \
|
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
|
||||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||||
@ -54,10 +70,14 @@ kernel void FN_NAME( \
|
|||||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||||
|
|
||||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||||
|
#endif
|
@ -6,6 +6,7 @@ use serde::Deserialize;
|
|||||||
pub enum Activation {
|
pub enum Activation {
|
||||||
#[default]
|
#[default]
|
||||||
Gelu,
|
Gelu,
|
||||||
|
#[serde(alias = "gelu_new")]
|
||||||
NewGelu,
|
NewGelu,
|
||||||
Relu,
|
Relu,
|
||||||
Relu2,
|
Relu2,
|
||||||
|
@ -254,6 +254,12 @@ pub fn simple_eval(
|
|||||||
let output = input0.broadcast_div(input1)?;
|
let output = input0.broadcast_div(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Pow" => {
|
||||||
|
let input0 = get(&node.input[0])?;
|
||||||
|
let input1 = get(&node.input[1])?;
|
||||||
|
let output = input0.broadcast_pow(input1)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Equal" => {
|
"Equal" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?;
|
let input1 = get(&node.input[1])?;
|
||||||
|
@ -17,6 +17,7 @@ pub mod mixformer;
|
|||||||
pub mod mixtral;
|
pub mod mixtral;
|
||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
pub mod persimmon;
|
pub mod persimmon;
|
||||||
|
pub mod phi;
|
||||||
pub mod quantized_blip;
|
pub mod quantized_blip;
|
||||||
pub mod quantized_blip_text;
|
pub mod quantized_blip_text;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
|
363
candle-transformers/src/models/phi.rs
Normal file
363
candle-transformers/src/models/phi.rs
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
|
||||||
|
/// Phi model.
|
||||||
|
/// https://huggingface.co/microsoft/phi-2
|
||||||
|
/// There is an alternative implementation of the phi model in mixformers.rs.
|
||||||
|
/// This corresponds to the model update made with the following commit:
|
||||||
|
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
|
||||||
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub(crate) vocab_size: usize,
|
||||||
|
pub(crate) hidden_size: usize,
|
||||||
|
pub(crate) intermediate_size: usize,
|
||||||
|
pub(crate) num_hidden_layers: usize,
|
||||||
|
pub(crate) num_attention_heads: usize,
|
||||||
|
pub(crate) num_key_value_heads: Option<usize>,
|
||||||
|
pub(crate) hidden_act: Activation,
|
||||||
|
pub(crate) max_position_embeddings: usize,
|
||||||
|
pub(crate) layer_norm_eps: f64,
|
||||||
|
pub(crate) tie_word_embeddings: bool,
|
||||||
|
pub(crate) rope_theta: f32,
|
||||||
|
pub(crate) partial_rotary_factor: f64,
|
||||||
|
pub(crate) qk_layernorm: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn num_key_value_heads(&self) -> usize {
|
||||||
|
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn head_dim(&self) -> usize {
|
||||||
|
self.hidden_size / self.num_attention_heads
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
dim: usize,
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||||
|
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((cfg.max_position_embeddings, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||||
|
Ok(Self {
|
||||||
|
dim,
|
||||||
|
sin: emb.sin()?,
|
||||||
|
cos: emb.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||||
|
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||||
|
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||||
|
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||||
|
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||||
|
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||||
|
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||||
|
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
act: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
|
||||||
|
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
// This does not match the mixformers implementation where Gelu is used rather than
|
||||||
|
// GeluNew.
|
||||||
|
act: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
dense: Linear,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
q_layernorm: Option<LayerNorm>,
|
||||||
|
k_layernorm: Option<LayerNorm>,
|
||||||
|
rotary_emb: RotaryEmbedding,
|
||||||
|
softmax_scale: f64,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (size, size), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
|
let shape = mask.shape();
|
||||||
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads();
|
||||||
|
let head_dim = cfg.head_dim();
|
||||||
|
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
|
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
// Alternative rope scalings are not supported.
|
||||||
|
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
|
||||||
|
let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
|
||||||
|
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
|
||||||
|
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
|
||||||
|
(Some(q_layernorm), Some(k_layernorm))
|
||||||
|
} else {
|
||||||
|
(None, None)
|
||||||
|
};
|
||||||
|
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
dense,
|
||||||
|
kv_cache: None,
|
||||||
|
q_layernorm,
|
||||||
|
k_layernorm,
|
||||||
|
rotary_emb,
|
||||||
|
softmax_scale,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||||
|
let n_rep = self.num_heads / self.num_kv_heads;
|
||||||
|
if n_rep == 1 {
|
||||||
|
Ok(xs)
|
||||||
|
} else {
|
||||||
|
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||||
|
xs.unsqueeze(2)?
|
||||||
|
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||||
|
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = match &self.q_layernorm {
|
||||||
|
None => query_states,
|
||||||
|
Some(ln) => query_states.apply(ln)?,
|
||||||
|
};
|
||||||
|
let key_states = match &self.k_layernorm {
|
||||||
|
None => key_states,
|
||||||
|
Some(ln) => key_states.apply(ln)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_size, seq_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
// Rotary embeddings.
|
||||||
|
let seqlen_offset = match &self.kv_cache {
|
||||||
|
None => 0,
|
||||||
|
Some((prev_k, _)) => prev_k.dim(2)?,
|
||||||
|
};
|
||||||
|
let query_states = self
|
||||||
|
.rotary_emb
|
||||||
|
.apply_rotary_emb(&query_states, seqlen_offset)?;
|
||||||
|
let key_states = self
|
||||||
|
.rotary_emb
|
||||||
|
.apply_rotary_emb(&key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
// KV cache.
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let k = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let v = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
// Repeat kv.
|
||||||
|
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||||
|
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||||
|
|
||||||
|
let attn_weights = (query_states
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.contiguous()?
|
||||||
|
.matmul(&key_states.to_dtype(DType::F32)?.t()?)?
|
||||||
|
* self.softmax_scale)?;
|
||||||
|
let attn_weights = match mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => masked_fill(
|
||||||
|
&attn_weights,
|
||||||
|
&mask.broadcast_left((b_size, self.num_heads))?,
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
let attn_weights =
|
||||||
|
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
|
||||||
|
let attn_output = attn_weights.matmul(&value_states)?;
|
||||||
|
let attn_output = attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_size, seq_len, ()))?;
|
||||||
|
attn_output.apply(&self.dense)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
input_layernorm: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let input_layernorm = layer_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.layer_norm_eps,
|
||||||
|
vb.pp("input_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
input_layernorm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.input_layernorm)?;
|
||||||
|
let attn_outputs = self.self_attn.forward(&xs, mask)?;
|
||||||
|
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||||
|
attn_outputs + feed_forward_hidden_states + residual
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
final_layernorm: LayerNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("model");
|
||||||
|
let embed_tokens =
|
||||||
|
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let final_layernorm = layer_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.layer_norm_eps,
|
||||||
|
vb_m.pp("final_layernorm"),
|
||||||
|
)?;
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_m = vb_m.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
final_layernorm,
|
||||||
|
lm_head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (_b_size, seq_len) = xs.dims2()?;
|
||||||
|
let mut xs = xs.apply(&self.embed_tokens)?;
|
||||||
|
let mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(get_mask(seq_len, xs.device())?)
|
||||||
|
};
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, mask.as_ref())?;
|
||||||
|
}
|
||||||
|
xs.apply(&self.final_layernorm)?
|
||||||
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.lm_head)?
|
||||||
|
.squeeze(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.layers.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user