我有两个数据框:
df_rates df_trades
(rate from currency -> USD)
+--------+----------+----+ +---+--------+------+----------+
|currency| rate_date|rate| | id|currency|amount|trade_date|
+--------+----------+----+ +---+--------+------+----------+
| EUR|2025-01-09|1.19| | 1| EUR| 1000|2025-01-09| # exact rate available
| EUR|2025-01-08|1.18| | 2| CAD| 1000|2025-01-09| # 1 day prior rate available
| CAD|2025-01-08|0.78| | 3| AUD| 1000|2025-01-09| # no applicable rate available
| CAD|2025-01-07|0.77| | 4| HKD| 1000|2025-01-09| # no rate available at all
| AUD|2025-02-09|1.39| +---+--------+------+----------+
| AUD|2025-02-08|1.38|
+--------+----------+----+
对于每笔交易,我都需要应用适当的汇率来计算 usd_amount。选择汇率的方法是:
- 查找
rate
trade_date
- 如果不可用
trade_date
则返回最多 7 天
如果以这种方式找不到利率,则usd_amount = null
我有以下有效的代码。但我不确定它是否可以扩展。特别是对于这种情况trade_id = 3
(当有可用利率但没有正确的日期范围时),因为实际上利率表有 1000 个利率(可追溯到 5-7 年前)。部分代码标记This PART
如下。
是否存在其他逻辑可以更有效地实现这一目标?
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, row_number, date_diff, when
def log_dataframe(df, msg):
print(msg)
df.show()
def calc_usd_amount(df_trades, df_rates):
df = df_trades.join(df_rates, how='left_outer', on='currency').withColumn('date_diff', date_diff('trade_date', 'rate_date'))
date_diff_no_good = (col('date_diff') < 0) | (col('date_diff') > 7)
# This PART
df = (
df.withColumns({
'rate_date': when(date_diff_no_good, None).otherwise(col('rate_date')),
'rate': when(date_diff_no_good, None).otherwise(col('rate')),
})
.drop_duplicates(['id', 'rate_date', 'rate'])
)
w_spec = row_number().over(Window.partitionBy(col('id'), col('currency')).orderBy(col('rate_date').desc()))
df = (
df.filter('rate_date IS NULL OR (rate_date <= trade_date AND rate_date > (trade_date - 7))')
.withColumn('rate_row_num', w_spec).filter('rate_row_num == 1')
.withColumn('usd_amount', col('rate') * col('amount'))
)
return df.drop('date_diff', 'rate_row_num')
from pyspark.sql import SparkSession
from datetime import date
spark = SparkSession.builder.getOrCreate()
dt = date.fromisoformat
df_trades = spark.createDataFrame(
data = [
(1, 'EUR', 1000, dt('2025-01-09')), # trade date rate available
(2, 'CAD', 1000, dt('2025-01-09')), # trade date -1d, rate available
(3, 'AUD', 1000, dt('2025-01-09')), # no applicable rate available
(4, 'HKD', 1000, dt('2025-01-09')), # no rate available at all
],
schema=['id', 'currency', 'amount', 'trade_date'],
)
df_rates = spark.createDataFrame(
data = [
('EUR', dt('2025-01-09'), 1.19), # trade date rate available
('EUR', dt('2025-01-08'), 1.18),
('CAD', dt('2025-01-08'), 0.78), # trade date -1d, rate available
('CAD', dt('2025-01-07'), 0.77),
('AUD', dt('2025-02-09'), 1.39), # no applicable rate available
('AUD', dt('2025-02-08'), 1.38),
],
schema=['currency', 'rate_date', 'rate']
)
df_out = calc_usd_amount(df_trades, df_rates)
log_dataframe(df_out, 'df_out')
印刷:
df_out
+--------+---+------+----------+----------+----+----------+
|currency| id|amount|trade_date| rate_date|rate|usd_amount|
+--------+---+------+----------+----------+----+----------+
| EUR| 1| 1000|2025-01-09|2025-01-09|1.19| 1190.0|
| CAD| 2| 1000|2025-01-09|2025-01-08|0.78| 780.0|
| AUD| 3| 1000|2025-01-09| NULL|NULL| NULL|
| HKD| 4| 1000|2025-01-09| NULL|NULL| NULL|
+--------+---+------+----------+----------+----+----------+