1
0
mirror of git://jb55.com/damus synced 2024-10-04 19:00:42 +00:00

nostrdb/Add fulltext search index

Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
William Casarin 2023-11-27 16:08:42 -08:00
parent 53fc1b6945
commit d541153e4c
2 changed files with 555 additions and 19 deletions

View File

@ -44,6 +44,8 @@ static const int DEFAULT_QUEUE_SIZE = 1000000;
#define NDB_PARSED_ALL (NDB_PARSED_ID|NDB_PARSED_PUBKEY|NDB_PARSED_SIG|NDB_PARSED_CREATED_AT|NDB_PARSED_KIND|NDB_PARSED_CONTENT|NDB_PARSED_TAGS)
typedef int (*ndb_migrate_fn)(struct ndb *);
typedef int (*ndb_word_parser_fn)(void *, const char *word, int word_len,
int word_index);
struct ndb_migration {
ndb_migrate_fn fn;
@ -133,6 +135,156 @@ struct ndb_u64_tsid {
uint64_t timestamp;
};
// uncompressed form of the actual lmdb key
struct ndb_text_search_key
{
int str_len;
const char *str;
int word_index;
uint64_t timestamp;
};
// ndb_text_search_key
//
// This is compressed when in lmdb:
//
// strlen: varint
// str: cstr
// timestamp: varint
// word_index: varint
static int ndb_make_text_search_key(unsigned char *buf, int bufsize,
int word_index, int word_len, const char *str,
uint64_t timestamp, int *keysize)
{
struct cursor cur;
int size, pad;
make_cursor(buf, buf + bufsize, &cur);
// string length
if (!push_varint(&cur, word_len))
return 0;
// non-null terminated string
if (!cursor_push(&cur, (unsigned char*)str, word_len))
return 0;
// the index of the word in the content so that we can do more accurate
// phrase searches
if (!push_varint(&cur, word_index))
return 0;
// TODO: need update this to uint64_t
if (!push_varint(&cur, (int)timestamp))
return 0;
size = cur.p - cur.start;
// pad to 8-byte alignment
pad = ((size + 7) & ~7) - size;
if (pad > 0) {
if (!cursor_memset(&cur, 0, pad)) {
return 0;
}
}
*keysize = cur.p - cur.start;
assert((*keysize % 8) == 0);
return 1;
}
static int ndb_make_text_search_key_low(unsigned char *buf, int bufsize,
int wordlen, const char *word,
int *keysize)
{
return ndb_make_text_search_key(buf, bufsize, 0, wordlen, word, 0, keysize);
}
/** From LMDB: Compare two items lexically */
static int mdb_cmp_memn(const MDB_val *a, const MDB_val *b) {
int diff;
ssize_t len_diff;
unsigned int len;
len = a->mv_size;
len_diff = (ssize_t) a->mv_size - (ssize_t) b->mv_size;
if (len_diff > 0) {
len = b->mv_size;
len_diff = 1;
}
diff = memcmp(a->mv_data, b->mv_data, len);
return diff ? diff : len_diff<0 ? -1 : len_diff;
}
static int ndb_text_search_key_compare(const MDB_val *a, const MDB_val *b)
{
struct cursor ca, cb;
int sa, sb;
MDB_val a2, b2;
make_cursor(a->mv_data, a->mv_data + a->mv_size, &ca);
make_cursor(b->mv_data, b->mv_data + b->mv_size, &cb);
// string size
if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb)))
return 0;
a2.mv_data = ca.p;
a2.mv_size = sa;
b2.mv_data = cb.p;
b2.mv_size = sb;
int cmp = mdb_cmp_memn(&a2, &b2);
if (cmp) return cmp;
// skip over string
ca.p += sa;
cb.p += sb;
// timestamp
if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb)))
return 0;
if (sa < sb) return -1;
else if (sa > sb) return 1;
// word index
if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb)))
return 0;
if (sa < sb) return -1;
else if (sa > sb) return 1;
return 0;
}
/*
static int ndb_decompress_text_search_key(unsigned char *p, int len,
struct ndb_text_search_key *key)
{
struct cursor c;
make_cursor(p, p + len, &c);
if (!pull_varint(&c, &key->str_len))
return 0;
key->str = cur->p;
if (!cursor_skip(&c, key->str_len))
return 0;
if (!pull_varint(&c, &key->word_index))
return 0;
if (!pull_varint(&c, &key->timestamp))
return 0;
}
*/
// Copies only lowercase characters to the destination string and fills the rest with null bytes.
// `dst` and `src` are pointers to the destination and source strings, respectively.
// `n` is the maximum number of characters to copy.
@ -742,23 +894,6 @@ int ndb_db_version(struct ndb *ndb)
return version;
}
/** From LMDB: Compare two items lexically */
static int mdb_cmp_memn(const MDB_val *a, const MDB_val *b) {
int diff;
ssize_t len_diff;
unsigned int len;
len = a->mv_size;
len_diff = (ssize_t) a->mv_size - (ssize_t) b->mv_size;
if (len_diff > 0) {
len = b->mv_size;
len_diff = 1;
}
diff = memcmp(a->mv_data, b->mv_data, len);
return diff ? diff : len_diff<0 ? -1 : len_diff;
}
// custom kind+timestamp comparison function. This is used by lmdb to perform
// b+ tree searches over the kind+timestamp index
static int ndb_u64_tsid_compare(const MDB_val *a, const MDB_val *b)
@ -814,10 +949,10 @@ static inline void ndb_tsid_init(struct ndb_tsid *key, unsigned char *id,
key->timestamp = timestamp;
}
static inline void ndb_u64_tsid_init(struct ndb_tsid *key, uint64_t integer,
static inline void ndb_u64_tsid_init(struct ndb_u64_tsid *key, uint64_t integer,
uint64_t timestamp)
{
key->integer = integer;
key->u64 = integer;
key->timestamp = timestamp;
}
@ -1877,6 +2012,388 @@ static int ndb_write_note_kind_index(struct ndb_txn *txn, struct ndb_note *note,
return 1;
}
/**
* Checks if a given Unicode code point is a punctuation character
*
* @param codepoint The Unicode code point to check. @return true if the
* code point is a punctuation character, false otherwise.
*/
static inline int is_punctuation(unsigned int codepoint) {
// Check for underscore (underscore is not treated as punctuation)
if (codepoint == '_')
return 0;
// Check for ASCII punctuation
if (ispunct(codepoint))
return 1;
// Check for Unicode punctuation exceptions (punctuation allowed in hashtags)
if (codepoint == 0x301C || codepoint == 0xFF5E) // Japanese Wave Dash / Tilde
return 0;
// Check for Unicode punctuation
// NOTE: We may need to adjust the codepoint ranges in the future,
// to include/exclude certain types of Unicode characters in hashtags.
// Unicode Blocks Reference: https://www.compart.com/en/unicode/block
return (
// Latin-1 Supplement No-Break Space (NBSP): U+00A0
(codepoint == 0x00A0) ||
// Latin-1 Supplement Punctuation: U+00A1 to U+00BF
(codepoint >= 0x00A1 && codepoint <= 0x00BF) ||
// General Punctuation: U+2000 to U+206F
(codepoint >= 0x2000 && codepoint <= 0x206F) ||
// Currency Symbols: U+20A0 to U+20CF
(codepoint >= 0x20A0 && codepoint <= 0x20CF) ||
// Supplemental Punctuation: U+2E00 to U+2E7F
(codepoint >= 0x2E00 && codepoint <= 0x2E7F) ||
// CJK Symbols and Punctuation: U+3000 to U+303F
(codepoint >= 0x3000 && codepoint <= 0x303F) ||
// Ideographic Description Characters: U+2FF0 to U+2FFF
(codepoint >= 0x2FF0 && codepoint <= 0x2FFF)
);
}
static inline int is_whitespace(char c) {
return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r';
}
static inline int is_right_boundary(int c) {
return is_whitespace(c) || is_punctuation(c);
}
static inline int parse_byte(struct cursor *cursor, unsigned char *c)
{
if (unlikely(cursor->p >= cursor->end))
return 0;
*c = *cursor->p;
return 1;
}
static inline int peek_char(struct cursor *cur, int ind) {
if ((cur->p + ind < cur->start) || (cur->p + ind >= cur->end))
return -1;
return *(cur->p + ind);
}
static int parse_utf8_char(struct cursor *cursor, unsigned int *code_point,
unsigned int *utf8_length)
{
unsigned char first_byte;
if (!parse_byte(cursor, &first_byte))
return 0; // Not enough data
// Determine the number of bytes in this UTF-8 character
int remaining_bytes = 0;
if (first_byte < 0x80) {
*code_point = first_byte;
return 1;
} else if ((first_byte & 0xE0) == 0xC0) {
remaining_bytes = 1;
*utf8_length = remaining_bytes + 1;
*code_point = first_byte & 0x1F;
} else if ((first_byte & 0xF0) == 0xE0) {
remaining_bytes = 2;
*utf8_length = remaining_bytes + 1;
*code_point = first_byte & 0x0F;
} else if ((first_byte & 0xF8) == 0xF0) {
remaining_bytes = 3;
*utf8_length = remaining_bytes + 1;
*code_point = first_byte & 0x07;
} else {
remaining_bytes = 0;
*utf8_length = 1; // Assume 1 byte length for unrecognized UTF-8 characters
// TODO: We need to gracefully handle unrecognized UTF-8 characters
//printf("Invalid UTF-8 byte: %x\n", *code_point);
*code_point = ((first_byte & 0xF0) << 6); // Prevent testing as punctuation
return 0; // Invalid first byte
}
// Peek at remaining bytes
for (int i = 0; i < remaining_bytes; ++i) {
signed char next_byte;
if ((next_byte = peek_char(cursor, i+1)) == -1) {
*utf8_length = 1;
return 0; // Not enough data
}
if ((next_byte & 0xC0) != 0x80) {
*utf8_length = 1;
return 0; // Invalid byte in sequence
}
*code_point = (*code_point << 6) | (next_byte & 0x3F);
}
return 1;
}
static inline int is_utf8_byte(unsigned char c) {
return c & 0x80;
}
static inline int consume_until_boundary(struct cursor *cur) {
unsigned int c;
unsigned int char_length = 1;
unsigned int *utf8_char_length = &char_length;
while (cur->p < cur->end) {
c = *cur->p;
*utf8_char_length = 1;
if (is_whitespace(c))
return 1;
// Need to check for UTF-8 characters, which can be multiple
// bytes long
if (is_utf8_byte(c)) {
if (!parse_utf8_char(cur, &c, utf8_char_length)) {
if (!is_right_boundary(c)){
// TODO: We should work towards
// handling all UTF-8 characters.
//printf("Invalid UTF-8 code point: %x\n", c);
}
}
}
if (is_right_boundary(c))
return 1;
// Need to use a variable character byte length for UTF-8 (2-4 bytes)
if (cur->p + *utf8_char_length <= cur->end)
cur->p += *utf8_char_length;
else
cur->p++;
}
return 1;
}
static void consume_whitespace_or_punctuation(struct cursor *cur)
{
while (cur->p < cur->end) {
if (!is_right_boundary(*cur->p))
return;
cur->p++;
}
}
static int ndb_write_word_to_index(struct ndb_txn *txn, const char *word,
int word_len, int word_index,
uint64_t timestamp, uint64_t note_id)
{
// cap to some reasonable key size
unsigned char buffer[1024];
int keysize, rc;
MDB_val k, v;
MDB_dbi text_db;
// build our compressed text index key
if (!ndb_make_text_search_key(buffer, sizeof(buffer), word_index,
word_len, word, timestamp, &keysize)) {
// probably too big
return 0;
}
k.mv_data = buffer;
k.mv_size = keysize;
v.mv_data = &note_id;
v.mv_size = sizeof(note_id);
text_db = txn->lmdb->dbs[NDB_DB_NOTE_TEXT];
if ((rc = mdb_put(txn->mdb_txn, text_db, &k, &v, 0))) {
ndb_debug("write note text index to db failed: %s\n",
mdb_strerror(rc));
return 0;
}
return 1;
}
static int ndb_parse_words(struct cursor *cur, void *ctx, ndb_word_parser_fn fn)
{
int word_len, words;
const char *word;
words = 0;
while (cur->p < cur->end) {
consume_whitespace_or_punctuation(cur);
if (cur->p >= cur->end)
break;
word = (const char *)cur->p;
if (!consume_until_boundary(cur))
break;
// start of word or end
word_len = cur->p - (unsigned char *)word;
if (word_len == 0 && cur->p >= cur->end)
break;
if (!fn(ctx, word, word_len, words))
continue;
words++;
}
return 1;
}
struct ndb_word_writer_ctx
{
struct ndb_txn *txn;
struct ndb_note *note;
uint64_t note_id;
};
static int ndb_fulltext_word_writer(void *ctx,
const char *word, int word_len, int words)
{
struct ndb_word_writer_ctx *wctx = ctx;
if (!ndb_write_word_to_index(wctx->txn, word, word_len, words,
wctx->note->created_at, wctx->note_id)) {
// too big to write this one, just skip it
ndb_debug(stderr, "failed to write word '%.*s' to index\n", word_len, word);
return 0;
}
//fprintf(stderr, "wrote '%.*s' to note text index\n", word_len, word);
return 1;
}
static int ndb_write_note_fulltext_index(struct ndb_txn *txn,
struct ndb_note *note,
uint64_t note_id)
{
struct cursor cur;
unsigned char *content;
struct ndb_str str;
struct ndb_word_writer_ctx ctx;
str = ndb_note_str(note, &note->content);
// I don't think this should happen?
if (unlikely(str.flag == NDB_PACKED_ID))
return 0;
content = (unsigned char *)str.str;
make_cursor(content, content + note->content_length, &cur);
ctx.txn = txn;
ctx.note = note;
ctx.note_id = note_id;
ndb_parse_words(&cur, &ctx, ndb_fulltext_word_writer);
return 1;
}
struct ndb_word
{
const char *word;
int word_len;
};
#define MAX_SEARCH_WORDS 16
struct ndb_search_words
{
struct ndb_word words[MAX_SEARCH_WORDS];
int num_words;
};
static int ndb_parse_search_words(void *ctx, const char *word_str, int word_len, int word_index)
{
struct ndb_search_words *words = ctx;
struct ndb_word *word;
if (words->num_words + 1 > MAX_SEARCH_WORDS)
return 0;
word = &words->words[words->num_words++];
word->word = word_str;
word->word_len = word_len;
return 1;
}
int ndb_text_search(struct ndb_txn *txn, const char *query)
{
unsigned char buffer[1024];
struct ndb_search_words words;
struct ndb_word *word;
struct cursor cur;
MDB_dbi text_db;
MDB_cursor *cursor;
MDB_val k, v;
int i, rc, keysize;
size_t len;
//uint64_t note_ids[32], note_id;
uint64_t note_id;
struct ndb_note *note;
//int num_note_ids;
//num_note_ids = 0;
text_db = txn->lmdb->dbs[NDB_DB_NOTE_TEXT];
make_cursor((unsigned char *)query, (unsigned char *)query + strlen(query), &cur);
words.num_words = 0;
ndb_parse_words(&cur, &words, ndb_parse_search_words);
if ((rc = mdb_cursor_open(txn->mdb_txn, text_db, &cursor))) {
fprintf(stderr, "nd_text_search: mdb_cursor_open failed, error %d\n", rc);
return 0;
}
for (i = 0; i < words.num_words; i++) {
word = &words.words[i];
fprintf(stderr, "search word %.*s\n", word->word_len, word->word);
if (!ndb_make_text_search_key_low(buffer, sizeof(buffer),
word->word_len, word->word,
&keysize)) {
// word is too big to fit in 1024-sized key
continue;
}
k.mv_data = buffer;
k.mv_size = keysize;
// Position cursor at the next key greater than or equal to the specified key
if (mdb_cursor_get(cursor, &k, &v, MDB_SET_RANGE)) {
continue;
} else {
//note_ids[num_note_ids++] = *((uint64_t*)v.mv_data);
note_id = *((uint64_t*)v.mv_data);
if ((note = ndb_get_note_by_key(txn, note_id, &len))) {
fprintf(stderr, "found note: '%s' for query word '%.*s'\n",
ndb_note_str(note, &note->content).str,
word->word_len, word->word);
}
return 1;
}
}
return 1;
}
static uint64_t ndb_write_note(struct ndb_txn *txn,
struct ndb_writer_note *note)
{
@ -1910,6 +2427,12 @@ static uint64_t ndb_write_note(struct ndb_txn *txn,
if (!ndb_write_note_kind_index(txn, note->note, note_key))
return 0;
// only do fulltext index on kind1 notes
if (note->note->kind == 1) {
if (!ndb_write_note_fulltext_index(txn, note->note, note_key))
return 0;
}
if (note->note->kind == 7) {
ndb_write_reaction_stats(txn, note->note);
}
@ -2282,6 +2805,12 @@ static int ndb_init_lmdb(const char *filename, struct ndb_lmdb *lmdb, size_t map
}
mdb_set_compare(txn, lmdb->dbs[NDB_DB_NOTE_KIND], ndb_u64_tsid_compare);
if ((rc = mdb_dbi_open(txn, "note_text", tsid_flags, &lmdb->dbs[NDB_DB_NOTE_TEXT]))) {
fprintf(stderr, "mdb_dbi_open id failed: %s\n", mdb_strerror(rc));
return 0;
}
mdb_set_compare(txn, lmdb->dbs[NDB_DB_NOTE_TEXT], ndb_text_search_key_compare);
// Commit the transaction
if ((rc = mdb_txn_commit(txn))) {
fprintf(stderr, "mdb_txn_commit failed, error %d\n", rc);

View File

@ -42,6 +42,7 @@ enum ndb_dbs {
NDB_DB_PROFILE_SEARCH,
NDB_DB_PROFILE_LAST_FETCH,
NDB_DB_NOTE_KIND, // note kind index
NDB_DB_NOTE_TEXT, // note fulltext index
NDB_DBS,
};
@ -327,6 +328,10 @@ void ndb_filter_reset(struct ndb_filter *);
void ndb_filter_end_field(struct ndb_filter *);
void ndb_filter_free(struct ndb_filter *filter);
// FULLTEXT SEARCH
int ndb_text_search(struct ndb_txn *, const char *query);
// stats
int ndb_stat(struct ndb *ndb, struct ndb_stat *stat);
void ndb_stat_counts_init(struct ndb_stat_counts *counts);
@ -528,6 +533,8 @@ ndb_db_name(enum ndb_dbs db)
return "profile_last_fetch";
case NDB_DB_NOTE_KIND:
return "note_kind_index";
case NDB_DB_NOTE_TEXT:
return "note_fulltext";
case NDB_DBS:
return "count";
}