mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Custom tokenizer for rwkv. (#1711)
* Custom tokenizer for rwkv. * Custom tokenizer. * Getting the tokenizer to work.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
64
|
||||
@ -315,3 +316,94 @@ impl Model {
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
type Bytes = Vec<u8>;
|
||||
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
|
||||
pub struct Tokenizer {
|
||||
table: Vec<Vec<Vec<Bytes>>>,
|
||||
good: Vec<HashSet<u8>>,
|
||||
idx2token: HashMap<u32, Vec<u8>>,
|
||||
token2idx: HashMap<Vec<u8>, u32>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let token2idx: HashMap<String, u32> =
|
||||
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
|
||||
let token2idx = token2idx
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into_bytes(), value))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let idx2token = token2idx
|
||||
.iter()
|
||||
.map(|(key, value)| (*value, key.to_vec()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let max_idx = token2idx.values().copied().max().unwrap_or(0);
|
||||
|
||||
let mut table = vec![vec![vec![]; 256]; 256];
|
||||
let mut good = vec![HashSet::new(); 256];
|
||||
for idx in (0..(1 + max_idx)).rev() {
|
||||
let s = match idx2token.get(&idx) {
|
||||
None => continue,
|
||||
Some(s) => s,
|
||||
};
|
||||
if s.len() >= 2 {
|
||||
let (s0, s1) = (s[0], s[1]);
|
||||
table[s0 as usize][s1 as usize].push(s.to_vec());
|
||||
good[s0 as usize].insert(s1);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
table,
|
||||
good,
|
||||
idx2token,
|
||||
token2idx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
|
||||
let mut v = Vec::new();
|
||||
for token_id in tokens.iter() {
|
||||
if let Some(token) = self.idx2token.get(token_id) {
|
||||
v.extend_from_slice(token.as_slice())
|
||||
}
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
let bytes = self.decode_bytes(tokens);
|
||||
String::from_utf8(bytes).map_err(candle::Error::wrap)
|
||||
}
|
||||
|
||||
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let mut s = vec![bytes[i]];
|
||||
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
|
||||
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
|
||||
for table_elem in table.iter() {
|
||||
if bytes[i..].starts_with(table_elem) {
|
||||
s = table_elem.to_vec();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += s.len();
|
||||
let token = match self.token2idx.get(&s) {
|
||||
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
|
||||
Some(token) => *token,
|
||||
};
|
||||
tokens.push(token)
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
|
||||
self.encode_bytes(str.as_bytes())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user