From 4ea36f78b7f8f80ed819cf269b78ef6d80c509b0 Mon Sep 17 00:00:00 2001 From: NexVeridian Date: Sat, 16 Dec 2023 03:56:40 -0800 Subject: [PATCH] multi thread --- Cargo.toml | 1 + README.md | 1 + src/main.rs | 100 ++++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 49536bb..dc5c8cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" surrealdb = "1.0" tokio = "1.35" +futures = "0.3" wikidata = "0.3.1" bzip2 = { version = "0.4", features = ["tokio"] } lazy_static = "1.4" diff --git a/README.md b/README.md index d55973c..81ef76c 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ FILE_FORMAT=bz2 FILE_NAME=data/latest-all.json.bz2 # If not using docker file for Wikidata to SurrealDB, use 0.0.0.0:8000 WIKIDATA_DB_PORT=surrealdb:8000 +THREADED_REQUESTS=true ``` ## View Progress diff --git a/src/main.rs b/src/main.rs index 6d17533..920d489 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use anyhow::{Error, Ok, Result}; use bzip2::read::MultiBzDecoder; +use futures::future::join_all; use indicatif::{ProgressBar, ProgressState, ProgressStyle}; use lazy_static::lazy_static; use serde_json::{from_str, Value}; @@ -11,7 +12,11 @@ use std::{ thread, time::Duration, }; -use surrealdb::{engine::remote::ws::Ws, opt::auth::Root, Surreal}; +use surrealdb::{ + engine::remote::ws::{Client, Ws}, + opt::auth::Root, + Surreal, +}; use wikidata::Entity; mod utils; @@ -24,6 +29,7 @@ lazy_static! { static ref WIKIDATA_FILE_FORMAT: String = env::var("WIKIDATA_FILE_FORMAT").expect("FILE_FORMAT not set"); static ref WIKIDATA_FILE_NAME: String = env::var("WIKIDATA_FILE_NAME").expect("FILE_NAME not set"); static ref WIKIDATA_DB_PORT: String = env::var("WIKIDATA_DB_PORT").expect("WIKIDATA_DB_PORT not set"); + static ref THREADED_REQUESTS: bool = env::var("THREADED_REQUESTS").expect("THREADED_REQUESTS not set").parse().expect("Failed to parse THREADED_REQUESTS"); } #[allow(non_camel_case_types)] @@ -48,17 +54,45 @@ impl File_Format { } } +async fn create_db_entity(db: &Surreal, line: String) -> Result<(), Error> { + let line = line.trim().trim_end_matches(',').to_string(); + if line == "[" || line == "]" { + return Ok(()); + } + + let json: Value = from_str(&line)?; + let data = Entity::from_json(json).expect("Failed to parse JSON"); + + let (mut claims, mut data) = EntityMini::from_entity(data); + + let id = data.id.clone().expect("No ID"); + data.id = None; + let _: Option = db.delete(&id).await?; + let _: Option = db.create(&id).content(data.clone()).await?; + + let id = claims.id.clone().expect("No ID"); + claims.id = None; + let _: Option = db.delete(&id).await?; + let _: Option = db.create(&id).content(claims).await?; + Ok(()) +} + +async fn create_db_entities(db: &Surreal, lines: Vec) -> Result<(), Error> { + for line in lines { + create_db_entity(db, line.to_string()).await?; + } + Ok(()) +} + #[tokio::main] async fn main() -> Result<(), Error> { thread::sleep(Duration::from_secs(10)); - - let mut compleated = 0; let total_size = 113_000_000; let pb = ProgressBar::new(total_size); pb.set_style( ProgressStyle::with_template( - "[{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} {percent} ETA:{eta}", + "[{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ETA:[{eta}]", )? .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { let sec = state.eta().as_secs(); @@ -79,31 +113,43 @@ async fn main() -> Result<(), Error> { let reader = File_Format::new(&WIKIDATA_FILE_FORMAT).reader(&WIKIDATA_FILE_NAME)?; - for line in reader.lines() { - let line = line?.trim().trim_end_matches(',').to_string(); - if line == "[" || line == "]" { - continue; + if !*THREADED_REQUESTS { + let counter = 0; + for line in reader.lines() { + create_db_entity(&db, line?).await?; + if counter % 100 == 0 { + pb.inc(100); + } + } + } else { + let mut futures = Vec::new(); + let mut chunk = Vec::new(); + let mut chunk_counter: i32 = 0; + const BATCH_AMMOUNT: u16 = 50; + + for line in reader.lines() { + chunk.push(line.unwrap()); + + if chunk.len() >= BATCH_AMMOUNT.try_into().unwrap() { + let db = db.clone(); + let lines = chunk.clone(); + let pb = pb.clone(); + + futures.push(tokio::spawn(async move { + create_db_entities(&db, lines).await.unwrap(); + pb.inc(BATCH_AMMOUNT.try_into().unwrap()); + })); + chunk_counter += 1; + chunk.clear(); + } + + if chunk_counter >= 50 { + join_all(futures).await; + futures = Vec::new(); + } } - let json: Value = from_str(&line)?; - let data = Entity::from_json(json).expect("Failed to parse JSON"); - - let (mut claims, mut data) = EntityMini::from_entity(data); - - let id = data.id.clone().expect("No ID"); - data.id = None; - let _: Option = db.delete(&id).await?; - let _: Option = db.create(&id).content(data.clone()).await?; - - let id = claims.id.clone().expect("No ID"); - claims.id = None; - let _: Option = db.delete(&id).await?; - let _: Option = db.create(&id).content(claims).await?; - - compleated += 1; - if compleated % 1000 == 0 { - pb.set_position(compleated); - } + join_all(futures).await; } pb.finish_with_message("Done parsing Wikidata");