refactor: get_expr

This commit is contained in:
Elijah McMorris 2024-10-02 16:15:45 -07:00
parent 126005905b
commit 43da542d10
Signed by: NexVeridian
SSH key fingerprint: SHA256:bsA1SKZxuEcEVHAy3gY1HUeM5ykRJl0U0kQHQn0hMg8

View file

@ -1,4 +1,3 @@
use anyhow::{Error, Result};
use polars::prelude::*;
use strum::IntoEnumIterator;
@ -30,32 +29,48 @@ impl Ticker {
}
}
fn get_expr(target_col: &str, current: &str, new: &str) -> Vec<Expr> {
match target_col {
"company" => vec![
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("ticker"))
.alias("ticker"),
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("company"))
.alias("company"),
],
"ticker" => vec![
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("company"))
.alias("company"),
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("ticker"))
.alias("ticker"),
],
_ => panic!("Invalid target column"),
}
}
fn arkw(df: DF) -> Result<DF, Error> {
let mut df = df.collect()?;
if let Ok(x) = df
.clone()
.lazy()
.with_columns(vec![
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKW)")))
.then(lit("ARKB"))
.otherwise(col("ticker"))
.alias("ticker"),
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKW)")))
.then(lit("ARKB"))
.otherwise(col("company"))
.alias("company"),
])
.with_columns(vec![
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKF)")))
.then(lit("ARKB"))
.otherwise(col("ticker"))
.alias("ticker"),
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKF)")))
.then(lit("ARKB"))
.otherwise(col("company"))
.alias("company"),
])
.with_columns(Self::get_expr(
"company",
"ARK BITCOIN ETF HOLDCO (ARKW)",
"ARKB",
))
.with_columns(Self::get_expr(
"company",
"ARK BITCOIN ETF HOLDCO (ARKF)",
"ARKB",
))
.collect()
{
df = x;
@ -85,24 +100,20 @@ impl Ticker {
fn cash_usd(df: DF) -> Result<DF, Error> {
let mut df = df.collect()?;
let exprs = |company: &str| -> Vec<Expr> {
vec![
when(col("company").eq(lit(company)))
.then(lit("CASH USD"))
.otherwise(col("ticker"))
.alias("ticker"),
when(col("company").eq(lit(company)))
.then(lit("CASH USD"))
.otherwise(col("company"))
.alias("company"),
]
};
if let Ok(x) = df
.clone()
.lazy()
.with_columns(exprs("Cash & Cash Equivalents"))
.with_columns(exprs("GOLDMAN FS TRSY OBLIG INST 468"))
.with_columns(Self::get_expr(
"company",
"Cash & Cash Equivalents",
"CASH USD",
))
.with_columns(Self::get_expr(
"company",
"GOLDMAN FS TRSY OBLIG INST 468",
"CASH USD",
))
.with_columns(Self::get_expr("company", "Cash & Other", "CASH USD"))
.collect()
{
df = x;
@ -150,12 +161,12 @@ mod tests {
#[case::cash_usd(
Ticker::CASH_USD,
defualt_df(
&[None::<&str>, None::<&str>],
&[Some("Cash & Cash Equivalents"), Some("GOLDMAN FS TRSY OBLIG INST 468")],
&[None::<&str>, None::<&str>, Some("CASH&Other")],
&[Some("Cash & Cash Equivalents"), Some("GOLDMAN FS TRSY OBLIG INST 468"), Some("Cash & Other")],
)?,
defualt_df(
&[Some("CASH USD"), Some("CASH USD")],
&[Some("CASH USD"), Some("CASH USD")],
&[Some("CASH USD"), Some("CASH USD"), Some("CASH USD")],
&[Some("CASH USD"), Some("CASH USD"), Some("CASH USD")],
)?,
)]
fn matrix(