From acef3f8f3b8a22c465a35443597f927d028a05f5 Mon Sep 17 00:00:00 2001 From: NexVeridian Date: Wed, 28 Aug 2024 14:25:17 -0700 Subject: [PATCH] refactor: reader and rename run --- Cargo.lock | 16 ++++----- benches/bench.rs | 5 +-- flake.lock | 6 ++-- src/main.rs | 3 +- src/utils.rs | 74 +++++++++++++--------------------------- src/utils/init_reader.rs | 28 +++++++++++++++ tests/integration.rs | 7 ++-- 7 files changed, 72 insertions(+), 67 deletions(-) create mode 100644 src/utils/init_reader.rs diff --git a/Cargo.lock b/Cargo.lock index 7408b9f..6f462ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -683,9 +683,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpp_demangle" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" +checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d" dependencies = [ "cfg-if", ] @@ -2985,9 +2985,9 @@ dependencies = [ [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] @@ -3545,9 +3545,9 @@ dependencies = [ [[package]] name = "symbolic-common" -version = "12.10.0" +version = "12.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16629323a4ec5268ad23a575110a724ad4544aae623451de600c747bf87b36cf" +checksum = "b1944ea8afd197111bca0c0edea1e1f56abb3edd030e240c1035cc0e3ff51fec" dependencies = [ "debugid", "memmap2", @@ -3557,9 +3557,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.10.0" +version = "12.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c043a45f08f41187414592b3ceb53fb0687da57209cc77401767fb69d5b596" +checksum = "ddaccaf1bf8e73c4f64f78dbb30aadd6965c71faa4ff3fba33f8d7296cf94a87" dependencies = [ "cpp_demangle", "rustc-demangle", diff --git a/benches/bench.rs b/benches/bench.rs index c54e1a5..c95f4cd 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -5,6 +5,7 @@ use std::{env, time::Duration}; use surrealdb::{engine::local::Db, Surreal}; use tokio::runtime::Runtime; +use init_reader::File_Format; use wikidata_to_surrealdb::utils::*; async fn inti_db() -> Result, Error> { @@ -29,7 +30,7 @@ fn bench(c: &mut Criterion) { .unwrap(); CreateVersion::Single - .run_threaded(Some(db.clone()), reader, None, 1000, 100) + .run(Some(db.clone()), reader, None, 1000, 100) .await .unwrap(); }) @@ -46,7 +47,7 @@ fn bench(c: &mut Criterion) { .unwrap(); CreateVersion::Bulk - .run_threaded(Some(db.clone()), reader, None, 1000, 100) + .run(Some(db.clone()), reader, None, 1000, 100) .await .unwrap(); }) diff --git a/flake.lock b/flake.lock index f265288..9cf1981 100644 --- a/flake.lock +++ b/flake.lock @@ -44,11 +44,11 @@ "rust-analyzer-src": [] }, "locked": { - "lastModified": 1724740262, - "narHash": "sha256-cpFasbzOTlwLi4fNas6hDznVUdCJn/lMLxi7MAMG6hg=", + "lastModified": 1724826636, + "narHash": "sha256-hz8Szf5J9oQg6EeMhHE/eKuexoHPiDbmOZTPvijYwyM=", "owner": "nix-community", "repo": "fenix", - "rev": "703efdd9b5c6a7d5824afa348a24fbbf8ff226be", + "rev": "3454a665ff4dd29cf618e6a2e53065370876297f", "type": "github" }, "original": { diff --git a/src/main.rs b/src/main.rs index 23cfb4a..30f8f1f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use surrealdb::{engine::remote::ws::Client, Surreal}; use tokio::time::{sleep, Duration}; mod utils; +use init_reader::File_Format; use utils::*; lazy_static! { @@ -30,7 +31,7 @@ async fn main() -> Result<(), Error> { let reader = File_Format::new(&WIKIDATA_FILE_FORMAT).reader(&WIKIDATA_FILE_NAME)?; CREATE_VERSION - .run_threaded( + .run( None::>, reader, Some(pb.clone()), diff --git a/src/utils.rs b/src/utils.rs index f1273bc..50ca0fd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,22 +1,17 @@ use anyhow::{Error, Result}; -use bzip2::read::MultiBzDecoder; -use core::panic; use futures::future::join_all; use indicatif::ProgressBar; use lazy_static::lazy_static; use rand::{distributions::Alphanumeric, Rng}; use serde_json::{from_str, Value}; -use std::{ - env, - fs::File, - io::{BufRead, BufReader}, -}; +use std::{env, io::BufRead}; use surrealdb::{Connection, Surreal}; use tokio::time::{sleep, Duration}; use wikidata::Entity; pub mod init_db; pub mod init_progress_bar; +pub mod init_reader; mod tables; use tables::*; @@ -29,28 +24,6 @@ lazy_static! { env::var("FILTER_PATH").unwrap_or("../filter.surql".to_string()); } -#[allow(non_camel_case_types)] -pub enum File_Format { - json, - bz2, -} -impl File_Format { - pub fn new(file: &str) -> Self { - match file { - "json" => Self::json, - "bz2" => Self::bz2, - _ => panic!("Unknown file format"), - } - } - pub fn reader(self, file: &str) -> Result, Error> { - let file = File::open(file)?; - match self { - File_Format::json => Ok(Box::new(BufReader::new(file))), - File_Format::bz2 => Ok(Box::new(BufReader::new(MultiBzDecoder::new(file)))), - } - } -} - pub async fn create_entity(db: &Surreal, line: &str) -> Result<(), Error> { let line = line.trim().trim_end_matches(',').to_string(); if line == "[" || line == "]" { @@ -84,28 +57,12 @@ pub enum CreateVersion { /// must create a filter.surql file in the root directory BulkFilter, } + impl CreateVersion { pub async fn run( self, - db: &Surreal, - chunk: &[String], - pb: &Option, - batch_size: usize, - ) -> bool { - match self { - CreateVersion::Single => self.create_single(db, chunk, pb).await.is_ok(), - CreateVersion::Bulk => self.create_bulk(db, chunk, pb, batch_size).await.is_ok(), - CreateVersion::BulkFilter => self - .create_bulk_filter(db, chunk, pb, batch_size) - .await - .is_ok(), - } - } - - pub async fn run_threaded( - self, - dbo: Option>, - reader: Box, // None::> + dbo: Option>, // None::> + reader: Box, pb: Option, batch_size: usize, batch_num: usize, @@ -146,7 +103,7 @@ impl CreateVersion { loop { match dbo { Some(ref db) => { - if create_version.run(db, &chunk, &pb, batch_size).await { + if create_version.create(db, &chunk, &pb, batch_size).await { break; } } @@ -155,7 +112,7 @@ impl CreateVersion { Ok(db) => db, Err(_) => continue, }; - if create_version.run(&db, &chunk, &pb, batch_size).await { + if create_version.create(&db, &chunk, &pb, batch_size).await { break; } } @@ -170,6 +127,23 @@ impl CreateVersion { }) } + async fn create( + self, + db: &Surreal, + chunk: &[String], + pb: &Option, + batch_size: usize, + ) -> bool { + match self { + CreateVersion::Single => self.create_single(db, chunk, pb).await.is_ok(), + CreateVersion::Bulk => self.create_bulk(db, chunk, pb, batch_size).await.is_ok(), + CreateVersion::BulkFilter => self + .create_bulk_filter(db, chunk, pb, batch_size) + .await + .is_ok(), + } + } + async fn create_single( self, db: &Surreal, diff --git a/src/utils/init_reader.rs b/src/utils/init_reader.rs new file mode 100644 index 0000000..d20a102 --- /dev/null +++ b/src/utils/init_reader.rs @@ -0,0 +1,28 @@ +use anyhow::{Error, Result}; +use bzip2::read::MultiBzDecoder; +use std::{ + fs::File, + io::{BufRead, BufReader}, +}; + +#[allow(non_camel_case_types)] +pub enum File_Format { + json, + bz2, +} +impl File_Format { + pub fn new(file: &str) -> Self { + match file { + "json" => Self::json, + "bz2" => Self::bz2, + _ => panic!("Unknown file format"), + } + } + pub fn reader(self, file: &str) -> Result, Error> { + let file = File::open(file)?; + match self { + File_Format::json => Ok(Box::new(BufReader::new(file))), + File_Format::bz2 => Ok(Box::new(BufReader::new(MultiBzDecoder::new(file)))), + } + } +} diff --git a/tests/integration.rs b/tests/integration.rs index bd46d51..97b8240 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -3,6 +3,7 @@ use rstest::rstest; use std::{env, io::BufRead}; use surrealdb::{engine::local::Db, Surreal}; +use init_reader::File_Format; use wikidata_to_surrealdb::utils::*; async fn inti_db() -> Result, Error> { @@ -57,7 +58,7 @@ async fn entity_threaded(#[case] version: CreateVersion) -> Result<(), Error> { let reader = init_reader("json", "Entity"); version - .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .run(Some(db.clone()), reader, None, 1_000, 100) .await?; assert_eq!(51.0, entity_query(&db).await?.unwrap()); @@ -71,7 +72,7 @@ async fn entity_threaded_filter() -> Result<(), Error> { let reader = init_reader("json", "bench"); CreateVersion::BulkFilter - .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .run(Some(db.clone()), reader, None, 1_000, 100) .await?; let count: Option = db @@ -116,7 +117,7 @@ async fn property_threaded(#[case] version: CreateVersion) -> Result<(), Error> let reader = init_reader("json", "Property"); version - .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .run(Some(db.clone()), reader, None, 1_000, 100) .await?; assert_eq!(2.0, property_query(&db).await?.unwrap());