refactor: reader and rename run

This commit is contained in:
Elijah McMorris 2024-08-28 14:25:17 -07:00
parent bb9967ced6
commit acef3f8f3b
Signed by: NexVeridian
SSH key fingerprint: SHA256:bsA1SKZxuEcEVHAy3gY1HUeM5ykRJl0U0kQHQn0hMg8
7 changed files with 72 additions and 67 deletions

16
Cargo.lock generated
View file

@ -683,9 +683,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]] [[package]]
name = "cpp_demangle" name = "cpp_demangle"
version = "0.4.3" version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
] ]
@ -2985,9 +2985,9 @@ dependencies = [
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.4.0" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [ dependencies = [
"semver", "semver",
] ]
@ -3545,9 +3545,9 @@ dependencies = [
[[package]] [[package]]
name = "symbolic-common" name = "symbolic-common"
version = "12.10.0" version = "12.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16629323a4ec5268ad23a575110a724ad4544aae623451de600c747bf87b36cf" checksum = "b1944ea8afd197111bca0c0edea1e1f56abb3edd030e240c1035cc0e3ff51fec"
dependencies = [ dependencies = [
"debugid", "debugid",
"memmap2", "memmap2",
@ -3557,9 +3557,9 @@ dependencies = [
[[package]] [[package]]
name = "symbolic-demangle" name = "symbolic-demangle"
version = "12.10.0" version = "12.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c043a45f08f41187414592b3ceb53fb0687da57209cc77401767fb69d5b596" checksum = "ddaccaf1bf8e73c4f64f78dbb30aadd6965c71faa4ff3fba33f8d7296cf94a87"
dependencies = [ dependencies = [
"cpp_demangle", "cpp_demangle",
"rustc-demangle", "rustc-demangle",

View file

@ -5,6 +5,7 @@ use std::{env, time::Duration};
use surrealdb::{engine::local::Db, Surreal}; use surrealdb::{engine::local::Db, Surreal};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use init_reader::File_Format;
use wikidata_to_surrealdb::utils::*; use wikidata_to_surrealdb::utils::*;
async fn inti_db() -> Result<Surreal<Db>, Error> { async fn inti_db() -> Result<Surreal<Db>, Error> {
@ -29,7 +30,7 @@ fn bench(c: &mut Criterion) {
.unwrap(); .unwrap();
CreateVersion::Single CreateVersion::Single
.run_threaded(Some(db.clone()), reader, None, 1000, 100) .run(Some(db.clone()), reader, None, 1000, 100)
.await .await
.unwrap(); .unwrap();
}) })
@ -46,7 +47,7 @@ fn bench(c: &mut Criterion) {
.unwrap(); .unwrap();
CreateVersion::Bulk CreateVersion::Bulk
.run_threaded(Some(db.clone()), reader, None, 1000, 100) .run(Some(db.clone()), reader, None, 1000, 100)
.await .await
.unwrap(); .unwrap();
}) })

6
flake.lock generated
View file

@ -44,11 +44,11 @@
"rust-analyzer-src": [] "rust-analyzer-src": []
}, },
"locked": { "locked": {
"lastModified": 1724740262, "lastModified": 1724826636,
"narHash": "sha256-cpFasbzOTlwLi4fNas6hDznVUdCJn/lMLxi7MAMG6hg=", "narHash": "sha256-hz8Szf5J9oQg6EeMhHE/eKuexoHPiDbmOZTPvijYwyM=",
"owner": "nix-community", "owner": "nix-community",
"repo": "fenix", "repo": "fenix",
"rev": "703efdd9b5c6a7d5824afa348a24fbbf8ff226be", "rev": "3454a665ff4dd29cf618e6a2e53065370876297f",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -5,6 +5,7 @@ use surrealdb::{engine::remote::ws::Client, Surreal};
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
mod utils; mod utils;
use init_reader::File_Format;
use utils::*; use utils::*;
lazy_static! { lazy_static! {
@ -30,7 +31,7 @@ async fn main() -> Result<(), Error> {
let reader = File_Format::new(&WIKIDATA_FILE_FORMAT).reader(&WIKIDATA_FILE_NAME)?; let reader = File_Format::new(&WIKIDATA_FILE_FORMAT).reader(&WIKIDATA_FILE_NAME)?;
CREATE_VERSION CREATE_VERSION
.run_threaded( .run(
None::<Surreal<Client>>, None::<Surreal<Client>>,
reader, reader,
Some(pb.clone()), Some(pb.clone()),

View file

@ -1,22 +1,17 @@
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use bzip2::read::MultiBzDecoder;
use core::panic;
use futures::future::join_all; use futures::future::join_all;
use indicatif::ProgressBar; use indicatif::ProgressBar;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use serde_json::{from_str, Value}; use serde_json::{from_str, Value};
use std::{ use std::{env, io::BufRead};
env,
fs::File,
io::{BufRead, BufReader},
};
use surrealdb::{Connection, Surreal}; use surrealdb::{Connection, Surreal};
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use wikidata::Entity; use wikidata::Entity;
pub mod init_db; pub mod init_db;
pub mod init_progress_bar; pub mod init_progress_bar;
pub mod init_reader;
mod tables; mod tables;
use tables::*; use tables::*;
@ -29,28 +24,6 @@ lazy_static! {
env::var("FILTER_PATH").unwrap_or("../filter.surql".to_string()); env::var("FILTER_PATH").unwrap_or("../filter.surql".to_string());
} }
#[allow(non_camel_case_types)]
pub enum File_Format {
json,
bz2,
}
impl File_Format {
pub fn new(file: &str) -> Self {
match file {
"json" => Self::json,
"bz2" => Self::bz2,
_ => panic!("Unknown file format"),
}
}
pub fn reader(self, file: &str) -> Result<Box<dyn BufRead>, Error> {
let file = File::open(file)?;
match self {
File_Format::json => Ok(Box::new(BufReader::new(file))),
File_Format::bz2 => Ok(Box::new(BufReader::new(MultiBzDecoder::new(file)))),
}
}
}
pub async fn create_entity(db: &Surreal<impl Connection>, line: &str) -> Result<(), Error> { pub async fn create_entity(db: &Surreal<impl Connection>, line: &str) -> Result<(), Error> {
let line = line.trim().trim_end_matches(',').to_string(); let line = line.trim().trim_end_matches(',').to_string();
if line == "[" || line == "]" { if line == "[" || line == "]" {
@ -84,28 +57,12 @@ pub enum CreateVersion {
/// must create a filter.surql file in the root directory /// must create a filter.surql file in the root directory
BulkFilter, BulkFilter,
} }
impl CreateVersion { impl CreateVersion {
pub async fn run( pub async fn run(
self, self,
db: &Surreal<impl Connection>, dbo: Option<Surreal<impl Connection>>, // None::<Surreal<Client>>
chunk: &[String], reader: Box<dyn BufRead>,
pb: &Option<ProgressBar>,
batch_size: usize,
) -> bool {
match self {
CreateVersion::Single => self.create_single(db, chunk, pb).await.is_ok(),
CreateVersion::Bulk => self.create_bulk(db, chunk, pb, batch_size).await.is_ok(),
CreateVersion::BulkFilter => self
.create_bulk_filter(db, chunk, pb, batch_size)
.await
.is_ok(),
}
}
pub async fn run_threaded(
self,
dbo: Option<Surreal<impl Connection>>,
reader: Box<dyn BufRead>, // None::<Surreal<Client>>
pb: Option<ProgressBar>, pb: Option<ProgressBar>,
batch_size: usize, batch_size: usize,
batch_num: usize, batch_num: usize,
@ -146,7 +103,7 @@ impl CreateVersion {
loop { loop {
match dbo { match dbo {
Some(ref db) => { Some(ref db) => {
if create_version.run(db, &chunk, &pb, batch_size).await { if create_version.create(db, &chunk, &pb, batch_size).await {
break; break;
} }
} }
@ -155,7 +112,7 @@ impl CreateVersion {
Ok(db) => db, Ok(db) => db,
Err(_) => continue, Err(_) => continue,
}; };
if create_version.run(&db, &chunk, &pb, batch_size).await { if create_version.create(&db, &chunk, &pb, batch_size).await {
break; break;
} }
} }
@ -170,6 +127,23 @@ impl CreateVersion {
}) })
} }
async fn create(
self,
db: &Surreal<impl Connection>,
chunk: &[String],
pb: &Option<ProgressBar>,
batch_size: usize,
) -> bool {
match self {
CreateVersion::Single => self.create_single(db, chunk, pb).await.is_ok(),
CreateVersion::Bulk => self.create_bulk(db, chunk, pb, batch_size).await.is_ok(),
CreateVersion::BulkFilter => self
.create_bulk_filter(db, chunk, pb, batch_size)
.await
.is_ok(),
}
}
async fn create_single( async fn create_single(
self, self,
db: &Surreal<impl Connection>, db: &Surreal<impl Connection>,

28
src/utils/init_reader.rs Normal file
View file

@ -0,0 +1,28 @@
use anyhow::{Error, Result};
use bzip2::read::MultiBzDecoder;
use std::{
fs::File,
io::{BufRead, BufReader},
};
#[allow(non_camel_case_types)]
pub enum File_Format {
json,
bz2,
}
impl File_Format {
pub fn new(file: &str) -> Self {
match file {
"json" => Self::json,
"bz2" => Self::bz2,
_ => panic!("Unknown file format"),
}
}
pub fn reader(self, file: &str) -> Result<Box<dyn BufRead>, Error> {
let file = File::open(file)?;
match self {
File_Format::json => Ok(Box::new(BufReader::new(file))),
File_Format::bz2 => Ok(Box::new(BufReader::new(MultiBzDecoder::new(file)))),
}
}
}

View file

@ -3,6 +3,7 @@ use rstest::rstest;
use std::{env, io::BufRead}; use std::{env, io::BufRead};
use surrealdb::{engine::local::Db, Surreal}; use surrealdb::{engine::local::Db, Surreal};
use init_reader::File_Format;
use wikidata_to_surrealdb::utils::*; use wikidata_to_surrealdb::utils::*;
async fn inti_db() -> Result<Surreal<Db>, Error> { async fn inti_db() -> Result<Surreal<Db>, Error> {
@ -57,7 +58,7 @@ async fn entity_threaded(#[case] version: CreateVersion) -> Result<(), Error> {
let reader = init_reader("json", "Entity"); let reader = init_reader("json", "Entity");
version version
.run_threaded(Some(db.clone()), reader, None, 1_000, 100) .run(Some(db.clone()), reader, None, 1_000, 100)
.await?; .await?;
assert_eq!(51.0, entity_query(&db).await?.unwrap()); assert_eq!(51.0, entity_query(&db).await?.unwrap());
@ -71,7 +72,7 @@ async fn entity_threaded_filter() -> Result<(), Error> {
let reader = init_reader("json", "bench"); let reader = init_reader("json", "bench");
CreateVersion::BulkFilter CreateVersion::BulkFilter
.run_threaded(Some(db.clone()), reader, None, 1_000, 100) .run(Some(db.clone()), reader, None, 1_000, 100)
.await?; .await?;
let count: Option<f32> = db let count: Option<f32> = db
@ -116,7 +117,7 @@ async fn property_threaded(#[case] version: CreateVersion) -> Result<(), Error>
let reader = init_reader("json", "Property"); let reader = init_reader("json", "Property");
version version
.run_threaded(Some(db.clone()), reader, None, 1_000, 100) .run(Some(db.clone()), reader, None, 1_000, 100)
.await?; .await?;
assert_eq!(2.0, property_query(&db).await?.unwrap()); assert_eq!(2.0, property_query(&db).await?.unwrap());