This commit is contained in:
Elijah McMorris 2023-06-16 07:58:46 +00:00
parent 0aff65bd51
commit 0c2b30f348
Signed by: NexVeridian
SSH key fingerprint: SHA256:bsA1SKZxuEcEVHAy3gY1HUeM5ykRJl0U0kQHQn0hMg8
4 changed files with 111 additions and 109 deletions

View file

@ -61,7 +61,9 @@
"Gruntfuggly.todo-tree", "Gruntfuggly.todo-tree",
"ms-azuretools.vscode-docker", "ms-azuretools.vscode-docker",
"redhat.vscode-yaml", "redhat.vscode-yaml",
"GitHub.copilot" // "GitHub.copilot",
"GitHub.copilot-nightly",
"GitHub.copilot-chat"
] ]
} }
} }

2
src/lib.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod util;
pub use util::*;

View file

@ -1,7 +1,6 @@
use chrono::NaiveDate; use chrono::NaiveDate;
use glob::glob; use glob::glob;
use polars::datatypes::DataType; use polars::datatypes::DataType;
use polars::prelude::*; use polars::prelude::*;
use polars::prelude::{DataFrame, StrptimeOptions, UniqueKeepStrategy}; use polars::prelude::{DataFrame, StrptimeOptions, UniqueKeepStrategy};
use reqwest::blocking::Client; use reqwest::blocking::Client;
@ -9,7 +8,6 @@ use serde_json::Value;
use std::error::Error; use std::error::Error;
use std::fs::File; use std::fs::File;
use std::io::Cursor; use std::io::Cursor;
use std::result::Result; use std::result::Result;
use strum_macros::EnumIter; use strum_macros::EnumIter;
@ -109,7 +107,7 @@ impl Ark {
let update = match (source, existing_file) { let update = match (source, existing_file) {
(Source::Read, false) => { (Source::Read, false) => {
panic!("Can not read from file, file is empty or does not exist") panic!("Can not read from file, file is empty, does not exist, or is locked")
} }
(Source::Read, true) => None, (Source::Read, true) => None,
(Source::Ark, _) => Some(ark.get_csv_ark()?), (Source::Ark, _) => Some(ark.get_csv_ark()?),
@ -174,12 +172,12 @@ impl Ark {
Ok(df.into()) Ok(df.into())
} }
fn sort(mut self) -> Result<Self, Box<dyn Error>> { pub fn sort(mut self) -> Result<Self, Box<dyn Error>> {
self.df = Self::df_sort(self.df.clone())?; self.df = Self::df_sort(self.df.clone())?;
Ok(self) Ok(self)
} }
fn df_sort(df: DF) -> Result<DF, Box<dyn Error>> { pub fn df_sort(df: DF) -> Result<DF, Box<dyn Error>> {
Ok(df Ok(df
.collect()? .collect()?
.sort(["date", "weight"], vec![false, true])? .sort(["date", "weight"], vec![false, true])?
@ -205,7 +203,7 @@ impl Ark {
Ok(self) Ok(self)
} }
fn df_format(df: DF) -> Result<DF, Box<dyn Error>> { pub fn df_format(df: DF) -> Result<DF, Box<dyn Error>> {
let mut df = df.collect()?; let mut df = df.collect()?;
if df.get_column_names().contains(&"market_value_($)") { if df.get_column_names().contains(&"market_value_($)") {
@ -425,11 +423,6 @@ mod tests {
use serial_test::serial; use serial_test::serial;
use std::fs; use std::fs;
fn write_df_parquet(df: DF) -> Result<(), Box<dyn Error>> {
Ark::write_df_parquet("data/test/ARKK.parquet".into(), df)?;
Ok(())
}
#[test] #[test]
#[serial] #[serial]
fn read_write_parquet() -> Result<(), Box<dyn Error>> { fn read_write_parquet() -> Result<(), Box<dyn Error>> {
@ -444,107 +437,11 @@ mod tests {
"weight" => [10.00] "weight" => [10.00]
]?; ]?;
write_df_parquet(test_df.clone().into())?; Ark::write_df_parquet("data/test/ARKK.parquet".into(), test_df.clone().into())?;
let read = Ark::new(Source::Read, Ticker::ARKK, Some("data/test".to_owned()))?.collect()?; let read = Ark::new(Source::Read, Ticker::ARKK, Some("data/test".to_owned()))?.collect()?;
fs::remove_file("data/test/ARKK.parquet")?; fs::remove_file("data/test/ARKK.parquet")?;
assert_eq!(read, test_df); assert_eq!(read, test_df);
Ok(()) Ok(())
} }
#[test]
#[serial]
fn get_api_arkk() -> Result<(), Box<dyn Error>> {
let df = Ark::new(
Source::ApiIncremental,
Ticker::ARKK,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 5, 18))?
.collect()?;
assert_eq!(
df.get_column_names(),
[
"company",
"cusip",
"date",
"market_value",
"share_price",
"shares",
"ticker",
"weight",
"weight_rank"
]
);
Ok(())
}
#[test]
#[serial]
fn get_api_format_arkk() -> Result<(), Box<dyn Error>> {
let dfl = Ark::new(
Source::ApiIncremental,
Ticker::ARKK,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 5, 18))?;
let df = Ark::df_format(dfl.into())?.collect()?;
assert_eq!(
(df.get_column_names(), df.dtypes(), df.shape().1 > 1),
(
vec![
"date",
"ticker",
"cusip",
"company",
"market_value",
"shares",
"share_price",
"weight",
],
vec![
DataType::Date,
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
DataType::Int64,
DataType::Int64,
DataType::Float64,
DataType::Float64,
],
true
)
);
Ok(())
}
#[test]
#[serial]
fn get_api_format_arkvc() -> Result<(), Box<dyn Error>> {
let dfl = Ark::new(
Source::ApiIncremental,
Ticker::ARKVC,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 1, 1))?;
let df = Ark::df_format(dfl.into())?.collect()?;
assert_eq!(
(df.get_column_names(), df.dtypes(), df.shape().1 > 1),
(
vec!["date", "ticker", "cusip", "company", "weight"],
vec![
DataType::Date,
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
DataType::Float64
],
true
)
);
Ok(())
}
} }

101
tests/integration.rs Normal file
View file

@ -0,0 +1,101 @@
use ark_invest_api_rust_data::util::*;
use chrono::NaiveDate;
use polars::datatypes::DataType;
use serial_test::serial;
use std::error::Error;
#[test]
#[serial]
fn get_api_arkk() -> Result<(), Box<dyn Error>> {
let df = Ark::new(
Source::ApiIncremental,
Ticker::ARKK,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 5, 18))?
.collect()?;
assert_eq!(
df.get_column_names(),
[
"company",
"cusip",
"date",
"market_value",
"share_price",
"shares",
"ticker",
"weight",
"weight_rank"
]
);
Ok(())
}
#[test]
#[serial]
fn get_api_format_arkk() -> Result<(), Box<dyn Error>> {
let dfl = Ark::new(
Source::ApiIncremental,
Ticker::ARKK,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 5, 18))?;
let df = Ark::df_format(dfl.into())?.collect()?;
assert_eq!(
(df.get_column_names(), df.dtypes(), df.shape().1 > 1),
(
vec![
"date",
"ticker",
"cusip",
"company",
"market_value",
"shares",
"share_price",
"weight",
],
vec![
DataType::Date,
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
DataType::Int64,
DataType::Int64,
DataType::Float64,
DataType::Float64,
],
true
)
);
Ok(())
}
#[test]
#[serial]
fn get_api_format_arkvc() -> Result<(), Box<dyn Error>> {
let dfl = Ark::new(
Source::ApiIncremental,
Ticker::ARKVC,
Some("data/test".to_owned()),
)?
.get_api(NaiveDate::from_ymd_opt(2023, 1, 1))?;
let df = Ark::df_format(dfl.into())?.collect()?;
assert_eq!(
(df.get_column_names(), df.dtypes(), df.shape().1 > 1),
(
vec!["date", "ticker", "cusip", "company", "weight"],
vec![
DataType::Date,
DataType::Utf8,
DataType::Utf8,
DataType::Utf8,
DataType::Float64
],
true
)
);
Ok(())
}