mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
More scaffolding, now need to implement matmul (for precompute_cos_sin to work).
This commit is contained in:
@ -223,10 +223,9 @@ impl Device {
|
|||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(_device) => {
|
Device::Metal(device) => {
|
||||||
// let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
// Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
bail!("Metal rand_normal not implemented")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -250,10 +249,9 @@ impl Device {
|
|||||||
let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(_device) => {
|
Device::Metal(device) => {
|
||||||
// let storage = device.ones_impl(shape, dtype)?;
|
let storage = device.ones_impl(shape, dtype)?;
|
||||||
// Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
bail!("Metal ones not implemented")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -268,10 +266,9 @@ impl Device {
|
|||||||
let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(_device) => {
|
Device::Metal(device) => {
|
||||||
// let storage = device.zeros_impl(shape, dtype)?;
|
let storage = device.zeros_impl(shape, dtype)?;
|
||||||
// Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
bail!("Metal zeros not implemented")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,11 +281,10 @@ impl Device {
|
|||||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(_device) => {
|
Device::Metal(device) => {
|
||||||
// let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
// let storage = device.storage_from_cpu_storage(&storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
// Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
bail!("Metal storage not implemented")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
pub use candle_metal;
|
pub use candle_metal;
|
||||||
use metal;
|
use metal;
|
||||||
|
use core::mem;
|
||||||
|
use half::{f16, bf16};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -43,8 +45,10 @@ impl MetalDevice {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
pub buffer: metal::Buffer,
|
buffer: metal::Buffer,
|
||||||
pub device: metal::Device,
|
device: MetalDevice,
|
||||||
|
dtype: DType
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendStorage for MetalStorage {
|
impl BackendStorage for MetalStorage {
|
||||||
@ -55,11 +59,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dtype(&self) -> DType {
|
fn dtype(&self) -> DType {
|
||||||
todo!()
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device(&self) -> &Self::Device {
|
fn device(&self) -> &Self::Device {
|
||||||
todo!()
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
@ -86,8 +90,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
todo!()
|
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
@ -182,12 +186,19 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
_: &Self,
|
rhs: &Self,
|
||||||
_: (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
_: &Layout,
|
lhs_l: &Layout,
|
||||||
_: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
let elem_count = b * m * n;
|
||||||
|
let dev = &self.device;
|
||||||
|
match (self.dtype, rhs.dtype){
|
||||||
|
(DType::F32, DType::F32) => {
|
||||||
|
todo!("MATMUL {b} {m} {n} {k}");
|
||||||
|
}
|
||||||
|
_ => todo!("Unimplemented matmul for this pair")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
@ -223,8 +234,60 @@ impl BackendDevice for MetalDevice {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
todo!("Storage")
|
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||||
|
let buffer = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<u8>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<u32>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<i64>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<bf16>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f16>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f32>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(storage.len() * mem::size_of::<f64>()) as u64,
|
||||||
|
option
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
@ -308,7 +308,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat => 8,
|
| Which::L70bChat => 8,
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module};
|
use candle_nn::{Embedding, Module};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
@ -181,28 +181,31 @@ pub struct ModelWeights {
|
|||||||
span_output: tracing::Span,
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
fn precomput_freqs_cis(head_dim: usize, freq_base: f32, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||||
let theta: Vec<_> = (0..head_dim)
|
let theta: Vec<_> = (0..head_dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
||||||
.to_dtype(DType::F32)?
|
let idx_theta = Tensor::new(range.as_slice(), device)?.reshape((MAX_SEQ_LEN, 1))?.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
// CPU as f32 seems just as fast
|
||||||
|
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
// .to_dtype(DType::F32)?
|
||||||
|
// .reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
let cos = idx_theta.cos()?;
|
let cos = idx_theta.cos()?;
|
||||||
let sin = idx_theta.sin()?;
|
let sin = idx_theta.sin()?;
|
||||||
Ok((cos, sin))
|
Ok((cos, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
|
||||||
let cpu = &Device::Cpu;
|
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||||
let output = ct.remove("output.weight")?;
|
let output = ct.remove("output.weight")?;
|
||||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||||
@ -276,7 +279,7 @@ impl ModelWeights {
|
|||||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(10000f32);
|
.unwrap_or(10000f32);
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
@ -331,14 +334,14 @@ impl ModelWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||||
if let Some(mask) = self.masks.get(&t) {
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
} else {
|
} else {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||||
self.masks.insert(t, mask.clone());
|
self.masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
@ -346,7 +349,7 @@ impl ModelWeights {
|
|||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mask = self.mask(seq_len)?;
|
let mask = self.mask(seq_len, x.device())?;
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
|
Reference in New Issue
Block a user