refactor: get_expr

This commit is contained in:
Elijah McMorris 2024-10-02 16:15:45 -07:00
parent 126005905b
commit 43da542d10
Signed by: NexVeridian
SSH key fingerprint: SHA256:bsA1SKZxuEcEVHAy3gY1HUeM5ykRJl0U0kQHQn0hMg8

View file

@ -1,4 +1,3 @@
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use polars::prelude::*; use polars::prelude::*;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
@ -30,32 +29,48 @@ impl Ticker {
} }
} }
fn get_expr(target_col: &str, current: &str, new: &str) -> Vec<Expr> {
match target_col {
"company" => vec![
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("ticker"))
.alias("ticker"),
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("company"))
.alias("company"),
],
"ticker" => vec![
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("company"))
.alias("company"),
when(col(target_col).eq(lit(current)))
.then(lit(new))
.otherwise(col("ticker"))
.alias("ticker"),
],
_ => panic!("Invalid target column"),
}
}
fn arkw(df: DF) -> Result<DF, Error> { fn arkw(df: DF) -> Result<DF, Error> {
let mut df = df.collect()?; let mut df = df.collect()?;
if let Ok(x) = df if let Ok(x) = df
.clone() .clone()
.lazy() .lazy()
.with_columns(vec![ .with_columns(Self::get_expr(
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKW)"))) "company",
.then(lit("ARKB")) "ARK BITCOIN ETF HOLDCO (ARKW)",
.otherwise(col("ticker")) "ARKB",
.alias("ticker"), ))
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKW)"))) .with_columns(Self::get_expr(
.then(lit("ARKB")) "company",
.otherwise(col("company")) "ARK BITCOIN ETF HOLDCO (ARKF)",
.alias("company"), "ARKB",
]) ))
.with_columns(vec![
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKF)")))
.then(lit("ARKB"))
.otherwise(col("ticker"))
.alias("ticker"),
when(col("company").eq(lit("ARK BITCOIN ETF HOLDCO (ARKF)")))
.then(lit("ARKB"))
.otherwise(col("company"))
.alias("company"),
])
.collect() .collect()
{ {
df = x; df = x;
@ -85,24 +100,20 @@ impl Ticker {
fn cash_usd(df: DF) -> Result<DF, Error> { fn cash_usd(df: DF) -> Result<DF, Error> {
let mut df = df.collect()?; let mut df = df.collect()?;
let exprs = |company: &str| -> Vec<Expr> {
vec![
when(col("company").eq(lit(company)))
.then(lit("CASH USD"))
.otherwise(col("ticker"))
.alias("ticker"),
when(col("company").eq(lit(company)))
.then(lit("CASH USD"))
.otherwise(col("company"))
.alias("company"),
]
};
if let Ok(x) = df if let Ok(x) = df
.clone() .clone()
.lazy() .lazy()
.with_columns(exprs("Cash & Cash Equivalents")) .with_columns(Self::get_expr(
.with_columns(exprs("GOLDMAN FS TRSY OBLIG INST 468")) "company",
"Cash & Cash Equivalents",
"CASH USD",
))
.with_columns(Self::get_expr(
"company",
"GOLDMAN FS TRSY OBLIG INST 468",
"CASH USD",
))
.with_columns(Self::get_expr("company", "Cash & Other", "CASH USD"))
.collect() .collect()
{ {
df = x; df = x;
@ -150,12 +161,12 @@ mod tests {
#[case::cash_usd( #[case::cash_usd(
Ticker::CASH_USD, Ticker::CASH_USD,
defualt_df( defualt_df(
&[None::<&str>, None::<&str>], &[None::<&str>, None::<&str>, Some("CASH&Other")],
&[Some("Cash & Cash Equivalents"), Some("GOLDMAN FS TRSY OBLIG INST 468")], &[Some("Cash & Cash Equivalents"), Some("GOLDMAN FS TRSY OBLIG INST 468"), Some("Cash & Other")],
)?, )?,
defualt_df( defualt_df(
&[Some("CASH USD"), Some("CASH USD")], &[Some("CASH USD"), Some("CASH USD"), Some("CASH USD")],
&[Some("CASH USD"), Some("CASH USD")], &[Some("CASH USD"), Some("CASH USD"), Some("CASH USD")],
)?, )?,
)] )]
fn matrix( fn matrix(