mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Matmul (no batch, no strided, f32, f32 only) sort of done.
This commit is contained in:
@ -49,11 +49,11 @@ mod device;
|
|||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
pub mod metal_backend;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
pub mod layout;
|
pub mod layout;
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
pub mod metal_backend;
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
mod metal_backend;
|
mod metal_backend;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
|
@ -3,9 +3,13 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
|||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
pub use candle_metal;
|
pub use candle_metal;
|
||||||
use metal;
|
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{f16, bf16};
|
use half::{bf16, f16};
|
||||||
|
use metal;
|
||||||
|
use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor};
|
||||||
|
use metal::mps::{Float32, MPSDataType};
|
||||||
|
use metal::MTLResourceOptions;
|
||||||
|
use crate::bail;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -17,6 +21,7 @@ pub enum MetalError {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
|
command_queue: metal::CommandQueue,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -47,8 +52,7 @@ impl MetalDevice {
|
|||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: metal::Buffer,
|
buffer: metal::Buffer,
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType
|
dtype: DType,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendStorage for MetalStorage {
|
impl BackendStorage for MetalStorage {
|
||||||
@ -192,12 +196,77 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
let dev = &self.device;
|
match (self.dtype, rhs.dtype) {
|
||||||
match (self.dtype, rhs.dtype){
|
|
||||||
(DType::F32, DType::F32) => {
|
(DType::F32, DType::F32) => {
|
||||||
todo!("MATMUL {b} {m} {n} {k}");
|
if b != 1 {
|
||||||
|
bail!("Didn't implemented strided matmul yet");
|
||||||
}
|
}
|
||||||
_ => todo!("Unimplemented matmul for this pair")
|
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||||
|
bail!("Didn't implemented non contiguous matmul yet");
|
||||||
|
}
|
||||||
|
let out_buffer = self.device.new_buffer(
|
||||||
|
(elem_count * mem::size_of::<f32>()) as u64,
|
||||||
|
MTLResourceOptions::empty(),
|
||||||
|
);
|
||||||
|
let m : u64 = m.try_into().expect("usize should fit u64");
|
||||||
|
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||||
|
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||||
|
// Create descriptors
|
||||||
|
let left_descriptor =
|
||||||
|
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
|
||||||
|
let right_descriptor =
|
||||||
|
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||||
|
let result_descriptor =
|
||||||
|
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||||
|
|
||||||
|
// Create matrix objects
|
||||||
|
let left_matrix =
|
||||||
|
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||||
|
.expect("Failed to create left matrix");
|
||||||
|
let right_matrix =
|
||||||
|
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||||
|
.expect("Failed to create left matrix");
|
||||||
|
|
||||||
|
let result_matrix =
|
||||||
|
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||||
|
.expect("Failed to create left matrix");
|
||||||
|
|
||||||
|
let transpose_left = false;
|
||||||
|
let transpose_right = false;
|
||||||
|
let alpha = 1.0;
|
||||||
|
let beta = 0.0;
|
||||||
|
|
||||||
|
|
||||||
|
// Create kernel
|
||||||
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
|
&self.device,
|
||||||
|
transpose_left,
|
||||||
|
transpose_right,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
)
|
||||||
|
.expect("Failed to create matrix multiplication kernel");
|
||||||
|
|
||||||
|
let buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
// Encode kernel to command buffer
|
||||||
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
|
&buffer,
|
||||||
|
&left_matrix,
|
||||||
|
&right_matrix,
|
||||||
|
&result_matrix,
|
||||||
|
);
|
||||||
|
buffer.commit();
|
||||||
|
Ok(Self{
|
||||||
|
buffer: out_buffer,
|
||||||
|
device: self.device.clone(),
|
||||||
|
dtype: self.dtype(),
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
_ => todo!("Unimplemented matmul for this pair"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +280,8 @@ impl BackendDevice for MetalDevice {
|
|||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
Ok(Self{device })
|
let command_queue = device.new_command_queue();
|
||||||
|
Ok(Self { device, command_queue })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
@ -237,57 +307,47 @@ impl BackendDevice for MetalDevice {
|
|||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||||
let buffer = match storage {
|
let buffer = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
(storage.len() * mem::size_of::<u8>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
(storage.len() * mem::size_of::<u32>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
(storage.len() * mem::size_of::<i64>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
(storage.len() * mem::size_of::<bf16>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
(storage.len() * mem::size_of::<f16>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
(storage.len() * mem::size_of::<f32>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
self.device.new_buffer_with_data(
|
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
(storage.len() * mem::size_of::<f64>()) as u64,
|
||||||
option
|
option,
|
||||||
)
|
),
|
||||||
}
|
|
||||||
};
|
};
|
||||||
Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()})
|
Ok(Self::Storage {
|
||||||
|
buffer,
|
||||||
|
device: self.clone(),
|
||||||
|
dtype: storage.dtype(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, MetalStorage, Layout, Result, Shape, Tensor};
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -176,7 +176,11 @@ pub trait CustomOp1 {
|
|||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
fn metal_fwd(&self, _storage: &MetalStorage, _layout: &Layout) -> Result<(MetalStorage, Shape)> {
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
Err(crate::Error::Metal(
|
Err(crate::Error::Metal(
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
))
|
))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
use super::{k_quants, GgmlDType};
|
use super::{k_quants, GgmlDType};
|
||||||
use crate::{Result, Device};
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -148,16 +148,36 @@ pub fn qtensor_from_ggml(
|
|||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::Q4_0 => {
|
||||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device),
|
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device),
|
}
|
||||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::Q4_1 => {
|
||||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device),
|
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device),
|
}
|
||||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::Q5_0 => {
|
||||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device),
|
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device),
|
}
|
||||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::Q5_1 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q8_0 => {
|
||||||
|
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q3K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q5K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||||
|
}
|
||||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -204,7 +224,10 @@ pub struct Content {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R, device: &Device) -> Result<Content> {
|
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Content> {
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::{Result, Device};
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -70,7 +70,12 @@ impl TensorInfo {
|
|||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec(), device)
|
super::ggml_file::qtensor_from_ggml(
|
||||||
|
self.ggml_dtype,
|
||||||
|
&raw_data,
|
||||||
|
self.shape.dims().to_vec(),
|
||||||
|
device,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ impl QTensor {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
data: Box::new(data),
|
data: Box::new(data),
|
||||||
shape,
|
shape,
|
||||||
device: device.clone()
|
device: device.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ impl QTensor {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
data: Box::new(data),
|
data: Box::new(data),
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
device: device.clone()
|
device: device.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, MetalStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -659,7 +659,9 @@ impl Storage {
|
|||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
(Self::Metal(src), Self::Metal(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
(Self::Metal(src), Self::Metal(dst)) => {
|
||||||
|
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||||
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -2,25 +2,27 @@ pub mod coco_classes;
|
|||||||
pub mod imagenet;
|
pub mod imagenet;
|
||||||
pub mod token_output_stream;
|
pub mod token_output_stream;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
|
||||||
use candle::utils::{cuda_is_available, metal_is_available};
|
use candle::utils::{cuda_is_available, metal_is_available};
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
if cuda_is_available(){
|
if cuda_is_available() {
|
||||||
Ok(Device::new_cuda(0)?)
|
Ok(Device::new_cuda(0)?)
|
||||||
}else if metal_is_available(){
|
} else if metal_is_available() {
|
||||||
Ok(Device::new_metal(0)?)
|
Ok(Device::new_metal(0)?)
|
||||||
}else{
|
} else {
|
||||||
#[cfg(all(target_os="macos", target_arch="aarch64"))]
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
{
|
{
|
||||||
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
|
||||||
}
|
}
|
||||||
#[cfg(not(all(target_os="macos", target_arch="aarch64")))]
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
{
|
{
|
||||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
println!(
|
||||||
|
"Running on CPU, to run on GPU, build this example with `--features cuda`"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
}
|
}
|
||||||
|
@ -181,14 +181,20 @@ pub struct ModelWeights {
|
|||||||
span_output: tracing::Span,
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32, device: &Device) -> Result<(Tensor, Tensor)> {
|
fn precomput_freqs_cis(
|
||||||
|
head_dim: usize,
|
||||||
|
freq_base: f32,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let theta: Vec<_> = (0..head_dim)
|
let theta: Vec<_> = (0..head_dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
|
||||||
let idx_theta = Tensor::new(range.as_slice(), device)?.reshape((MAX_SEQ_LEN, 1))?.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
let idx_theta = Tensor::new(range.as_slice(), device)?
|
||||||
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
// TODO This change avoids allocating on Metal and then casting since allocating directly on
|
||||||
// CPU as f32 seems just as fast
|
// CPU as f32 seems just as fast
|
||||||
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
@ -260,7 +266,7 @@ impl ModelWeights {
|
|||||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
ct: gguf_file::Content,
|
ct: gguf_file::Content,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
device: &Device
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
None => candle::bail!("cannot find {s} in metadata"),
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
@ -283,7 +289,10 @@ impl ModelWeights {
|
|||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps)?;
|
let norm = RmsNorm::new(
|
||||||
|
ct.tensor(reader, "output_norm.weight", device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
let output = ct.tensor(reader, "output.weight", device)?;
|
let output = ct.tensor(reader, "output.weight", device)?;
|
||||||
let mut layers = Vec::with_capacity(block_count);
|
let mut layers = Vec::with_capacity(block_count);
|
||||||
for layer_idx in 0..block_count {
|
for layer_idx in 0..block_count {
|
||||||
@ -291,11 +300,15 @@ impl ModelWeights {
|
|||||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
let attention_wo =
|
||||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
let feed_forward_w1 =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||||
|
let feed_forward_w2 =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
let attention_norm =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
|
Reference in New Issue
Block a user