diff --git a/Cargo.lock b/Cargo.lock index 3e7cf1f..58af6b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4347,7 +4347,6 @@ dependencies = [ "criterion", "futures 0.3.30", "indicatif", - "lazy_static", "pprof", "rand", "rstest", diff --git a/Cargo.toml b/Cargo.toml index b8508d5..a91c248 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,11 +9,10 @@ anyhow = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" surrealdb = { version = "2.0.1", features = ["protocol-http", "kv-mem"] } -tokio = { version = "1.39", features = ["fs", "time"] } +tokio = { version = "1.39", features = ["fs", "time", "sync"] } futures = "0.3" wikidata = "1.1" bzip2 = { version = "0.4", features = ["tokio"] } -lazy_static = "1.5" indicatif = "0.17" rand = "0.8" backon = { version = "1.2", features = ["tokio-sleep"] } diff --git a/README.md b/README.md index 8d208cb..ef97bd9 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Env string CREATE_VERSION must be in the enum CREATE_VERSION pub enum CreateVersion { #[default] Bulk, - /// must create a filter.surql file in the data directory + /// must create a `filter.surql` file in the data directory BulkFilter, } ``` diff --git a/src/main.rs b/src/main.rs index e2eee8a..8bab9e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,39 +1,59 @@ use anyhow::{Error, Ok, Result}; -use lazy_static::lazy_static; use std::env; use surrealdb::{engine::remote::http::Client, Surreal}; -use tokio::time::{sleep, Duration}; - +use tokio::{ + fs, + sync::OnceCell, + time::{sleep, Duration}, +}; mod utils; use init_reader::File_Format; use utils::*; -lazy_static! { - static ref WIKIDATA_FILE_FORMAT: String = - env::var("WIKIDATA_FILE_FORMAT").expect("FILE_FORMAT not set"); - static ref WIKIDATA_FILE_NAME: String = - env::var("WIKIDATA_FILE_NAME").expect("FILE_NAME not set"); - static ref CREATE_VERSION: CreateVersion = match env::var("CREATE_VERSION") - .expect("CREATE_VERSION not set") - .as_str() - { - "Bulk" => CreateVersion::Bulk, - "BulkFilter" => CreateVersion::BulkFilter, - _ => panic!("Unknown CREATE_VERSION"), - }; +static WIKIDATA_FILE_FORMAT: OnceCell = OnceCell::const_new(); +static WIKIDATA_FILE_NAME: OnceCell = OnceCell::const_new(); +static CREATE_VERSION: OnceCell = OnceCell::const_new(); + +async fn get_wikidata_file_format() -> &'static String { + WIKIDATA_FILE_FORMAT + .get_or_init(|| async { env::var("WIKIDATA_FILE_FORMAT").expect("FILE_FORMAT not set") }) + .await +} + +async fn get_wikidata_file_name() -> &'static String { + WIKIDATA_FILE_NAME + .get_or_init(|| async { env::var("WIKIDATA_FILE_NAME").expect("FILE_NAME not set") }) + .await +} + +async fn get_create_version() -> &'static CreateVersion { + CREATE_VERSION + .get_or_init(|| async { + match env::var("CREATE_VERSION") + .expect("CREATE_VERSION not set") + .as_str() + { + "Bulk" => CreateVersion::Bulk, + "BulkFilter" => CreateVersion::BulkFilter, + _ => panic!("Unknown CREATE_VERSION"), + } + }) + .await } #[tokio::main] async fn main() -> Result<(), Error> { sleep(Duration::from_secs(10)).await; let pb = init_progress_bar::create_pb().await; - let reader = File_Format::new(&WIKIDATA_FILE_FORMAT).reader(&WIKIDATA_FILE_NAME)?; + let reader = File_Format::new(get_wikidata_file_format().await) + .reader(get_wikidata_file_name().await)?; - tokio::fs::create_dir_all("data/temp").await?; - tokio::fs::remove_dir_all("data/temp").await?; - tokio::fs::create_dir_all("data/temp").await?; + fs::create_dir_all("data/temp").await?; + fs::remove_dir_all("data/temp").await?; + fs::create_dir_all("data/temp").await?; - CREATE_VERSION + get_create_version() + .await .run( None::>, reader, diff --git a/src/utils.rs b/src/utils.rs index 6132f02..8e2fad9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,11 +3,11 @@ use backon::Retryable; use core::panic; use futures::future::join_all; use indicatif::ProgressBar; -use lazy_static::lazy_static; use rand::{distributions::Alphanumeric, Rng}; use serde_json::{from_str, Value}; use std::{env, io::BufRead}; use surrealdb::{Connection, Surreal}; +use tokio::sync::OnceCell; use wikidata::Entity; pub mod init_backoff; @@ -17,20 +17,33 @@ pub mod init_reader; mod tables; use tables::*; -lazy_static! { - static ref OVERWRITE_DB: bool = env::var("OVERWRITE_DB") - .expect("OVERWRITE_DB not set") - .parse() - .expect("Failed to parse OVERWRITE_DB"); - static ref FILTER_PATH: String = - env::var("FILTER_PATH").unwrap_or("data/filter.surql".to_string()); +static OVERWRITE_DB: OnceCell = OnceCell::const_new(); +static FILTER_PATH: OnceCell = OnceCell::const_new(); + +async fn get_overwrite_db() -> bool { + *OVERWRITE_DB + .get_or_init(|| async { + env::var("OVERWRITE_DB") + .expect("OVERWRITE_DB not set") + .parse::() + .expect("Failed to parse OVERWRITE_DB") + }) + .await +} + +async fn get_filter_path() -> &'static String { + FILTER_PATH + .get_or_init(|| async { + env::var("FILTER_PATH").unwrap_or("data/filter.surql".to_string()) + }) + .await } #[derive(Clone, Copy, Default)] pub enum CreateVersion { #[default] Bulk, - /// must create a filter.surql file in the root directory + /// must create a `filter.surql` file in the root directory BulkFilter, } @@ -77,7 +90,7 @@ impl CreateVersion { Some(db) => self.create_retry(&db, &chunk, &pb, batch_size).await, None => { let db = init_db::create_db_remote - .retry(*init_backoff::exponential) + .retry(*init_backoff::get_exponential().await) .await .expect("Failed to create remote db"); self.create_retry(&db, &chunk, &pb, batch_size).await @@ -96,7 +109,7 @@ impl CreateVersion { batch_size: usize, ) -> Result<(), Error> { (|| async { self.create(db, chunk, pb, batch_size).await }) - .retry(*init_backoff::exponential) + .retry(*init_backoff::get_exponential().await) .await } @@ -137,7 +150,7 @@ impl CreateVersion { Ok(data) => data, Err(_) => continue, }; - let (claims, data) = EntityMini::from_entity(data); + let (claims, data) = EntityMini::from_entity(data).await; match data.id.clone().expect("No ID").tb.as_str() { "Property" => property_vec.push(data), "Lexeme" => lexeme_vec.push(data), @@ -147,7 +160,7 @@ impl CreateVersion { claims_vec.push(claims); } - if *OVERWRITE_DB { + if get_overwrite_db().await { db.upsert::>("Entity") .content(entity_vec) .await?; @@ -191,7 +204,7 @@ impl CreateVersion { let db_mem = init_db::create_db_mem().await?; self.create_bulk(&db_mem, lines, &None, batch_size).await?; - let filter = tokio::fs::read_to_string(&*FILTER_PATH).await?; + let filter = tokio::fs::read_to_string(get_filter_path().await).await?; db_mem.query(filter).await?; let file_name: String = rand::thread_rng() diff --git a/src/utils/init_backoff.rs b/src/utils/init_backoff.rs index 7e26766..00f3472 100644 --- a/src/utils/init_backoff.rs +++ b/src/utils/init_backoff.rs @@ -1,9 +1,14 @@ use backon::ExponentialBuilder; -use lazy_static::lazy_static; -use tokio::time::Duration; +use tokio::{sync::OnceCell, time::Duration}; -lazy_static! { - pub static ref exponential: ExponentialBuilder = ExponentialBuilder::default() - .with_max_times(30) - .with_max_delay(Duration::from_secs(60)); +static BACKOFF_EXPONENTIAL: OnceCell = OnceCell::const_new(); + +pub async fn get_exponential() -> &'static ExponentialBuilder { + BACKOFF_EXPONENTIAL + .get_or_init(|| async { + ExponentialBuilder::default() + .with_max_times(30) + .with_max_delay(Duration::from_secs(60)) + }) + .await } diff --git a/src/utils/init_db.rs b/src/utils/init_db.rs index 4733de1..a2e217a 100644 --- a/src/utils/init_db.rs +++ b/src/utils/init_db.rs @@ -1,6 +1,5 @@ use anyhow::Error; use anyhow::Result; -use lazy_static::lazy_static; use std::env; use surrealdb::{ engine::{ @@ -10,20 +9,36 @@ use surrealdb::{ opt::auth::Root, Surreal, }; +use tokio::sync::OnceCell; -lazy_static! { - static ref DB_USER: String = env::var("DB_USER").expect("DB_USER not set"); - static ref DB_PASSWORD: String = env::var("DB_PASSWORD").expect("DB_PASSWORD not set"); - static ref WIKIDATA_DB_PORT: String = - env::var("WIKIDATA_DB_PORT").expect("WIKIDATA_DB_PORT not set"); +static DB_USER: OnceCell = OnceCell::const_new(); +static DB_PASSWORD: OnceCell = OnceCell::const_new(); +static WIKIDATA_DB_PORT: OnceCell = OnceCell::const_new(); + +pub async fn get_db_user() -> &'static String { + DB_USER + .get_or_init(|| async { env::var("DB_USER").expect("DB_USER not set") }) + .await +} + +pub async fn get_db_password() -> &'static String { + DB_PASSWORD + .get_or_init(|| async { env::var("DB_PASSWORD").expect("DB_PASSWORD not set") }) + .await +} + +pub async fn get_wikidata_db_port() -> &'static String { + WIKIDATA_DB_PORT + .get_or_init(|| async { env::var("WIKIDATA_DB_PORT").expect("WIKIDATA_DB_PORT not set") }) + .await } pub async fn create_db_remote() -> Result, Error> { - let db = Surreal::new::(WIKIDATA_DB_PORT.as_str()).await?; + let db = Surreal::new::(get_wikidata_db_port().await).await?; db.signin(Root { - username: &DB_USER, - password: &DB_PASSWORD, + username: get_db_user().await, + password: get_db_password().await, }) .await?; db.use_ns("wikidata").use_db("wikidata").await?; diff --git a/src/utils/tables.rs b/src/utils/tables.rs index f074357..f8d7fb2 100644 --- a/src/utils/tables.rs +++ b/src/utils/tables.rs @@ -1,13 +1,16 @@ -use lazy_static::lazy_static; +use futures::future::join_all; use serde::{Deserialize, Serialize}; use std::env; use surrealdb::sql::{Id, Thing}; +use tokio::sync::OnceCell; use wikidata::{ClaimValue, ClaimValueData, Entity, Lang, Pid, WikiId}; -lazy_static! { - static ref WIKIDATA_LANG: String = env::var("WIKIDATA_LANG") - .expect("WIKIDATA_LANG not set") - .to_string(); +static WIKIDATA_LANG: OnceCell = OnceCell::const_new(); + +async fn get_wikidata_lang() -> &'static String { + WIKIDATA_LANG + .get_or_init(|| async { env::var("WIKIDATA_LANG").expect("WIKIDATA_LANG not set") }) + .await } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -17,7 +20,7 @@ pub enum ClaimData { } impl ClaimData { - fn from_cvd(cvd: ClaimValueData) -> Self { + async fn from_cvd(cvd: ClaimValueData) -> Self { match cvd { ClaimValueData::Item(qid) => ClaimData::Thing(Thing::from(("Entity", Id::from(qid.0)))), ClaimValueData::Property(pid) => { @@ -54,48 +57,53 @@ pub struct EntityMini { } impl EntityMini { - pub fn from_entity(entity: Entity) -> (Claims, Self) { - let thing_claim = Thing::from(("Claims", get_id_entity(&entity).id)); + pub async fn from_entity(entity: Entity) -> (Claims, Self) { + let thing_claim = Thing::from(("Claims", get_id_entity(&entity).await.id)); ( Claims { id: Some(thing_claim.clone()), - ..Self::flatten_claims(entity.claims.clone()) + ..Self::flatten_claims(entity.claims.clone()).await }, Self { - id: Some(get_id_entity(&entity)), - label: get_name(&entity), + id: Some(get_id_entity(&entity).await), + label: get_name(&entity).await, claims: thing_claim, - description: get_description(&entity), + description: get_description(&entity).await, }, ) } - fn flatten_claims(claims: Vec<(Pid, ClaimValue)>) -> Claims { + async fn flatten_claims(claims: Vec<(Pid, ClaimValue)>) -> Claims { Claims { id: None, - claims: claims - .iter() - .flat_map(|(pid, claim_value)| { + claims: { + let futures = claims.iter().map(|(pid, claim_value)| async { let mut flattened = vec![Claim { id: Thing::from(("Property", Id::from(pid.0))), - value: ClaimData::from_cvd(claim_value.data.clone()), + value: ClaimData::from_cvd(claim_value.data.clone()).await, }]; - flattened.extend(claim_value.qualifiers.iter().map( - |(qualifier_pid, qualifier_value)| Claim { - id: Thing::from(("Claims", Id::from(qualifier_pid.0))), - value: ClaimData::from_cvd(qualifier_value.clone()), + let inner_futures = claim_value.qualifiers.iter().map( + |(qualifier_pid, qualifier_value)| async { + let qualifier_data = ClaimData::from_cvd(qualifier_value.clone()).await; + Claim { + id: Thing::from(("Claims", Id::from(qualifier_pid.0))), + value: qualifier_data, + } }, - )); + ); + flattened.extend(join_all(inner_futures).await); flattened - }) - .collect(), + }); + + join_all(futures).await.into_iter().flatten().collect() + }, } } } -fn get_id_entity(entity: &Entity) -> Thing { +async fn get_id_entity(entity: &Entity) -> Thing { let (id, tb) = match entity.id { WikiId::EntityId(qid) => (qid.0, "Entity".to_string()), WikiId::PropertyId(pid) => (pid.0, "Property".to_string()), @@ -106,18 +114,18 @@ fn get_id_entity(entity: &Entity) -> Thing { Thing::from((tb, Id::from(id))) } -fn get_name(entity: &Entity) -> String { +async fn get_name(entity: &Entity) -> String { entity .labels - .get(&Lang(WIKIDATA_LANG.to_string())) + .get(&Lang(get_wikidata_lang().await.to_string())) .map(|label| label.to_string()) .unwrap_or_default() } -fn get_description(entity: &Entity) -> String { +async fn get_description(entity: &Entity) -> String { entity .descriptions - .get(&Lang(WIKIDATA_LANG.to_string())) + .get(&Lang(get_wikidata_lang().await.to_string())) .cloned() .unwrap_or_default() }