mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
cuda-conv-
...
moondream
Author | SHA1 | Date | |
---|---|---|---|
101a4c8389 |
@ -608,34 +608,6 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -1893,54 +1865,8 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
let slice = Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?;
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
|
@ -609,41 +609,28 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
|
@ -109,7 +109,8 @@ fn main() -> Result<()> {
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
|
@ -51,48 +51,6 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t l_in,
|
||||
const size_t l_out,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t b_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, k_size)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= b_size * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
|
||||
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
|
||||
const size_t b_i = dst_i / dst_s0;
|
||||
const size_t dst_i2 = dst_i - b_i * dst_s0;
|
||||
const size_t c_i = dst_i2 / dst_s1;
|
||||
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
|
||||
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
T d = 0;
|
||||
for (size_t k_i = 0; k_i < min(dst_i3 + 1, k_size); ++k_i) {
|
||||
const size_t l_in_i_times_stride = dst_i3 - k_i;
|
||||
const size_t l_in_i = l_in_i_times_stride / stride;
|
||||
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
if (l_in_i * stride == l_in_i_times_stride && l_in_i < l_in) {
|
||||
d += src[src_i];
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
@ -569,7 +527,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME, FN_NAME2) \
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t l_out, \
|
||||
@ -583,18 +541,6 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME2( \
|
||||
const size_t l_in, \
|
||||
const size_t l_out, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t b_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(l_in, l_out, c_out, k_size, b_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -696,7 +642,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16, col2im1d_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -708,7 +654,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16, col2im1d_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -751,7 +697,7 @@ IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32, col2im1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64, col2im1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8, col2im1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32, col2im1d_u32)
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
@ -72,60 +72,27 @@ kernel void FN_NAME_STRIDED( \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
// u32
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
// f16
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
|
||||
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
|
||||
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
|
||||
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
|
||||
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
|
||||
#endif
|
||||
|
||||
// f32
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
|
||||
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
|
||||
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
@ -292,7 +292,7 @@ fn binary_ops_bf16() {
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -319,189 +319,107 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> f16
|
||||
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f32 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// f16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// f16 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// bf16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// bf16 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// bf16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// bf16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// bf16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
// u32 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// u32 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// u32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// u8 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// u8 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u8 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u8 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// u8 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// i64 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// i64 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
// i64 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// i64 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
// i64 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
|
@ -23,6 +23,7 @@ pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
|
174
candle-transformers/src/models/moondream.rs
Normal file
174
candle-transformers/src/models/moondream.rs
Normal file
@ -0,0 +1,174 @@
|
||||
#![allow(unused)]
|
||||
use crate::models::phi;
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{linear_b, Linear, VarBuilder};
|
||||
|
||||
// https://github.com/vikhyat/moondream/blob/main/moondream/configuration_moondream.py
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
phi_config: phi::Config,
|
||||
vision_config: VisionConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct VisionConfig {
|
||||
image_embedding_dim: usize,
|
||||
model_dim: usize,
|
||||
hidden_dim: usize,
|
||||
act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl VisionConfig {
|
||||
pub fn v2() -> Self {
|
||||
Self {
|
||||
image_embedding_dim: 1152,
|
||||
model_dim: 2048,
|
||||
hidden_dim: 2048 * 4,
|
||||
act: candle_nn::Activation::Silu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn v2() -> Self {
|
||||
let phi_config = phi::Config {
|
||||
vocab_size: 51200,
|
||||
hidden_size: 2048,
|
||||
intermediate_size: 8192,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: None,
|
||||
hidden_act: candle_nn::Activation::NewGelu,
|
||||
max_position_embeddings: 2048,
|
||||
tie_word_embeddings: false,
|
||||
layer_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.,
|
||||
partial_rotary_factor: 0.5,
|
||||
qk_layernorm: false,
|
||||
};
|
||||
let vision_config = VisionConfig::v2();
|
||||
Self {
|
||||
phi_config,
|
||||
vision_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LinearPatchEmbedding {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
act: candle_nn::Activation,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
in_f: usize,
|
||||
hidden_f: usize,
|
||||
out_f: usize,
|
||||
act: candle_nn::Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let fc1 = linear_b(in_f, hidden_f, true, vb.pp("fc1"))?;
|
||||
let fc2 = linear_b(hidden_f, out_f, true, vb.pp("fc2"))?;
|
||||
Ok(Self { fc1, act, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionProjection {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl VisionProjection {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp = Mlp::new(
|
||||
cfg.image_embedding_dim,
|
||||
cfg.hidden_dim,
|
||||
cfg.model_dim,
|
||||
cfg.act,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self { mlp })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionProjection {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.mlp)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionEncoder {
|
||||
encoder: Encoder,
|
||||
projection: VisionProjection,
|
||||
}
|
||||
|
||||
impl VisionEncoder {
|
||||
pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = Encoder::new(cfg, vb.pp("vision.trunk"))?;
|
||||
let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
projection,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, c, hp1, wp2) = xs.dims4()?;
|
||||
let (p1, p2) = (14, 14);
|
||||
let h = hp1 / p1;
|
||||
let w = wp2 / p2;
|
||||
let xs = xs
|
||||
.reshape((b, c, h, p1, h, p2))?
|
||||
.permute((0, 2, 4, 1, 3, 5))?
|
||||
.reshape((b, h * w, c * p1 * p2))?;
|
||||
xs.apply(&self.encoder)?.apply(&self.projection)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
text_model: phi::Model,
|
||||
vision_encoder: VisionEncoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let text_model = phi::Model::new(&cfg.phi_config, vb.pp("text_model"))?;
|
||||
let vision_encoder = VisionEncoder::new(&cfg.vision_config, vb.pp("vision_encoder"))?;
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_encoder,
|
||||
})
|
||||
}
|
||||
}
|
@ -106,7 +106,7 @@ impl Module for MLP {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -265,7 +265,7 @@ impl Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
@ -304,7 +304,7 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
|
Reference in New Issue
Block a user