Add the RWKV model (v5). (#1707)

* Start adding the RWKV model.

* More of the forward step.

* Handle rescaling.

* FeedForward.

* More work on RWKV.

* Better state tracking.

* Finish a first pass on forward.

* Fix the shape mismatches.

* Do not rescale in f32.

* Rename to rwkv-v5.

* Add the new models to the readme.
This commit is contained in:
Laurent Mazare
2024-02-14 10:58:32 +01:00
committed by GitHub
parent 68f7655895
commit 2d5f2a728d
5 changed files with 616 additions and 3 deletions

View File

@ -1,13 +1,12 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,