use core::ops::Bound; use std::borrow::BorrowMut; use std::marker::PhantomPinned; use std::pin::Pin; use std::ptr::NonNull; use std::sync::{Arc, Mutex, MutexGuard}; use log::trace; use rusqlite::{params, Connection, Rows, Statement, Transaction}; use crate::{Db, Error, IDb, ITx, ITxFn, Result, TxError, TxFnResult, TxResult, Value, ValueIter}; pub use rusqlite; // --- err impl From for Error { fn from(e: rusqlite::Error) -> Error { Error(format!("Sqlite: {}", e).into()) } } impl From for TxError { fn from(e: rusqlite::Error) -> TxError { TxError::Db(e.into()) } } // -- db pub struct SqliteDb(Mutex); struct SqliteDbInner { db: Connection, trees: Vec, } impl SqliteDb { pub fn init(db: rusqlite::Connection) -> Db { let s = Self(Mutex::new(SqliteDbInner { db, trees: Vec::new(), })); Db(Arc::new(s)) } } impl SqliteDbInner { fn get_tree(&self, i: usize) -> Result<&'_ str> { self.trees .get(i) .map(String::as_str) .ok_or_else(|| Error("invalid tree id".into())) } fn internal_get(&self, tree: &str, key: &[u8]) -> Result> { let mut stmt = self .db .prepare(&format!("SELECT v FROM {} WHERE k = ?1", tree))?; let mut res_iter = stmt.query([key])?; match res_iter.next()? { None => Ok(None), Some(v) => Ok(Some(v.get::<_, Vec>(0)?)), } } } impl IDb for SqliteDb { fn open_tree(&self, name: &str) -> Result { let name = format!("tree_{}", name.replace(':', "_COLON_")); let mut this = self.0.lock().unwrap(); if let Some(i) = this.trees.iter().position(|x| x == &name) { Ok(i) } else { trace!("create table {}", name); this.db.execute( &format!( "CREATE TABLE IF NOT EXISTS {} ( k BLOB PRIMARY KEY, v BLOB )", name ), [], )?; trace!("table created: {}, unlocking", name); let i = this.trees.len(); this.trees.push(name.to_string()); Ok(i) } } fn list_trees(&self) -> Result> { let mut trees = vec![]; trace!("list_trees: lock db"); let this = self.0.lock().unwrap(); trace!("list_trees: lock acquired"); let mut stmt = this.db.prepare( "SELECT name FROM sqlite_schema WHERE type = 'table' AND name LIKE 'tree_%'", )?; let mut rows = stmt.query([])?; while let Some(row) = rows.next()? { let name = row.get::<_, String>(0)?; let name = name.replace("_COLON_", ":"); let name = name.strip_prefix("tree_").unwrap().to_string(); trees.push(name); } Ok(trees) } // ---- fn get(&self, tree: usize, key: &[u8]) -> Result> { trace!("get {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("get {}: lock acquired", tree); let tree = this.get_tree(tree)?; this.internal_get(tree, key) } fn remove(&self, tree: usize, key: &[u8]) -> Result { trace!("remove {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("remove {}: lock acquired", tree); let tree = this.get_tree(tree)?; let res = this .db .execute(&format!("DELETE FROM {} WHERE k = ?1", tree), params![key])?; Ok(res > 0) } fn len(&self, tree: usize) -> Result { trace!("len {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("len {}: lock acquired", tree); let tree = this.get_tree(tree)?; let mut stmt = this.db.prepare(&format!("SELECT COUNT(*) FROM {}", tree))?; let mut res_iter = stmt.query([])?; match res_iter.next()? { None => Ok(0), Some(v) => Ok(v.get::<_, usize>(0)?), } } fn insert(&self, tree: usize, key: &[u8], value: &[u8]) -> Result { trace!("insert {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("insert {}: lock acquired", tree); let tree = this.get_tree(tree)?; let old_val = this.internal_get(tree, key)?; this.db.execute( &format!("INSERT OR REPLACE INTO {} (k, v) VALUES (?1, ?2)", tree), params![key, value], )?; Ok(old_val.is_none()) } fn iter(&self, tree: usize) -> Result> { trace!("iter {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("iter {}: lock acquired", tree); let tree = this.get_tree(tree)?; let sql = format!("SELECT k, v FROM {} ORDER BY k ASC", tree); DbValueIterator::make(this, &sql, []) } fn iter_rev(&self, tree: usize) -> Result> { trace!("iter_rev {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("iter_rev {}: lock acquired", tree); let tree = this.get_tree(tree)?; let sql = format!("SELECT k, v FROM {} ORDER BY k DESC", tree); DbValueIterator::make(this, &sql, []) } fn range<'r>( &self, tree: usize, low: Bound<&'r [u8]>, high: Bound<&'r [u8]>, ) -> Result> { trace!("range {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("range {}: lock acquired", tree); let tree = this.get_tree(tree)?; let (bounds_sql, params) = bounds_sql(low, high); let sql = format!("SELECT k, v FROM {} {} ORDER BY k ASC", tree, bounds_sql); let params = params .iter() .map(|x| x as &dyn rusqlite::ToSql) .collect::>(); DbValueIterator::make::<&[&dyn rusqlite::ToSql]>(this, &sql, params.as_ref()) } fn range_rev<'r>( &self, tree: usize, low: Bound<&'r [u8]>, high: Bound<&'r [u8]>, ) -> Result> { trace!("range_rev {}: lock db", tree); let this = self.0.lock().unwrap(); trace!("range_rev {}: lock acquired", tree); let tree = this.get_tree(tree)?; let (bounds_sql, params) = bounds_sql(low, high); let sql = format!("SELECT k, v FROM {} {} ORDER BY k DESC", tree, bounds_sql); let params = params .iter() .map(|x| x as &dyn rusqlite::ToSql) .collect::>(); DbValueIterator::make::<&[&dyn rusqlite::ToSql]>(this, &sql, params.as_ref()) } // ---- fn transaction(&self, f: &dyn ITxFn) -> TxResult<(), ()> { trace!("transaction: lock db"); let mut this = self.0.lock().unwrap(); trace!("transaction: lock acquired"); let this_mut_ref: &mut SqliteDbInner = this.borrow_mut(); let mut tx = SqliteTx { tx: this_mut_ref.db.transaction()?, trees: &this_mut_ref.trees, }; let res = match f.try_on(&mut tx) { TxFnResult::Ok => { tx.tx.commit()?; Ok(()) } TxFnResult::Abort => { tx.tx.rollback()?; Err(TxError::Abort(())) } TxFnResult::DbErr => { tx.tx.rollback()?; Err(TxError::Db(Error( "(this message will be discarded)".into(), ))) } }; trace!("transaction done"); res } } // ---- struct SqliteTx<'a> { tx: Transaction<'a>, trees: &'a [String], } impl<'a> SqliteTx<'a> { fn get_tree(&self, i: usize) -> Result<&'_ str> { self.trees.get(i).map(String::as_ref).ok_or_else(|| { Error( "invalid tree id (it might have been openned after the transaction started)".into(), ) }) } fn internal_get(&self, tree: &str, key: &[u8]) -> Result> { let mut stmt = self .tx .prepare(&format!("SELECT v FROM {} WHERE k = ?1", tree))?; let mut res_iter = stmt.query([key])?; match res_iter.next()? { None => Ok(None), Some(v) => Ok(Some(v.get::<_, Vec>(0)?)), } } } impl<'a> ITx for SqliteTx<'a> { fn get(&self, tree: usize, key: &[u8]) -> Result> { let tree = self.get_tree(tree)?; self.internal_get(tree, key) } fn len(&self, tree: usize) -> Result { let tree = self.get_tree(tree)?; let mut stmt = self.tx.prepare(&format!("SELECT COUNT(*) FROM {}", tree))?; let mut res_iter = stmt.query([])?; match res_iter.next()? { None => Ok(0), Some(v) => Ok(v.get::<_, usize>(0)?), } } fn insert(&mut self, tree: usize, key: &[u8], value: &[u8]) -> Result { let tree = self.get_tree(tree)?; let old_val = self.internal_get(tree, key)?; self.tx.execute( &format!("INSERT OR REPLACE INTO {} (k, v) VALUES (?1, ?2)", tree), params![key, value], )?; Ok(old_val.is_none()) } fn remove(&mut self, tree: usize, key: &[u8]) -> Result { let tree = self.get_tree(tree)?; let res = self .tx .execute(&format!("DELETE FROM {} WHERE k = ?1", tree), params![key])?; Ok(res > 0) } fn iter(&self, _tree: usize) -> Result> { unimplemented!(); } fn iter_rev(&self, _tree: usize) -> Result> { unimplemented!(); } fn range<'r>( &self, _tree: usize, _low: Bound<&'r [u8]>, _high: Bound<&'r [u8]>, ) -> Result> { unimplemented!(); } fn range_rev<'r>( &self, _tree: usize, _low: Bound<&'r [u8]>, _high: Bound<&'r [u8]>, ) -> Result> { unimplemented!(); } } // ---- struct DbValueIterator<'a> { db: MutexGuard<'a, SqliteDbInner>, stmt: Option>, iter: Option>, _pin: PhantomPinned, } impl<'a> DbValueIterator<'a> { fn make( db: MutexGuard<'a, SqliteDbInner>, sql: &str, args: P, ) -> Result> { let res = DbValueIterator { db, stmt: None, iter: None, _pin: PhantomPinned, }; let mut boxed = Box::pin(res); trace!("make iterator with sql: {}", sql); unsafe { let db = NonNull::from(&boxed.db); let stmt = db.as_ref().db.prepare(sql)?; let mut_ref: Pin<&mut DbValueIterator<'a>> = Pin::as_mut(&mut boxed); Pin::get_unchecked_mut(mut_ref).stmt = Some(stmt); let mut stmt = NonNull::from(&boxed.stmt); let iter = stmt.as_mut().as_mut().unwrap().query(args)?; let mut_ref: Pin<&mut DbValueIterator<'a>> = Pin::as_mut(&mut boxed); Pin::get_unchecked_mut(mut_ref).iter = Some(iter); } Ok(Box::new(DbValueIteratorPin(boxed))) } } impl<'a> Drop for DbValueIterator<'a> { fn drop(&mut self) { trace!("drop iter"); drop(self.iter.take()); drop(self.stmt.take()); } } struct DbValueIteratorPin<'a>(Pin>>); impl<'a> Iterator for DbValueIteratorPin<'a> { type Item = Result<(Value, Value)>; fn next(&mut self) -> Option { let next = unsafe { let mut_ref: Pin<&mut DbValueIterator<'a>> = Pin::as_mut(&mut self.0); Pin::get_unchecked_mut(mut_ref).iter.as_mut()?.next() }; let row = match next { Err(e) => return Some(Err(e.into())), Ok(None) => return None, Ok(Some(r)) => r, }; let k = match row.get::<_, Vec>(0) { Err(e) => return Some(Err(e.into())), Ok(x) => x, }; let v = match row.get::<_, Vec>(1) { Err(e) => return Some(Err(e.into())), Ok(y) => y, }; Some(Ok((k, v))) } } // ---- fn bounds_sql<'r>(low: Bound<&'r [u8]>, high: Bound<&'r [u8]>) -> (String, Vec>) { let mut sql = String::new(); let mut params: Vec> = vec![]; match low { Bound::Included(b) => { sql.push_str(" WHERE k >= ?1"); params.push(b.to_vec()); } Bound::Excluded(b) => { sql.push_str(" WHERE k > ?1"); params.push(b.to_vec()); } Bound::Unbounded => (), }; match high { Bound::Included(b) => { if !params.is_empty() { sql.push_str(" AND k <= ?2"); } else { sql.push_str(" WHERE k <= ?1"); } params.push(b.to_vec()); } Bound::Excluded(b) => { if !params.is_empty() { sql.push_str(" AND k < ?2"); } else { sql.push_str(" WHERE k < ?1"); } params.push(b.to_vec()); } Bound::Unbounded => (), } (sql, params) }