diff --git a/nostrdb/NdbNote.swift b/nostrdb/NdbNote.swift index e06ba913..d489ecbe 100644 --- a/nostrdb/NdbNote.swift +++ b/nostrdb/NdbNote.swift @@ -41,7 +41,7 @@ enum NdbData { class NdbNote: Encodable, Equatable, Hashable { // we can have owned notes, but we can also have lmdb virtual-memory mapped notes so its optional - private let owned: Bool + let owned: Bool let count: Int let key: NoteKey? let note: UnsafeMutablePointer diff --git a/nostrdb/NdbTxn.swift b/nostrdb/NdbTxn.swift index 876fa80a..06d7fdce 100644 --- a/nostrdb/NdbTxn.swift +++ b/nostrdb/NdbTxn.swift @@ -16,22 +16,32 @@ class NdbTxn { var txn: ndb_txn private var val: T! var moved: Bool + var inherited: Bool init(ndb: Ndb, with: (NdbTxn) -> T = { _ in () }) { - self.txn = ndb_txn() #if TXNDEBUG txn_count += 1 print("opening transaction \(txn_count)") #endif - let _ = ndb_begin_query(ndb.ndb.ndb, &self.txn) + if let active_txn = Thread.current.threadDictionary["ndb_txn"] as? ndb_txn { + // some parent thread is active, use that instead + self.txn = active_txn + self.inherited = true + } else { + self.txn = ndb_txn() + let _ = ndb_begin_query(ndb.ndb.ndb, &self.txn) + Thread.current.threadDictionary["ndb_txn"] = self.txn + self.inherited = false + } self.moved = false self.val = with(self) } - init(txn: ndb_txn, val: T) { + private init(txn: ndb_txn, val: T) { self.txn = txn self.val = val self.moved = false + self.inherited = false } /// Only access temporarily! Do not store database references for longterm use. If it's a primitive type you @@ -42,13 +52,16 @@ class NdbTxn { } deinit { - if !moved { - #if TXNDEBUG - txn_count -= 1; - print("closing transaction \(txn_count)") - #endif - ndb_end_query(&self.txn) + if moved || inherited { + return } + + #if TXNDEBUG + txn_count -= 1; + print("closing transaction \(txn_count)") + #endif + ndb_end_query(&self.txn) + Thread.current.threadDictionary.removeObject(forKey: "ndb_txn") } // functor diff --git a/nostrdb/Test/NdbTests.swift b/nostrdb/Test/NdbTests.swift index 855b605b..c1d8bd5a 100644 --- a/nostrdb/Test/NdbTests.swift +++ b/nostrdb/Test/NdbTests.swift @@ -155,6 +155,25 @@ final class NdbTests: XCTestCase { let testNote = NdbNote.owned_from_json(json: testJSONWithEscapedSlashes)! XCTAssertEqual(testNote.content, "https://cdn.nostr.build/i/5c1d3296f66c2630131bf123106486aeaf051ed8466031c0e0532d70b33cddb2.jpg") } + + func test_inherited_transactions() throws { + let ndb = Ndb(path: db_dir)! + do { + let txn1 = NdbTxn(ndb: ndb) + + let ntxn = (Thread.current.threadDictionary.value(forKey: "ndb_txn") as? ndb_txn)! + XCTAssertEqual(txn1.txn.lmdb, ntxn.lmdb) + XCTAssertEqual(txn1.txn.mdb_txn, ntxn.mdb_txn) + + let txn2 = NdbTxn(ndb: ndb) + + XCTAssertEqual(txn1.inherited, false) + XCTAssertEqual(txn2.inherited, true) + } + + let ndb_txn = Thread.current.threadDictionary.value(forKey: "ndb_txn") + XCTAssertNil(ndb_txn) + } func test_decode_perf() throws { // This is an example of a performance test case.