mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Compare commits
14 Commits
wasm-llama
...
faster-gem
Author | SHA1 | Date | |
---|---|---|---|
c6ae9f565e | |||
620f83cf66 | |||
3fa3623135 | |||
f7b2a0391d | |||
8b6f5be1cc | |||
df6667ba88 | |||
a79286885c | |||
74845a4dcd | |||
aa76b783eb | |||
25564357f7 | |||
634700d84a | |||
e635f18eda | |||
52414ba5c8 | |||
186c308d51 |
@ -1,6 +1,7 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"candle-core",
|
"candle-core",
|
||||||
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
|
@ -98,8 +98,9 @@ Cheatsheet:
|
|||||||
- [candle-nn](./candle-nn/): Facilities to build real models
|
- [candle-nn](./candle-nn/): Facilities to build real models
|
||||||
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
||||||
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
||||||
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
|
- [candle-transformers](./candle-transformers): Transformer related utilities.
|
||||||
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
let mut file = std::fs::File::open("ggml.bin")?;
|
||||||
|
let data = candle_core::ggml::Content::read(&mut file, &Device::Cpu)?;
|
||||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
|
@ -1010,12 +1010,18 @@ impl Map2 for MatMul {
|
|||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
|
||||||
let dst_rs = dst_strides[0];
|
|
||||||
let dst_cs = dst_strides[1];
|
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
|
|
||||||
|
let (dst_rs, dst_cs) = if m == 1 {
|
||||||
|
(1, 1)
|
||||||
|
} else if n == 1 {
|
||||||
|
(1, 1)
|
||||||
|
} else {
|
||||||
|
let dst_shape: Shape = (m, n).into();
|
||||||
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
|
(dst_strides[0], dst_strides[1])
|
||||||
|
};
|
||||||
|
|
||||||
let num_threads = crate::utils::get_num_threads();
|
let num_threads = crate::utils::get_num_threads();
|
||||||
let parallelism = if num_threads > 1 {
|
let parallelism = if num_threads > 1 {
|
||||||
Parallelism::Rayon(num_threads)
|
Parallelism::Rayon(num_threads)
|
||||||
|
@ -101,6 +101,13 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_cpu(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => true,
|
||||||
|
Self::Cuda(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_cuda(&self) -> bool {
|
pub fn is_cuda(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => false,
|
Self::Cpu => false,
|
||||||
|
582
candle-core/src/ggml.rs
Normal file
582
candle-core/src/ggml.rs
Normal file
@ -0,0 +1,582 @@
|
|||||||
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
|
use crate::{DType, Device, Result, Tensor};
|
||||||
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
|
// Default to QK_K 256 rather than 64.
|
||||||
|
pub const QK_K: usize = 256;
|
||||||
|
pub const K_SCALE_SIZE: usize = 12;
|
||||||
|
|
||||||
|
pub const QK4_0: usize = 32;
|
||||||
|
pub const QK4_1: usize = 32;
|
||||||
|
pub const QK5_0: usize = 32;
|
||||||
|
pub const QK5_1: usize = 32;
|
||||||
|
pub const QK8_0: usize = 32;
|
||||||
|
pub const QK8_1: usize = 32;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4_0 {
|
||||||
|
d: f16,
|
||||||
|
qs: [u8; QK4_0 / 2],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4_1 {
|
||||||
|
d: f16,
|
||||||
|
m: f16,
|
||||||
|
qs: [u8; QK4_1 / 2],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5_0 {
|
||||||
|
d: f16,
|
||||||
|
qh: [u8; 4],
|
||||||
|
qs: [u8; QK5_0 / 2],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5_1 {
|
||||||
|
d: f16,
|
||||||
|
m: f16,
|
||||||
|
qh: [u8; 4],
|
||||||
|
qs: [u8; QK5_1 / 2],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ8_0 {
|
||||||
|
d: f16,
|
||||||
|
qs: [u8; QK8_0],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ8_1 {
|
||||||
|
d: f16,
|
||||||
|
s: f16,
|
||||||
|
qs: [u8; QK8_1],
|
||||||
|
}
|
||||||
|
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ2K {
|
||||||
|
scales: [u8; QK_K / 16],
|
||||||
|
qs: [u8; QK_K / 4],
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
}
|
||||||
|
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ3K {
|
||||||
|
hmask: [u8; QK_K / 8],
|
||||||
|
qs: [u8; QK_K / 4],
|
||||||
|
scales: [u8; 12],
|
||||||
|
d: f16,
|
||||||
|
}
|
||||||
|
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4K {
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
scales: [u8; K_SCALE_SIZE],
|
||||||
|
qs: [u8; QK_K / 2],
|
||||||
|
}
|
||||||
|
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5K {
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
scales: [u8; K_SCALE_SIZE],
|
||||||
|
qh: [u8; QK_K / 8],
|
||||||
|
qs: [u8; QK_K / 2],
|
||||||
|
}
|
||||||
|
const _: () =
|
||||||
|
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ6K {
|
||||||
|
ql: [u8; QK_K / 2],
|
||||||
|
qh: [u8; QK_K / 4],
|
||||||
|
scales: [i8; QK_K / 16],
|
||||||
|
d: f16,
|
||||||
|
}
|
||||||
|
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||||
|
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut ys_index = 0;
|
||||||
|
for x in xs {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let min = x.dmin.to_f32();
|
||||||
|
let q = &x.qs;
|
||||||
|
|
||||||
|
let mut is = 0;
|
||||||
|
for n in (0..QK_K).step_by(128) {
|
||||||
|
// Step by 32 over q.
|
||||||
|
let q = &q[n / 4..];
|
||||||
|
let mut shift = 0;
|
||||||
|
for _j in 0..4 {
|
||||||
|
let sc = x.scales[is];
|
||||||
|
is += 1;
|
||||||
|
let dl = d * (sc & 0xF) as f32;
|
||||||
|
let ml = min * (sc >> 4) as f32;
|
||||||
|
for q in &q[..16] {
|
||||||
|
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sc = x.scales[is];
|
||||||
|
is += 1;
|
||||||
|
let dl = d * (sc & 0xF) as f32;
|
||||||
|
let ml = min * (sc >> 4) as f32;
|
||||||
|
for q in &q[16..32] {
|
||||||
|
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
shift += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||||
|
if j < 4 {
|
||||||
|
let d = q[j] & 63;
|
||||||
|
let m = q[j + 4] & 63;
|
||||||
|
(d, m)
|
||||||
|
} else {
|
||||||
|
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||||
|
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||||
|
(d, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||||
|
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut ys_index = 0;
|
||||||
|
for x in xs.iter() {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let min = x.dmin.to_f32();
|
||||||
|
let q = &x.qs;
|
||||||
|
let mut is = 0;
|
||||||
|
for j in (0..QK_K).step_by(64) {
|
||||||
|
let q = &q[j / 2..j / 2 + 32];
|
||||||
|
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||||
|
let d1 = d * sc as f32;
|
||||||
|
let m1 = min * m as f32;
|
||||||
|
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||||
|
let d2 = d * sc as f32;
|
||||||
|
let m2 = min * m as f32;
|
||||||
|
for q in q {
|
||||||
|
let y = d1 * (q & 0xF) as f32 - m1;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
for q in q {
|
||||||
|
let y = d2 * (q >> 4) as f32 - m2;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
is += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||||
|
fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
|
||||||
|
fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut ys_index = 0;
|
||||||
|
for x in xs.iter() {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let min = x.dmin.to_f32();
|
||||||
|
let ql = &x.qs;
|
||||||
|
let qh = &x.qh;
|
||||||
|
let mut is = 0;
|
||||||
|
let mut u1 = 1;
|
||||||
|
let mut u2 = 2;
|
||||||
|
for j in (0..QK_K).step_by(64) {
|
||||||
|
let ql = &ql[j / 2..j / 2 + 32];
|
||||||
|
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||||
|
let d1 = d * sc as f32;
|
||||||
|
let m1 = min * m as f32;
|
||||||
|
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||||
|
let d2 = d * sc as f32;
|
||||||
|
let m2 = min * m as f32;
|
||||||
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
|
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||||
|
let y = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
|
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||||
|
let y = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
is += 2;
|
||||||
|
u1 <<= 2;
|
||||||
|
u2 <<= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||||
|
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
for x in xs.iter() {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let ql = &x.ql;
|
||||||
|
let qh = &x.qh;
|
||||||
|
let sc = &x.scales;
|
||||||
|
for n in (0..QK_K).step_by(128) {
|
||||||
|
let idx = n / 128;
|
||||||
|
let ys = &mut ys[n..];
|
||||||
|
let sc = &sc[8 * idx..];
|
||||||
|
let ql = &ql[64 * idx..];
|
||||||
|
let qh = &qh[32 * idx..];
|
||||||
|
for l in 0..32 {
|
||||||
|
let is = l / 16;
|
||||||
|
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
|
||||||
|
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
|
||||||
|
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
|
||||||
|
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
|
||||||
|
ys[l] = d * sc[is] as f32 * q1 as f32;
|
||||||
|
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
|
||||||
|
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
|
||||||
|
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum Magic {
|
||||||
|
Ggjt,
|
||||||
|
Ggla,
|
||||||
|
Ggmf,
|
||||||
|
Ggml,
|
||||||
|
Ggsn,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<u32> for Magic {
|
||||||
|
type Error = crate::Error;
|
||||||
|
fn try_from(value: u32) -> Result<Self> {
|
||||||
|
let magic = match value {
|
||||||
|
0x67676a74 => Self::Ggjt,
|
||||||
|
0x67676c61 => Self::Ggla,
|
||||||
|
0x67676d66 => Self::Ggmf,
|
||||||
|
0x67676d6c => Self::Ggml,
|
||||||
|
0x6767736e => Self::Ggsn,
|
||||||
|
_ => crate::bail!("unknown magic {value:08x}"),
|
||||||
|
};
|
||||||
|
Ok(magic)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum VersionedMagic {
|
||||||
|
GgmlUnversioned,
|
||||||
|
GgmfV1,
|
||||||
|
GgjtV1,
|
||||||
|
GgjtV2,
|
||||||
|
GgjtV3,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VersionedMagic {
|
||||||
|
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||||
|
let magic = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let magic = Magic::try_from(magic)?;
|
||||||
|
if magic == Magic::Ggml {
|
||||||
|
return Ok(Self::GgmlUnversioned);
|
||||||
|
}
|
||||||
|
let version = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let versioned_magic = match (magic, version) {
|
||||||
|
(Magic::Ggmf, 1) => Self::GgmfV1,
|
||||||
|
(Magic::Ggjt, 1) => Self::GgjtV1,
|
||||||
|
(Magic::Ggjt, 2) => Self::GgjtV2,
|
||||||
|
(Magic::Ggjt, 3) => Self::GgjtV3,
|
||||||
|
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||||
|
};
|
||||||
|
Ok(versioned_magic)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn align32(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::GgmlUnversioned | Self::GgmfV1 => false,
|
||||||
|
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct HParams {
|
||||||
|
pub n_vocab: u32,
|
||||||
|
pub n_embd: u32,
|
||||||
|
pub n_mult: u32,
|
||||||
|
pub n_head: u32,
|
||||||
|
pub n_layer: u32,
|
||||||
|
pub n_rot: u32,
|
||||||
|
pub ftype: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HParams {
|
||||||
|
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||||
|
let n_vocab = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let n_embd = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let n_mult = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let n_head = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let n_layer = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let n_rot = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let ftype = reader.read_u32::<LittleEndian>()?;
|
||||||
|
Ok(Self {
|
||||||
|
n_vocab,
|
||||||
|
n_embd,
|
||||||
|
n_mult,
|
||||||
|
n_head,
|
||||||
|
n_layer,
|
||||||
|
n_rot,
|
||||||
|
ftype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct Vocab {
|
||||||
|
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Vocab {
|
||||||
|
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
|
||||||
|
let mut token_score_pairs = Vec::with_capacity(n_vocab);
|
||||||
|
for _index in 0..n_vocab {
|
||||||
|
let len = reader.read_u32::<LittleEndian>()? as usize;
|
||||||
|
let mut word = vec![0u8; len];
|
||||||
|
reader.read_exact(&mut word)?;
|
||||||
|
let score = reader.read_f32::<LittleEndian>()?;
|
||||||
|
token_score_pairs.push((word, score))
|
||||||
|
}
|
||||||
|
Ok(Self { token_score_pairs })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum GgmlDType {
|
||||||
|
F32,
|
||||||
|
F16,
|
||||||
|
Q4_0,
|
||||||
|
Q4_1,
|
||||||
|
Q5_0,
|
||||||
|
Q5_1,
|
||||||
|
Q8_0,
|
||||||
|
Q8_1,
|
||||||
|
Q2K,
|
||||||
|
Q3K,
|
||||||
|
Q4K,
|
||||||
|
Q5K,
|
||||||
|
Q6K,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GgmlDType {
|
||||||
|
fn from_u32(u: u32) -> Result<Self> {
|
||||||
|
let dtype = match u {
|
||||||
|
0 => Self::F32,
|
||||||
|
1 => Self::F16,
|
||||||
|
2 => Self::Q4_0,
|
||||||
|
3 => Self::Q4_1,
|
||||||
|
6 => Self::Q5_0,
|
||||||
|
7 => Self::Q5_1,
|
||||||
|
8 => Self::Q8_0,
|
||||||
|
9 => Self::Q8_1,
|
||||||
|
10 => Self::Q2K,
|
||||||
|
11 => Self::Q3K,
|
||||||
|
12 => Self::Q4K,
|
||||||
|
13 => Self::Q5K,
|
||||||
|
14 => Self::Q6K,
|
||||||
|
_ => crate::bail!("unknown dtype for tensor {u}"),
|
||||||
|
};
|
||||||
|
Ok(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn type_size(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::F32 => 4,
|
||||||
|
Self::F16 => 2,
|
||||||
|
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
|
||||||
|
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
|
||||||
|
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
|
||||||
|
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
||||||
|
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
|
||||||
|
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
|
||||||
|
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
|
||||||
|
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
|
||||||
|
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||||
|
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||||
|
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn blck_size(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::F32 => 1,
|
||||||
|
Self::F16 => 1,
|
||||||
|
Self::Q4_0 => QK4_0,
|
||||||
|
Self::Q4_1 => QK4_1,
|
||||||
|
Self::Q5_0 => QK5_0,
|
||||||
|
Self::Q5_1 => QK5_1,
|
||||||
|
Self::Q8_0 => QK8_0,
|
||||||
|
Self::Q8_1 => QK8_1,
|
||||||
|
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Content {
|
||||||
|
pub magic: VersionedMagic,
|
||||||
|
pub hparams: HParams,
|
||||||
|
pub vocab: Vocab,
|
||||||
|
pub tensors: Vec<(String, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
magic: VersionedMagic,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(String, Tensor)> {
|
||||||
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let dtype = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let dtype = GgmlDType::from_u32(dtype)?;
|
||||||
|
let mut dims = vec![0u32; n_dims as usize];
|
||||||
|
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||||
|
let mut name = vec![0u8; name_len as usize];
|
||||||
|
reader.read_exact(&mut name)?;
|
||||||
|
let name = String::from_utf8_lossy(&name).into_owned();
|
||||||
|
|
||||||
|
if magic.align32() {
|
||||||
|
let pos = reader.stream_position()?;
|
||||||
|
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
||||||
|
}
|
||||||
|
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||||
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
|
let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
|
||||||
|
println!("{name} {dtype:?} {dims:?}");
|
||||||
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
|
reader.read_exact(&mut raw_data)?;
|
||||||
|
let tensor = match dtype {
|
||||||
|
GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
|
||||||
|
GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
|
||||||
|
dequantize_row_q2k(raw_data, &mut f32_data)?;
|
||||||
|
// Maybe we should use bf16 instead?
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
GgmlDType::Q3K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ3K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) };
|
||||||
|
dequantize_row_q3k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
|
||||||
|
dequantize_row_q4k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
GgmlDType::Q5K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ5K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) };
|
||||||
|
dequantize_row_q5k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
|
||||||
|
dequantize_row_q6k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
|
||||||
|
};
|
||||||
|
Ok((name, tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Content {
|
||||||
|
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Content> {
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
|
let magic = VersionedMagic::read(reader)?;
|
||||||
|
let hparams = HParams::read(reader)?;
|
||||||
|
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||||
|
let mut tensors = vec![];
|
||||||
|
|
||||||
|
while reader.stream_position()? != last_position {
|
||||||
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
|
tensors.push((name, tensor))
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
magic,
|
||||||
|
hparams,
|
||||||
|
vocab,
|
||||||
|
tensors,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -45,6 +45,7 @@ pub mod display;
|
|||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod ggml;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
|
20
candle-datasets/Cargo.toml
Normal file
20
candle-datasets/Cargo.toml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-datasets"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
byteorder = { workspace = true }
|
||||||
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
|
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||||
|
hf-hub = { workspace = true}
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
memmap2 = { workspace = true }
|
||||||
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
rand = { workspace = true }
|
6
candle-datasets/src/lib.rs
Normal file
6
candle-datasets/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
//! Datasets & Dataloaders for Candle
|
||||||
|
pub mod batcher;
|
||||||
|
pub mod nlp;
|
||||||
|
pub mod vision;
|
||||||
|
|
||||||
|
pub use batcher::Batcher;
|
1
candle-datasets/src/nlp/mod.rs
Normal file
1
candle-datasets/src/nlp/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod tinystories;
|
122
candle-datasets/src/nlp/tinystories.rs
Normal file
122
candle-datasets/src/nlp/tinystories.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
|
||||||
|
//! by the tools from https://github.com/karpathy/llama2.c
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
|
pub struct Dataset {
|
||||||
|
valid_tokens: Vec<memmap2::Mmap>,
|
||||||
|
train_tokens: Vec<memmap2::Mmap>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||||
|
let file = std::fs::File::open(p)?;
|
||||||
|
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||||
|
Ok(mmap)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dataset {
|
||||||
|
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||||
|
let dir = dir.as_ref();
|
||||||
|
let mut bin_files = vec![];
|
||||||
|
for file in std::fs::read_dir(dir)?.flatten() {
|
||||||
|
let file = file.path();
|
||||||
|
if let Some(extension) = file.extension() {
|
||||||
|
if extension == "bin" {
|
||||||
|
bin_files.push(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if bin_files.len() < 2 {
|
||||||
|
candle::bail!("found less than two bin files in {:?}", dir)
|
||||||
|
}
|
||||||
|
bin_files.sort();
|
||||||
|
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||||
|
let train_tokens = bin_files[1..]
|
||||||
|
.iter()
|
||||||
|
.map(mmap_file)
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(Self {
|
||||||
|
valid_tokens: vec![valid_tokens],
|
||||||
|
train_tokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_tokens(&self) -> usize {
|
||||||
|
self.train_tokens.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn valid_tokens(&self) -> usize {
|
||||||
|
self.valid_tokens.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DatasetRandomIter<'a> {
|
||||||
|
all_tokens: &'a [memmap2::Mmap],
|
||||||
|
tokens: Vec<&'a memmap2::Mmap>,
|
||||||
|
current_tokens: &'a memmap2::Mmap,
|
||||||
|
indexes_in_bytes: Vec<usize>,
|
||||||
|
seq_len: usize,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DatasetRandomIter<'a> {
|
||||||
|
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::thread_rng;
|
||||||
|
|
||||||
|
let all_tokens = if valid {
|
||||||
|
&ds.valid_tokens
|
||||||
|
} else {
|
||||||
|
&ds.train_tokens
|
||||||
|
};
|
||||||
|
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||||
|
tokens.shuffle(&mut thread_rng());
|
||||||
|
let current_tokens = tokens.pop().unwrap();
|
||||||
|
let seq_len_in_bytes = seq_len * 2;
|
||||||
|
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||||
|
.step_by(seq_len_in_bytes)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||||
|
Self {
|
||||||
|
all_tokens,
|
||||||
|
tokens,
|
||||||
|
current_tokens,
|
||||||
|
indexes_in_bytes,
|
||||||
|
seq_len,
|
||||||
|
device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||||
|
type Item = Result<(Tensor, Tensor)>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::thread_rng;
|
||||||
|
|
||||||
|
let seq_len = self.seq_len;
|
||||||
|
if self.indexes_in_bytes.is_empty() {
|
||||||
|
if self.tokens.is_empty() {
|
||||||
|
self.tokens = self.all_tokens.iter().collect();
|
||||||
|
self.tokens.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
self.current_tokens = self.tokens.pop().unwrap();
|
||||||
|
let seq_len_in_bytes = self.seq_len * 2;
|
||||||
|
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||||
|
.step_by(seq_len_in_bytes)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||||
|
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||||
|
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||||
|
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||||
|
return Some(Err(err.into()));
|
||||||
|
}
|
||||||
|
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||||
|
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||||
|
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||||
|
Some(candle::error::zip(inputs, targets))
|
||||||
|
}
|
||||||
|
}
|
@ -11,6 +11,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
|
candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||||
|
@ -111,6 +111,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_f32: bool,
|
use_f32: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
@ -123,8 +127,18 @@ struct Args {
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = if args.v1 {
|
let config = if args.v1 {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
use candle_nn::{Embedding, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@ -47,6 +47,21 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||||
|
// model.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
Ok(Linear::new(weight, None))
|
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||||
|
Ok(Linear { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||||
@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
let scale = vb.get(size, "weight")?;
|
let scale = vb.get(size, "weight")?;
|
||||||
Ok(Self { scale, eps })
|
Ok(Self { scale, eps, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let in_dtype = x.dtype();
|
let in_dtype = x.dtype();
|
||||||
// This is a no-op if x's dtype is already f32.
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
@ -155,6 +174,8 @@ struct CausalSelfAttention {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_rot: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "flash-attn")]
|
#[cfg(feature = "flash-attn")]
|
||||||
@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
|||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_rot.enter();
|
||||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
@ -188,6 +210,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
let k = self.k_proj.forward(x)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
@ -269,6 +292,8 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let size_in = cfg.hidden_size;
|
let size_in = cfg.hidden_size;
|
||||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||||
@ -286,6 +311,8 @@ impl CausalSelfAttention {
|
|||||||
head_dim: cfg.hidden_size / cfg.n_head,
|
head_dim: cfg.hidden_size / cfg.n_head,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
use_flash_attn: cfg.use_flash_attn,
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
|
span,
|
||||||
|
span_rot,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,15 +328,18 @@ struct Mlp {
|
|||||||
c_fc1: Linear,
|
c_fc1: Linear,
|
||||||
c_fc2: Linear,
|
c_fc2: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||||
let h_size = cfg.hidden_size;
|
let h_size = cfg.hidden_size;
|
||||||
let i_size = cfg.intermediate_size;
|
let i_size = cfg.intermediate_size;
|
||||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||||
@ -319,6 +349,7 @@ impl Mlp {
|
|||||||
c_fc1,
|
c_fc1,
|
||||||
c_fc2,
|
c_fc2,
|
||||||
c_proj,
|
c_proj,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -328,10 +359,12 @@ struct Block {
|
|||||||
attn: CausalSelfAttention,
|
attn: CausalSelfAttention,
|
||||||
rms_2: RmsNorm,
|
rms_2: RmsNorm,
|
||||||
mlp: Mlp,
|
mlp: Mlp,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||||
@ -341,6 +374,7 @@ impl Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
@ -354,6 +388,7 @@ impl Block {
|
|||||||
attn,
|
attn,
|
||||||
rms_2,
|
rms_2,
|
||||||
mlp,
|
mlp,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ struct InferenceCmd {
|
|||||||
#[arg(long, default_value = "")]
|
#[arg(long, default_value = "")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
/// Config file in binary format.
|
/// Config file in binary or safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config: Option<String>,
|
config: Option<String>,
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
for inp_tgt in batch_iter {
|
for inp_tgt in batch_iter {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
@ -225,11 +225,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(common_args.cpu)?;
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let is_safetensors = config_path
|
||||||
let config = Config::from_reader(&mut file)?;
|
.extension()
|
||||||
println!("{config:?}");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let (vb, config) = if is_safetensors {
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let config = Config::tiny();
|
||||||
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
|
(vb, config)
|
||||||
|
} else {
|
||||||
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
|
let config = Config::from_reader(&mut file)?;
|
||||||
|
println!("{config:?}");
|
||||||
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
|
(vb, config)
|
||||||
|
};
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
|
@ -1,118 +1,6 @@
|
|||||||
#![allow(dead_code)]
|
|
||||||
#![allow(unused)]
|
|
||||||
use crate::model::{Cache, Config, Llama};
|
use crate::model::{Cache, Config, Llama};
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result};
|
||||||
|
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||||
pub struct Dataset {
|
|
||||||
valid_tokens: Vec<memmap2::Mmap>,
|
|
||||||
train_tokens: Vec<memmap2::Mmap>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
|
||||||
let file = std::fs::File::open(p)?;
|
|
||||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
|
||||||
Ok(mmap)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Dataset {
|
|
||||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
|
||||||
let dir = dir.as_ref();
|
|
||||||
let mut bin_files = vec![];
|
|
||||||
for file in std::fs::read_dir(dir)?.flatten() {
|
|
||||||
let file = file.path();
|
|
||||||
if let Some(extension) = file.extension() {
|
|
||||||
if extension == "bin" {
|
|
||||||
bin_files.push(file)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if bin_files.len() < 2 {
|
|
||||||
candle::bail!("found less than two bin files in {:?}", dir)
|
|
||||||
}
|
|
||||||
bin_files.sort();
|
|
||||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
|
||||||
let train_tokens = bin_files[1..]
|
|
||||||
.iter()
|
|
||||||
.map(mmap_file)
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(Self {
|
|
||||||
valid_tokens: vec![valid_tokens],
|
|
||||||
train_tokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DatasetRandomIter<'a> {
|
|
||||||
all_tokens: &'a [memmap2::Mmap],
|
|
||||||
tokens: Vec<&'a memmap2::Mmap>,
|
|
||||||
current_tokens: &'a memmap2::Mmap,
|
|
||||||
indexes_in_bytes: Vec<usize>,
|
|
||||||
seq_len: usize,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> DatasetRandomIter<'a> {
|
|
||||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let all_tokens = if valid {
|
|
||||||
&ds.valid_tokens
|
|
||||||
} else {
|
|
||||||
&ds.train_tokens
|
|
||||||
};
|
|
||||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
|
||||||
tokens.shuffle(&mut thread_rng());
|
|
||||||
let current_tokens = tokens.pop().unwrap();
|
|
||||||
let seq_len_in_bytes = seq_len * 2;
|
|
||||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
|
||||||
.step_by(seq_len_in_bytes)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
|
||||||
Self {
|
|
||||||
all_tokens,
|
|
||||||
tokens,
|
|
||||||
current_tokens,
|
|
||||||
indexes_in_bytes,
|
|
||||||
seq_len,
|
|
||||||
device,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
|
||||||
type Item = Result<(Tensor, Tensor)>;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let seq_len = self.seq_len;
|
|
||||||
if self.indexes_in_bytes.is_empty() {
|
|
||||||
if self.tokens.is_empty() {
|
|
||||||
self.tokens = self.all_tokens.iter().collect();
|
|
||||||
self.tokens.shuffle(&mut thread_rng());
|
|
||||||
}
|
|
||||||
self.current_tokens = self.tokens.pop().unwrap();
|
|
||||||
let seq_len_in_bytes = self.seq_len * 2;
|
|
||||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
|
||||||
.step_by(seq_len_in_bytes)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
|
||||||
}
|
|
||||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
|
||||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
|
||||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
|
||||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
|
||||||
return Some(Err(err.into()));
|
|
||||||
}
|
|
||||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
|
||||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
|
||||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
|
||||||
Some(candle::error::zip(inputs, targets))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn valid_loss(
|
fn valid_loss(
|
||||||
dataset: &Dataset,
|
dataset: &Dataset,
|
||||||
@ -121,7 +9,7 @@ fn valid_loss(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
let mut sum_ce = 0f64;
|
let mut sum_ce = 0f64;
|
||||||
let mut cnt = 0usize;
|
let mut cnt = 0usize;
|
||||||
for inp_tgt in batch_iter.take(50) {
|
for inp_tgt in batch_iter.take(50) {
|
||||||
@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
||||||
println!(
|
println!(
|
||||||
"loaded dataset, train: {} files, valid: {} files",
|
"loaded dataset, train: {} files, valid: {} files",
|
||||||
dataset.train_tokens.len(),
|
dataset.train_tokens(),
|
||||||
dataset.valid_tokens.len()
|
dataset.valid_tokens()
|
||||||
);
|
);
|
||||||
let varmap = candle_nn::VarMap::new();
|
let varmap = candle_nn::VarMap::new();
|
||||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
@ -104,7 +104,15 @@ impl TransformerWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
|
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
|
||||||
|
// TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
|
||||||
|
// size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
|
||||||
|
// second matrix back. We detect this case here and as a temporary hack make the weight
|
||||||
|
// matrix column major rather than row major. This ends up speeding up text generation from
|
||||||
|
// 120 token/s to 220 token/s on a Ryzen 2600X.
|
||||||
|
let tr = device.is_cpu() && !candle::utils::has_mkl();
|
||||||
|
let tr = false;
|
||||||
|
let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
|
||||||
let mut ws = std::collections::HashMap::new();
|
let mut ws = std::collections::HashMap::new();
|
||||||
let mut insert = |name: &str, t: Tensor| {
|
let mut insert = |name: &str, t: Tensor| {
|
||||||
ws.insert(name.to_string(), t);
|
ws.insert(name.to_string(), t);
|
||||||
@ -115,36 +123,36 @@ impl TransformerWeights {
|
|||||||
"model.embed_tokens.weight",
|
"model.embed_tokens.weight",
|
||||||
self.token_embedding_table.clone(),
|
self.token_embedding_table.clone(),
|
||||||
);
|
);
|
||||||
insert("lm_head.weight", self.token_embedding_table.clone());
|
insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
|
||||||
insert("model.norm.weight", self.rms_final_weight.clone());
|
insert("model.norm.weight", self.rms_final_weight.clone());
|
||||||
for layer in 0..cfg.n_layers {
|
for layer in 0..cfg.n_layers {
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||||
self.wq.i(layer)?,
|
tr(self.wq.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||||
self.wk.i(layer)?,
|
tr(self.wk.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||||
self.wv.i(layer)?,
|
tr(self.wv.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||||
self.wo.i(layer)?,
|
tr(self.wo.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||||
self.w1.i(layer)?,
|
tr(self.w1.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||||
self.w2.i(layer)?,
|
tr(self.w2.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||||
self.w3.i(layer)?,
|
tr(self.w3.i(layer)?)?,
|
||||||
);
|
);
|
||||||
ws.insert(
|
ws.insert(
|
||||||
format!("model.layers.{layer}.input_layernorm.weight"),
|
format!("model.layers.{layer}.input_layernorm.weight"),
|
||||||
|
@ -63,7 +63,7 @@ struct TrainingArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn training_loop<M: Model>(
|
fn training_loop<M: Model>(
|
||||||
m: candle_nn::vision::Dataset,
|
m: candle_datasets::vision::Dataset,
|
||||||
args: &TrainingArgs,
|
args: &TrainingArgs,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
@ -140,7 +140,7 @@ struct Args {
|
|||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Load the dataset
|
// Load the dataset
|
||||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||||
println!("train-images: {:?}", m.train_images.shape());
|
println!("train-images: {:?}", m.train_images.shape());
|
||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
// error type if needed or add some specialized cases on the candle-core side.
|
// error type if needed or add some specialized cases on the candle-core side.
|
||||||
pub mod activation;
|
pub mod activation;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
pub mod dataset;
|
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
pub mod layer_norm;
|
pub mod layer_norm;
|
||||||
@ -11,7 +10,6 @@ pub mod loss;
|
|||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod vision;
|
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use conv::{Conv1d, Conv1dConfig};
|
pub use conv::{Conv1d, Conv1dConfig};
|
||||||
|
@ -111,7 +111,10 @@ impl Model {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
link.respond(id, Ok(WorkerOutput::Generated(prompt)));
|
link.respond(id, Ok(WorkerOutput::Generated(prompt)));
|
||||||
|
|
||||||
for index in 0..self.config.seq_len - 10 {
|
for index in 0.. {
|
||||||
|
if tokens.len() >= self.config.seq_len {
|
||||||
|
break;
|
||||||
|
}
|
||||||
let context_size = if self.cache.use_kv_cache && index > 0 {
|
let context_size = if self.cache.use_kv_cache && index > 0 {
|
||||||
1
|
1
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user