mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add the RWKV model (v5). (#1707)
* Start adding the RWKV model. * More of the forward step. * Handle rescaling. * FeedForward. * More work on RWKV. * Better state tracking. * Finish a first pass on forward. * Fix the shape mismatches. * Do not rescale in f32. * Rename to rwkv-v5. * Add the new models to the readme.
This commit is contained in:
@ -1,13 +1,12 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
|
Reference in New Issue
Block a user