diff --git a/.cargo/config.toml b/.cargo/config.toml index 08f4cbb..3c952ce 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,5 @@ # [target.x86_64-unknown-linux-gnu] # rustflags = ["-C", "link-arg=-fuse-ld=/usr/local/bin/mold"] + +[alias] +t = "nextest run" diff --git a/Cargo.toml b/Cargo.toml index 5e67de5..b0d56c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,6 @@ tokio = { version = "1.26", features = ["full"] } openssl = { version = "0.10", features = ["vendored"] } chrono = { version = "0.4", features = ["serde"] } serde_json = "1.0" + +[dev-dependencies] +serial_test = "*" diff --git a/src/main.rs b/src/main.rs index be6b870..d75c6e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,25 +37,26 @@ use util::*; // } fn main() { - let read = Ark::new(Source::Read, Ticker::ARKK) + let read = Ark::new(Source::Read, Ticker::ARKK, None) .unwrap() .collect() .unwrap(); + println!("{:#?}", read.dtypes()); println!("{:#?}", read); - let api = Ark::new(Source::ApiFull, Ticker::ARKK) + let api = Ark::new(Source::ApiFull, Ticker::ARKK, None) .unwrap() .collect() .unwrap(); println!("{:#?}", api); - // let ark = Ark::new(Source::Ark, Ticker::ARKK) + // let ark = Ark::new(Source::Ark, Ticker::ARKK, None) // .unwrap() // .collect() // .unwrap(); // println!("{:#?}", ark); - // let ark = Ark::new(Source::Ark, Ticker::ARKVC) + // let ark = Ark::new(Source::Ark, Ticker::ARKVC, None) // .unwrap() // .collect() // .unwrap(); diff --git a/src/util.rs b/src/util.rs index 243881e..a7eccca 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,7 @@ use chrono::NaiveDate; use glob::glob; use polars::datatypes::DataType; + use polars::prelude::*; use polars::prelude::{DataFrame, StrptimeOptions, UniqueKeepStrategy}; use reqwest::blocking::Client; @@ -8,6 +9,7 @@ use serde_json::Value; use std::error::Error; use std::fs::File; use std::io::Cursor; + use std::result::Result; use strum_macros::EnumIter; @@ -84,20 +86,34 @@ pub enum Source { ApiFull, } pub struct Ark { - df: DF, + pub df: DF, ticker: Ticker, + path: Option, } impl Ark { - pub fn new(source: Source, ticker: Ticker) -> Result> { + pub fn new( + source: Source, + ticker: Ticker, + path: Option, + ) -> Result> { + let existing_file = Self::read_parquet(ticker, path.clone()).is_ok(); + let mut ark = Self { - df: Self::read_parquet(ticker)?, + df: match existing_file { + true => Self::read_parquet(ticker, path.clone())?, + false => DF::DataFrame(df!["date" => [""],]?), + }, ticker, + path, }; - let update = match source { - Source::Read => None, - Source::Ark => Some(ark.get_csv_ark()?), - Source::ApiIncremental => { + let update = match (source, existing_file) { + (Source::Read, false) => { + panic!("Can not read from file, file is empty or does not exist") + } + (Source::Read, true) => None, + (Source::Ark, _) => Some(ark.get_csv_ark()?), + (Source::ApiIncremental, true) => { let last_day = ark .df .clone() @@ -108,15 +124,20 @@ impl Ark { .and_then(NaiveDate::from_num_days_from_ce_opt); Some(ark.get_api(last_day)?) } - Source::ApiFull => Some(ark.get_api(None)?), + (Source::ApiIncremental, false) | (Source::ApiFull, _) => Some(ark.get_api(None)?), }; if let Some(update) = update { - ark.df = Self::concat_df(vec![ - Self::df_format(ark.df)?, - Self::df_format(update.into())?, - ])?; + if existing_file { + ark.df = Self::concat_df(vec![ + Self::df_format(ark.df)?, + Self::df_format(update.into())?, + ])?; + } else { + ark.df = Self::df_format(update.into())?; + } } + Ok(ark) } @@ -124,17 +145,35 @@ impl Ark { self.df.collect() } - pub fn write_parquet(&mut self) -> Result<&Self, Box> { - // with format + pub fn write_parquet(self) -> Result> { + // with format df let ark = self.format()?; - ParquetWriter::new(File::create(format!( - "data/parquet/{}.parquet", - ark.ticker - ))?) - .finish(&mut ark.df.clone().collect()?)?; + Self::write_df_parquet( + match &ark.path { + Some(path) => format!("{}/{}.parquet", path, ark.ticker), + None => format!("data/parquet/{}.parquet", ark.ticker), + }, + ark.df.clone(), + )?; Ok(ark) } + fn write_df_parquet(path: String, df: DF) -> Result<(), Box> { + ParquetWriter::new(File::create(path)?).finish(&mut df.collect()?)?; + Ok(()) + } + + fn read_parquet(ticker: Ticker, path: Option) -> Result> { + let df = LazyFrame::scan_parquet( + match path { + Some(p) => format!("{}/{}.parquet", p, ticker), + None => format!("data/parquet/{}.parquet", ticker), + }, + ScanArgsParquet::default(), + )?; + Ok(df.into()) + } + fn sort(mut self) -> Result> { self.df = Self::df_sort(self.df.clone())?; Ok(self) @@ -147,14 +186,6 @@ impl Ark { .into()) } - fn read_parquet(ticker: Ticker) -> Result> { - let df = LazyFrame::scan_parquet( - format!("data/parquet/{}.parquet", ticker), - ScanArgsParquet::default(), - )?; - Ok(df.into()) - } - fn concat_df(dfs: Vec) -> Result> { // with dedupe let df = concat(dfs.lazy(), false, true)?; @@ -169,7 +200,7 @@ impl Ark { Ok(df) } - pub fn format(&mut self) -> Result<&Self, Box> { + pub fn format(mut self) -> Result> { self.df = Self::df_format(self.df.clone())?; Ok(self) } @@ -307,11 +338,11 @@ impl Ark { pub fn get_api(&self, last_day: Option) -> Result> { let url = match (self.ticker, last_day) { (self::Ticker::ARKVC, Some(last_day)) => format!( - "https://api.nexveridian.com/arkvc_holdings?end={}", + "https://api.nexveridian.com/arkvc_holdings?start={}", last_day ), (tic, Some(last_day)) => format!( - "https://api.nexveridian.com/ark_holdings?ticker={}&end={}", + "https://api.nexveridian.com/ark_holdings?ticker={}&start={}", tic, last_day ), (self::Ticker::ARKVC, None) => "https://api.nexveridian.com/arkvc_holdings".to_owned(), @@ -330,19 +361,22 @@ impl Ark { Reader::Csv.get_data_url(url) } - pub fn merge_old_csv_to_parquet(ticker: Ticker) -> Result> { + pub fn merge_old_csv_to_parquet( + ticker: Ticker, + path: Option, + ) -> Result> { let mut dfs = vec![]; for x in glob(&format!("data/csv/{}/*", ticker))?.filter_map(Result::ok) { dfs.push(LazyCsvReader::new(x).finish()?); } let mut df = concat(dfs, false, true)?.into(); - if Self::read_parquet(ticker).is_ok() { - let df_old = Self::read_parquet(ticker)?; + if Self::read_parquet(ticker, path.clone()).is_ok() { + let df_old = Self::read_parquet(ticker, path.clone())?; df = Self::concat_df(vec![Self::df_format(df_old)?, Self::df_format(df)?])? } - Ok(Self { df, ticker }) + Ok(Self { df, ticker, path }) } } @@ -384,3 +418,133 @@ impl Reader { Ok(df) } } + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::fs; + + fn write_df_parquet(df: DF) -> Result<(), Box> { + Ark::write_df_parquet("data/test/ARKK.parquet".into(), df)?; + Ok(()) + } + + #[test] + #[serial] + fn read_write_parquet() -> Result<(), Box> { + let test_df = df![ + "date" => ["2023-01-01"], + "ticker" => ["TSLA"], + "cusip" => ["123abc"], + "company" => ["Tesla"], + "market_value" => [100], + "shares" => [10], + "share_price" => [10], + "weight" => [10.00] + ]?; + + write_df_parquet(test_df.clone().into())?; + let read = Ark::new(Source::Read, Ticker::ARKK, Some("data/test".to_owned()))?.collect()?; + fs::remove_file("data/test/ARKK.parquet")?; + + assert_eq!(read, test_df); + Ok(()) + } + + #[test] + #[serial] + fn get_api_arkk() -> Result<(), Box> { + let df = Ark::new( + Source::ApiIncremental, + Ticker::ARKK, + Some("data/test".to_owned()), + )? + .get_api(NaiveDate::from_ymd_opt(2023, 5, 18))? + .collect()?; + + assert_eq!( + df.get_column_names(), + [ + "company", + "cusip", + "date", + "market_value", + "share_price", + "shares", + "ticker", + "weight", + "weight_rank" + ] + ); + Ok(()) + } + + #[test] + #[serial] + fn get_api_format_arkk() -> Result<(), Box> { + let dfl = Ark::new( + Source::ApiIncremental, + Ticker::ARKK, + Some("data/test".to_owned()), + )? + .get_api(NaiveDate::from_ymd_opt(2023, 5, 18))?; + let df = Ark::df_format(dfl.into())?.collect()?; + + assert_eq!( + (df.get_column_names(), df.dtypes(), df.shape().1 > 1), + ( + vec![ + "date", + "ticker", + "cusip", + "company", + "market_value", + "shares", + "share_price", + "weight", + ], + vec![ + DataType::Date, + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + DataType::Int64, + DataType::Int64, + DataType::Float64, + DataType::Float64, + ], + true + ) + ); + Ok(()) + } + + #[test] + #[serial] + fn get_api_format_arkvc() -> Result<(), Box> { + let dfl = Ark::new( + Source::ApiIncremental, + Ticker::ARKVC, + Some("data/test".to_owned()), + )? + .get_api(NaiveDate::from_ymd_opt(2023, 1, 1))?; + let df = Ark::df_format(dfl.into())?.collect()?; + + assert_eq!( + (df.get_column_names(), df.dtypes(), df.shape().1 > 1), + ( + vec!["date", "ticker", "cusip", "company", "weight"], + vec![ + DataType::Date, + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + DataType::Float64 + ], + true + ) + ); + Ok(()) + } +}