diff --git a/benches/bench.rs b/benches/bench.rs index 817b5cf..c54e1a5 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -28,16 +28,10 @@ fn bench(c: &mut Criterion) { .reader("tests/data/bench.json") .unwrap(); - create_db_entities_threaded( - Some(db.clone()), - reader, - None, - 1000, - 100, - CreateVersion::Single, - ) - .await - .unwrap(); + CreateVersion::Single + .run_threaded(Some(db.clone()), reader, None, 1000, 100) + .await + .unwrap(); }) }) }); @@ -51,16 +45,10 @@ fn bench(c: &mut Criterion) { .reader("tests/data/bench.json") .unwrap(); - create_db_entities_threaded( - Some(db.clone()), - reader, - None, - 1000, - 100, - CreateVersion::Bulk, - ) - .await - .unwrap(); + CreateVersion::Bulk + .run_threaded(Some(db.clone()), reader, None, 1000, 100) + .await + .unwrap(); }) }) }); diff --git a/src/main.rs b/src/main.rs index 848225d..de069bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,7 +49,7 @@ async fn main() -> Result<(), Error> { let line = line?; loop { - if create_db_entity(&db, &line).await.is_ok() { + if create_entity(&db, &line).await.is_ok() { break; } if retries >= 60 * 10 { @@ -69,37 +69,37 @@ async fn main() -> Result<(), Error> { } } CreateMode::ThreadedSingle => { - create_db_entities_threaded( - None::>, - reader, - Some(pb.clone()), - 2_500, - 100, - CreateVersion::Single, - ) - .await?; + CreateVersion::Single + .run_threaded( + None::>, + reader, + Some(pb.clone()), + 2_500, + 100, + ) + .await?; } CreateMode::ThreadedBulk => { - create_db_entities_threaded( - None::>, - reader, - Some(pb.clone()), - 500, - 1_000, - CreateVersion::Bulk, - ) - .await?; + CreateVersion::Bulk + .run_threaded( + None::>, + reader, + Some(pb.clone()), + 500, + 1_000, + ) + .await?; } CreateMode::ThreadedBulkFilter => { - create_db_entities_threaded( - None::>, - reader, - Some(pb.clone()), - 500, - 1_000, - CreateVersion::BulkFilter, - ) - .await?; + CreateVersion::BulkFilter + .run_threaded( + None::>, + reader, + Some(pb.clone()), + 500, + 1_000, + ) + .await?; } } diff --git a/src/utils.rs b/src/utils.rs index 366554a..4355d64 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -51,7 +51,7 @@ impl File_Format { } } -pub async fn create_db_entity(db: &Surreal, line: &str) -> Result<(), Error> { +pub async fn create_entity(db: &Surreal, line: &str) -> Result<(), Error> { let line = line.trim().trim_end_matches(',').to_string(); if line == "[" || line == "]" { return Ok(()); @@ -76,105 +76,6 @@ pub async fn create_db_entity(db: &Surreal, line: &str) -> Resu Ok(()) } -pub async fn create_db_entities( - db: &Surreal, - lines: &[String], - pb: &Option, -) -> Result<(), Error> { - let mut counter = 0; - for line in lines { - create_db_entity(db, line).await?; - counter += 1; - if counter % 100 == 0 { - if let Some(ref p) = pb { - p.inc(100) - } - } - } - Ok(()) -} - -pub async fn create_db_entities_bulk( - db: &Surreal, - lines: &[String], - pb: &Option, - batch_size: usize, -) -> Result<(), Error> { - let lines = lines - .iter() - .map(|line| line.trim().trim_end_matches(',').to_string()) - .filter(|line| line != "[" && line != "]") - .collect::>(); - - let mut data_vec: Vec = Vec::with_capacity(batch_size); - let mut claims_vec: Vec = Vec::with_capacity(batch_size); - let mut property_vec: Vec = Vec::with_capacity(batch_size); - let mut lexeme_vec: Vec = Vec::with_capacity(batch_size); - - for line in lines { - let json: Value = from_str(&line).expect("Failed to parse JSON"); - let data = Entity::from_json(json).expect("Failed to parse JSON"); - let (claims, data) = EntityMini::from_entity(data); - match data.id.clone().expect("No ID").tb.as_str() { - "Property" => property_vec.push(data), - "Lexeme" => lexeme_vec.push(data), - "Entity" => data_vec.push(data), - _ => panic!("Unknown table"), - } - claims_vec.push(claims); - } - - db.insert::>("Entity") - .content(data_vec) - .await?; - db.insert::>("Claims") - .content(claims_vec) - .await?; - db.insert::>("Property") - .content(property_vec) - .await?; - db.insert::>("Lexeme") - .content(lexeme_vec) - .await?; - - if let Some(ref p) = pb { - p.inc(batch_size as u64) - } - Ok(()) -} - -pub async fn create_db_entities_bulk_filter( - db: &Surreal, - lines: &[String], - pb: &Option, - batch_size: usize, -) -> Result<(), Error> { - let db_mem = init_db::create_db_mem().await?; - create_db_entities_bulk(&db_mem, lines, &None, batch_size).await?; - - let filter = tokio::fs::read_to_string(&*FILTER_PATH).await?; - db_mem.query(filter).await?; - - let file_name: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(30) - .map(char::from) - .collect(); - - let file_path = format!("data/temp/{}.surql", file_name); - tokio::fs::create_dir_all("data/temp").await?; - - db_mem.export(&file_path).await?; - db.import(&file_path).await?; - - tokio::fs::remove_file(&file_path).await?; - - if let Some(ref p) = pb { - p.inc(batch_size as u64) - } - Ok(()) -} - #[derive(Clone, Copy)] pub enum CreateVersion { Single, @@ -191,95 +92,182 @@ impl CreateVersion { batch_size: usize, ) -> bool { match self { - CreateVersion::Single => create_db_entities(db, chunk, pb).await.is_ok(), - CreateVersion::Bulk => create_db_entities_bulk(db, chunk, pb, batch_size) - .await - .is_ok(), - CreateVersion::BulkFilter => create_db_entities_bulk_filter(db, chunk, pb, batch_size) + CreateVersion::Single => self.create_single(db, chunk, pb).await.is_ok(), + CreateVersion::Bulk => self.create_bulk(db, chunk, pb, batch_size).await.is_ok(), + CreateVersion::BulkFilter => self + .create_bulk_filter(db, chunk, pb, batch_size) .await .is_ok(), } } -} -pub async fn create_db_entities_threaded( - dbo: Option>, // None::> - reader: Box, - pb: Option, - batch_size: usize, - batch_num: usize, - create_version: CreateVersion, -) -> Result<(), Error> { - let mut futures = Vec::new(); - let mut chunk = Vec::with_capacity(batch_size); - let mut chunk_counter = 0; - let mut lines = reader.lines(); - let mut last_loop = false; + pub async fn run_threaded( + self, + dbo: Option>, + reader: Box, // None::> + pb: Option, + batch_size: usize, + batch_num: usize, + ) -> Result<(), Error> { + let mut lines = reader.lines().peekable(); + let mut futures = Vec::new(); - loop { - let line = lines.next(); - match line { - Some(line) => chunk.push(line?), - None => last_loop = true, - }; + while lines.peek().is_some() { + let chunk: Vec = lines + .by_ref() + .take(batch_size) + .filter_map(Result::ok) + .collect(); - if chunk.len() >= batch_size || last_loop { - let dbo = dbo.clone(); - let pb = pb.clone(); + futures.push(self.spawn_chunk(dbo.clone(), chunk, pb.clone(), batch_size)); - futures.push(tokio::spawn(async move { - let mut retries = 0; - loop { - match dbo { - Some(ref db) => { - if create_version.run(db, &chunk, &pb, batch_size).await { - break; - } - if db.use_ns("wikidata").use_db("wikidata").await.is_err() { - continue; - }; - } - None => { - let db = if let Ok(db) = init_db::create_db_ws().await { - db - } else { - continue; - }; - if create_version.run(&db, &chunk, &pb, batch_size).await { - break; - } + if futures.len() >= batch_num { + join_all(futures).await; + futures = Vec::new(); + } + } + + join_all(futures).await; + Ok(()) + } + + fn spawn_chunk( + &self, + dbo: Option>, + chunk: Vec, + pb: Option, + batch_size: usize, + ) -> tokio::task::JoinHandle<()> { + let create_version = *self; + + tokio::spawn(async move { + let mut retries = 0; + loop { + match dbo { + Some(ref db) => { + if create_version.run(db, &chunk, &pb, batch_size).await { + break; } } - - if retries >= 60 * 10 { - panic!("Failed to create entities, too many retries"); + None => { + let db = match init_db::create_db_ws().await { + Ok(db) => db, + Err(_) => continue, + }; + if create_version.run(&db, &chunk, &pb, batch_size).await { + break; + } } - retries += 1; - sleep(Duration::from_millis(250)).await; } - })); - chunk_counter += 1; - chunk = Vec::with_capacity(batch_size); - } - if chunk_counter >= batch_num || last_loop { - join_all(futures).await; - futures = Vec::new(); - chunk_counter = 0; - } - if last_loop { - break; - } + if retries >= 60 * 10 { + panic!("Failed to create entities, too many retries"); + } + retries += 1; + sleep(Duration::from_millis(250)).await; + } + }) } - match dbo { - Some(db) => { - create_db_entities(&db, &chunk, &pb).await?; - } - None => { - create_db_entities(&init_db::create_db_ws().await?, &chunk, &pb).await?; + async fn create_single( + self, + db: &Surreal, + lines: &[String], + pb: &Option, + ) -> Result<(), Error> { + let mut counter = 0; + for line in lines { + create_entity(db, line).await?; + counter += 1; + if counter % 100 == 0 { + if let Some(ref p) = pb { + p.inc(100) + } + } } + Ok(()) + } + + async fn create_bulk( + self, + db: &Surreal, + lines: &[String], + pb: &Option, + batch_size: usize, + ) -> Result<(), Error> { + let lines = lines + .iter() + .map(|line| line.trim().trim_end_matches(',').to_string()) + .filter(|line| line != "[" && line != "]") + .collect::>(); + + let mut data_vec: Vec = Vec::with_capacity(batch_size); + let mut claims_vec: Vec = Vec::with_capacity(batch_size); + let mut property_vec: Vec = Vec::with_capacity(batch_size); + let mut lexeme_vec: Vec = Vec::with_capacity(batch_size); + + for line in lines { + let json: Value = from_str(&line).expect("Failed to parse JSON"); + let data = Entity::from_json(json).expect("Failed to parse JSON"); + let (claims, data) = EntityMini::from_entity(data); + match data.id.clone().expect("No ID").tb.as_str() { + "Property" => property_vec.push(data), + "Lexeme" => lexeme_vec.push(data), + "Entity" => data_vec.push(data), + _ => panic!("Unknown table"), + } + claims_vec.push(claims); + } + + db.insert::>("Entity") + .content(data_vec) + .await?; + db.insert::>("Claims") + .content(claims_vec) + .await?; + db.insert::>("Property") + .content(property_vec) + .await?; + db.insert::>("Lexeme") + .content(lexeme_vec) + .await?; + + if let Some(ref p) = pb { + p.inc(batch_size as u64) + } + Ok(()) + } + + async fn create_bulk_filter( + self, + db: &Surreal, + lines: &[String], + pb: &Option, + batch_size: usize, + ) -> Result<(), Error> { + let db_mem = init_db::create_db_mem().await?; + self.create_bulk(&db_mem, lines, &None, batch_size).await?; + + let filter = tokio::fs::read_to_string(&*FILTER_PATH).await?; + db_mem.query(filter).await?; + + let file_name: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(30) + .map(char::from) + .collect(); + + let file_path = format!("data/temp/{}.surql", file_name); + tokio::fs::create_dir_all("data/temp").await?; + + db_mem.export(&file_path).await?; + db.import(&file_path).await?; + + tokio::fs::remove_file(&file_path).await?; + + if let Some(ref p) = pb { + p.inc(batch_size as u64) + } + Ok(()) } - join_all(futures).await; - Ok(()) } diff --git a/src/utils/init_db.rs b/src/utils/init_db.rs index da1075f..9ffa1bb 100644 --- a/src/utils/init_db.rs +++ b/src/utils/init_db.rs @@ -34,5 +34,6 @@ pub async fn create_db_ws() -> Result, Error> { pub async fn create_db_mem() -> Result, Error> { let db = Surreal::new::(()).await?; db.use_ns("wikidata").use_db("wikidata").await?; + Ok(db) } diff --git a/tests/integration.rs b/tests/integration.rs index e9a62b0..bd46d51 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -42,7 +42,7 @@ async fn entity() { let reader = init_reader("json", "Entity"); for line in reader.lines() { - create_db_entity(&db, &line.unwrap()).await.unwrap(); + create_entity(&db, &line.unwrap()).await.unwrap(); } assert_eq!(51.0, entity_query(&db).await.unwrap().unwrap()) @@ -56,7 +56,9 @@ async fn entity_threaded(#[case] version: CreateVersion) -> Result<(), Error> { let db = inti_db().await?; let reader = init_reader("json", "Entity"); - create_db_entities_threaded(Some(db.clone()), reader, None, 1_000, 100, version).await?; + version + .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .await?; assert_eq!(51.0, entity_query(&db).await?.unwrap()); Ok(()) @@ -68,15 +70,9 @@ async fn entity_threaded_filter() -> Result<(), Error> { let db = inti_db().await?; let reader = init_reader("json", "bench"); - create_db_entities_threaded( - Some(db.clone()), - reader, - None, - 1_000, - 100, - CreateVersion::BulkFilter, - ) - .await?; + CreateVersion::BulkFilter + .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .await?; let count: Option = db .query("return count(select * from Entity);") @@ -105,7 +101,7 @@ async fn property() { let reader = init_reader("json", "Property"); for line in reader.lines() { - create_db_entity(&db, &line.unwrap()).await.unwrap(); + create_entity(&db, &line.unwrap()).await.unwrap(); } assert_eq!(2.0, property_query(&db).await.unwrap().unwrap()) @@ -119,7 +115,9 @@ async fn property_threaded(#[case] version: CreateVersion) -> Result<(), Error> let db = inti_db().await?; let reader = init_reader("json", "Property"); - create_db_entities_threaded(Some(db.clone()), reader, None, 1_000, 100, version).await?; + version + .run_threaded(Some(db.clone()), reader, None, 1_000, 100) + .await?; assert_eq!(2.0, property_query(&db).await?.unwrap()); Ok(())