diff --git a/docs/source/c_fx_smile.rst b/docs/source/c_fx_smile.rst index af68ebc2..d01d19d1 100644 --- a/docs/source/c_fx_smile.rst +++ b/docs/source/c_fx_smile.rst @@ -404,8 +404,8 @@ this produces minor deviations from his calculated values. for op in ops: op.rate(fx=fxf) - strikes = [float(_._pricing["k"]) for _ in ops] - vols = [float(_._pricing["vol"]) for _ in ops] + strikes = [float(_._pricing.k) for _ in ops] + vols = [float(_._pricing.vol) for _ in ops] data2 = DataFrame( data=[strikes[0:3], vols[0:3], strikes[3:6], vols[3:6], strikes[6:9], vols[6:9]], index=[("1y", "k"), ("1y", "vol"), ("18m", "k"), ("18m", "vol"), ("2y", "k"), ("2y", "vol")], diff --git a/docs/source/i_whatsnew.rst b/docs/source/i_whatsnew.rst index 6099af8a..fc0a594f 100644 --- a/docs/source/i_whatsnew.rst +++ b/docs/source/i_whatsnew.rst @@ -101,6 +101,10 @@ email contact, see `rateslib `_. *FXRates* and *FXForwards* to allow auto-mutate detection of associated objects and ensure consistent method results. (`570 `_) + * - Refactor + - The internal data objects for *FXOption* pricing are restructured to conform to more + strict data typing. + (`642 `_) 1.6.0 (30th November 2024) **************************** diff --git a/python/rateslib/instruments/fx_volatility/strategies.py b/python/rateslib/instruments/fx_volatility/strategies.py index 175da5bb..992a167f 100644 --- a/python/rateslib/instruments/fx_volatility/strategies.py +++ b/python/rateslib/instruments/fx_volatility/strategies.py @@ -40,7 +40,7 @@ class FXOptionStrat: The multiplier for the *'vol'* metric that sums the options to a final *rate*. """ - _pricing: dict[str, Any] + _greeks: dict[str, Any] = {} _strat_elements: tuple[FXOption | FXOptionStrat, ...] periods: list[FXOption] @@ -733,7 +733,7 @@ def d_wrt_sigma1(period_index, greeks, smile_greeks, vol, eta1): iters += 1 if record_greeks: # this needs to be explicitly called since it degrades performance - self._pricing["strangle_greeks"] = { + self._greeks["strangle"] = { "single_vol": { "FXPut": self.periods[0].analytic_greeks(curves, solver, fx, base, vol=tgt_vol), "FXCall": self.periods[1].analytic_greeks( @@ -900,16 +900,16 @@ def _maybe_set_vega_neutral_notional(self, curves, solver, fx, base, vol, metric metric="single_vol", record_greeks=True, ) - self._pricing["straddle_greeks"] = self.periods[1].analytic_greeks( + self._greeks["straddle"] = self.periods[1].analytic_greeks( curves, solver, fx, base, vol=vol[1], ) - strangle_vega = self._pricing["strangle_greeks"]["market_vol"]["FXPut"]["vega"] - strangle_vega += self._pricing["strangle_greeks"]["market_vol"]["FXCall"]["vega"] - straddle_vega = self._pricing["straddle_greeks"]["vega"] + strangle_vega = self._greeks["strangle"]["market_vol"]["FXPut"]["vega"] + strangle_vega += self._greeks["strangle"]["market_vol"]["FXCall"]["vega"] + straddle_vega = self._greeks["straddle"]["vega"] scalar = strangle_vega / straddle_vega self.periods[1].kwargs["notional"] = float( self.periods[0].periods[0].periods[0].notional * -scalar, diff --git a/python/rateslib/instruments/fx_volatility/vanilla.py b/python/rateslib/instruments/fx_volatility/vanilla.py index 91ae2f84..edcf4c9f 100644 --- a/python/rateslib/instruments/fx_volatility/vanilla.py +++ b/python/rateslib/instruments/fx_volatility/vanilla.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABCMeta +from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING from pandas import DataFrame @@ -10,7 +11,7 @@ from rateslib.calendars import _get_fx_expiry_and_delivery, get_calendar from rateslib.curves import Curve from rateslib.curves._parsers import _validate_obj_not_no_input -from rateslib.default import NoInput, _drb, plot +from rateslib.default import NoInput, PlotOutput, _drb, plot from rateslib.dual.utils import _dual_float from rateslib.fx_volatility import FXVolObj from rateslib.instruments.sensitivities import Sensitivities @@ -23,8 +24,11 @@ from rateslib.periods import Cashflow, FXCallPeriod, FXPutPeriod if TYPE_CHECKING: + import numpy as np + from rateslib.typing import ( FX_, + NPV, Any, CalInput, Curves_, @@ -32,19 +36,20 @@ DualTypes, DualTypes_, FXVol_, - FXVolOption, FXVolOption_, Solver_, bool_, datetime_, + float_, int_, str_, ) -class PricingMetrics(NamedTuple): - vol: FXVolOption - k: DualTypes_ +@dataclass +class _PricingMetrics: + vol: FXVolOption_ + k: DualTypes delta_index: DualTypes | None spot: datetime t_e: DualTypes @@ -141,7 +146,7 @@ class FXOption(Sensitivities, metaclass=ABCMeta): style: str = "european" _rate_scalar: float = 1.0 - _pricing: dict[str, Any] = {} + _pricing: _PricingMetrics _option_periods: tuple[FXPutPeriod | FXCallPeriod] _cashflow_periods: tuple[Cashflow] @@ -279,45 +284,47 @@ def _set_strike_and_vol( # and some of the pricing elements associated with this strike definition must # be captured for use in subsequent formulae. fx_ = _validate_fx_as_forwards(fx) - vol_: FXVolOption = _validate_obj_not_no_input(vol, "vol") # type: ignore[assignment] + # vol_: FXVolOption = _validate_obj_not_no_input(vol, "vol") # type: ignore[assignment] + vol_ = vol curves_3: Curve = _validate_obj_not_no_input(curves[3], "curves[3]") curves_1: Curve = _validate_obj_not_no_input(curves[1], "curves[1]") - self._pricing = { - "vol": vol_, - "k": self.kwargs["strike"], - "delta_index": None, - "spot": fx_.pairs_settlement[self.kwargs["pair"]], - "t_e": self._option_periods[0]._t_to_expiry(curves_3.node_dates[0]), - "f_d": fx_.rate(self.kwargs["pair"], self.kwargs["delivery"]), - } + self._pricing = _PricingMetrics( + vol=vol_, + k=self.kwargs["strike"], + delta_index=None, + spot=fx_.pairs_settlement[self.kwargs["pair"]], + t_e=self._option_periods[0]._t_to_expiry(curves_3.node_dates[0]), + f_d=fx_.rate(self.kwargs["pair"], self.kwargs["delivery"]), + ) + w_deli = curves_1[self.kwargs["delivery"]] - w_spot = curves_1[self._pricing["spot"]] + w_spot = curves_1[self._pricing.spot] if isinstance(self.kwargs["strike"], str): method = self.kwargs["strike"].lower() if method == "atm_forward": - self._pricing["k"] = fx_.rate(self.kwargs["pair"], self.kwargs["delivery"]) + self._pricing.k = fx_.rate(self.kwargs["pair"], self.kwargs["delivery"]) elif method == "atm_spot": - self._pricing["k"] = fx_.rate(self.kwargs["pair"], self._pricing["spot"]) + self._pricing.k = fx_.rate(self.kwargs["pair"], self._pricing.spot) elif method == "atm_delta": - self._pricing["k"], self._pricing["delta_index"] = self._option_periods[ + self._pricing.k, self._pricing.delta_index = self._option_periods[ 0 ]._strike_and_index_from_atm( delta_type=self._option_periods[0].delta_type, - vol=vol_, + vol=_validate_obj_not_no_input(vol_, "vol"), # type: ignore[arg-type] w_deli=w_deli, w_spot=w_spot, - f=self._pricing["f_d"], - t_e=self._pricing["t_e"], + f=self._pricing.f_d, + t_e=self._pricing.t_e, ) elif method[-1] == "d": # representing delta # then strike is commanded by delta - self._pricing["k"], self._pricing["delta_index"] = self._option_periods[ + self._pricing.k, self._pricing.delta_index = self._option_periods[ 0 ]._strike_and_index_from_delta( delta=float(self.kwargs["strike"][:-1]) / 100.0, @@ -325,8 +332,8 @@ def _set_strike_and_vol( vol=vol_, w_deli=w_deli, w_spot=w_spot, - f=self._pricing["f_d"], - t_e=self._pricing["t_e"], + f=self._pricing.f_d, + t_e=self._pricing.t_e, ) # TODO: this may affect solvers dependent upon sensitivity to vol for changing strikes. @@ -335,20 +342,20 @@ def _set_strike_and_vol( # IRS for mid-market. # self.periods[0].strike = self._pricing["k"] - self._option_periods[0].strike = _dual_float(self._pricing["k"]) + self._option_periods[0].strike = _dual_float(self._pricing.k) if isinstance(vol_, FXVolObj): - if self._pricing["delta_index"] is None: - self._pricing["delta_index"], self._pricing["vol"], _ = vol_.get_from_strike( - k=self._pricing["k"], - f=self._pricing["f_d"], + if self._pricing.delta_index is None: + self._pricing.delta_index, self._pricing.vol, _ = vol_.get_from_strike( + k=self._pricing.k, + f=self._pricing.f_d, w_deli=w_deli, w_spot=w_spot, expiry=self.kwargs["expiry"], ) else: - self._pricing["vol"] = vol_._get_index( - self._pricing["delta_index"], + self._pricing.vol = vol_._get_index( + self._pricing.delta_index, self.kwargs["expiry"], ) @@ -361,7 +368,7 @@ def _set_premium(self, curves: Curves_DiscTuple, fx: FX_ = NoInput(0)) -> None: _validate_obj_not_no_input(curves[1], "curves[1]"), curves_3, fx, - vol=self._pricing["vol"], + vol=self._pricing.vol, local=False, ) except AttributeError: @@ -451,10 +458,15 @@ def rate( metric = _drb(self.kwargs["metric"], metric) if metric in ["vol", "single_vol"]: - return self._pricing["vol"] - - _: DualTypes = self.periods[0].rate( - curves_[1], curves_[3], fx_, NoInput(0), False, self._pricing["vol"] + return _validate_obj_not_no_input(self._pricing.vol, "vol") # type: ignore[return-value] + + _: DualTypes = self._option_periods[0].rate( + disc_curve=_validate_obj_not_no_input(curves_[1], "curve"), + disc_curve_ccy2=_validate_obj_not_no_input(curves_[3], "curve"), + fx=fx_, + base=NoInput(0), + local=False, + vol=self._pricing.vol, ) if metric == "premium": if self.periods[0].metric == "pips": @@ -470,8 +482,8 @@ def npv( fx: FX_ = NoInput(0), base: str_ = NoInput(0), local: bool = False, - vol: float = NoInput(0), - ): + vol: FXVol_ = NoInput(0), + ) -> NPV: """ Return the NPV of the *Option*. @@ -510,16 +522,23 @@ def npv( self._set_strike_and_vol(curves_, fx_, vol_) self._set_premium(curves_, fx_) - opt_npv = self._option_periods[0].npv(curves_[1], curves_[3], fx_, base_, local, vol_) + opt_npv = self._option_periods[0].npv( + disc_curve=_validate_obj_not_no_input(curves_[1], "curve_[1]"), + disc_curve_ccy2=_validate_obj_not_no_input(curves_[3], "curve_[3]"), + fx=fx_, + base=base_, + local=local, + vol=vol_, + ) if self.kwargs["premium_ccy"] == self.kwargs["pair"][:3]: disc_curve = curves_[1] else: disc_curve = curves_[3] prem_npv = self._cashflow_periods[0].npv(NoInput(0), disc_curve, fx, base, local) if local: - return {k: opt_npv.get(k, 0) + prem_npv.get(k, 0) for k in set(opt_npv) | set(prem_npv)} + return {k: opt_npv.get(k, 0) + prem_npv.get(k, 0) for k in set(opt_npv) | set(prem_npv)} # type:ignore[union-attr, arg-type] else: - return opt_npv + prem_npv + return opt_npv + prem_npv # type: ignore[operator] def cashflows( self, @@ -527,8 +546,8 @@ def cashflows( solver: Solver_ = NoInput(0), fx: FX_ = NoInput(0), base: str_ = NoInput(0), - vol: float = NoInput(0), - ): + vol: FXVol_ = NoInput(0), + ) -> DataFrame: """ Return the properties of all periods used in calculating cashflows. @@ -567,8 +586,14 @@ def cashflows( self._set_premium(curves_, fx_) seq = [ - self.periods[0].cashflows(curves_[1], curves_[3], fx_, base_, vol=vol_), - self.periods[1].cashflows(curves_[1], curves_[3], fx_, base_), + self._option_periods[0].cashflows( + disc_curve=_validate_obj_not_no_input(curves_[1], "curves_[1]"), + disc_curve_ccy2=_validate_obj_not_no_input(curves_[3], "curves_[3]"), + fx=fx_, + base=base_, + vol=vol_, + ), + self._cashflow_periods[0].cashflows(curves_[1], curves_[3], fx_, base_), ] return DataFrame.from_records(seq) @@ -579,8 +604,8 @@ def analytic_greeks( fx: FX_ = NoInput(0), base: str_ = NoInput(0), local: bool = False, - vol: float = NoInput(0), - ): + vol: FXVol_ = NoInput(0), + ) -> dict[str, Any]: """ Return various pricing metrics of the *FX Option*. @@ -620,12 +645,12 @@ def analytic_greeks( # self._set_premium(curves, fx) return self._option_periods[0].analytic_greeks( - curves_[1], - curves_[3], - fx_, - base_, - local, - vol_, + disc_curve=_validate_obj_not_no_input(curves_[1], "curves_[1]"), + disc_curve_ccy2=_validate_obj_not_no_input(curves_[3], "curves_[3]"), + fx=_validate_fx_as_forwards(fx_), + base=base_, + local=local, + vol=vol_, premium=NoInput(0), ) @@ -637,8 +662,10 @@ def _plot_payoff( fx: FX_ = NoInput(0), base: str_ = NoInput(0), local: bool = False, - vol: float = NoInput(0), - ): + vol: FXVol_ = NoInput(0), + ) -> tuple[ + np.ndarray[tuple[int], np.dtype[np.float64]], np.ndarray[tuple[int], np.dtype[np.float64]] + ]: """ Mechanics to determine (x,y) coordinates for payoff at expiry plot. """ @@ -666,10 +693,10 @@ def plot_payoff( fx: FX_ = NoInput(0), base: str_ = NoInput(0), local: bool = False, - vol: float = NoInput(0), - ): + vol: float_ = NoInput(0), + ) -> PlotOutput: x, y = self._plot_payoff(range, curves, solver, fx, base, local, vol) - return plot(x, [y]) + return plot(x, [y]) # type: ignore class FXCall(FXOption): @@ -679,7 +706,7 @@ class FXCall(FXOption): For parameters see :class:`~rateslib.instruments.FXOption`. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._option_periods = ( FXCallPeriod( @@ -705,7 +732,7 @@ class FXPut(FXOption): For parameters see :class:`~rateslib.instruments.FXOption`. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._option_periods = ( FXPutPeriod( diff --git a/python/rateslib/periods.py b/python/rateslib/periods.py index 051d75aa..bc5b6c46 100644 --- a/python/rateslib/periods.py +++ b/python/rateslib/periods.py @@ -74,6 +74,7 @@ Curve_, CurveOption_, DualTypes, + DualTypes_, FXVolOption, FXVolOption_, Number, @@ -3334,10 +3335,10 @@ def cashflows( self, disc_curve: Curve, disc_curve_ccy2: Curve, - fx: float | FXRates | FXForwards | NoInput = NoInput(0), - base: str | NoInput = NoInput(0), + fx: FX_ = NoInput(0), + base: str_ = NoInput(0), local: bool = False, - vol: DualTypes | FXVols | NoInput = NoInput(0), + vol: FXVolOption_ = NoInput(0), ) -> dict[str, Any]: """ Return the properties of the period used in calculating cashflows. @@ -3473,11 +3474,11 @@ def rate( self, disc_curve: Curve, disc_curve_ccy2: Curve, - fx: float | FXRates | FXForwards | NoInput = NoInput(0), - base: str | NoInput = NoInput(0), + fx: FX_ = NoInput(0), + base: str_ = NoInput(0), local: bool = False, - vol: DualTypes | FXVols | NoInput = NoInput(0), - metric: str | NoInput = NoInput(0), + vol: FXVolOption_ = NoInput(0), + metric: str_ = NoInput(0), ) -> DualTypes: """ Return the pricing metric of the *FXOption*. @@ -3597,10 +3598,10 @@ def analytic_greeks( disc_curve: Curve, disc_curve_ccy2: Curve, fx: FXForwards, - base: str | NoInput = NoInput(0), + base: str_ = NoInput(0), local: bool = False, - vol: DualTypes | FXVols | NoInput = NoInput(0), - premium: DualTypes | NoInput = NoInput(0), # expressed in the payment currency + vol: FXVolOption_ = NoInput(0), + premium: DualTypes_ = NoInput(0), # expressed in the payment currency ) -> dict[str, Any]: r""" Return the different greeks for the *FX Option*. @@ -3890,7 +3891,7 @@ def _strike_and_index_from_atm( w_deli: DualTypes, w_spot: DualTypes, f: DualTypes, - t_e: float, + t_e: DualTypes, ) -> tuple[DualTypes, DualTypes | None]: # TODO this method branches depending upon eta0 and eta1, but depending upon the # type of vol these maybe automatcially set equal to each other. Refactorin this would @@ -3951,7 +3952,7 @@ def _strike_and_index_from_delta( w_deli: DualTypes, w_spot: DualTypes, f: DualTypes, - t_e: float, + t_e: DualTypes, ) -> tuple[DualTypes, DualTypes | None]: vol_delta_type = _get_vol_delta_type(vol, delta_type) @@ -3993,7 +3994,7 @@ def _strike_and_index_from_delta( _2: DualTypes | None = delta_idx return _1, _2 - def _moneyness_from_atm_delta_closed_form(self, vol: DualTypes, t_e: float) -> DualTypes: + def _moneyness_from_atm_delta_closed_form(self, vol: DualTypes, t_e: DualTypes) -> DualTypes: """ Return `u` given premium unadjusted `delta`, of either 'spot' or 'forward' type. @@ -4291,7 +4292,7 @@ def _moneyness_from_atm_delta_two_dimensional( self, delta_type: str, vol: FXDeltaVolSmile, - t_e: float, + t_e: DualTypes, z_w: DualTypes, ) -> tuple[DualTypes, DualTypes]: def root2d( @@ -4454,7 +4455,7 @@ def root3d( def _get_vol_maybe_from_obj( self, - vol: FXVols | DualTypes | NoInput, + vol: FXVolOption_, fx: FXForwards, disc_curve: Curve, ) -> DualTypes: diff --git a/python/rateslib/typing.py b/python/rateslib/typing.py index c3075a70..dec53fc1 100644 --- a/python/rateslib/typing.py +++ b/python/rateslib/typing.py @@ -98,6 +98,7 @@ bool_: TypeAlias = "bool | NoInput" int_: TypeAlias = "int | NoInput" datetime_: TypeAlias = "datetime | NoInput" +float_: TypeAlias = "float | NoInput" from rateslib.curves import Curve as Curve # noqa: E402 diff --git a/python/tests/test_instruments.py b/python/tests/test_instruments.py index 583210d8..e4d79083 100644 --- a/python/tests/test_instruments.py +++ b/python/tests/test_instruments.py @@ -4785,6 +4785,21 @@ def test_expiry_delivery_tenor_eom(self, evald, eom, expected) -> None: ) assert fxo.kwargs["expiry"] == expected + def test_single_vol_not_no_input(self, fxfo): + fxo = FXCall( + pair="eurusd", + expiry=dt(2023, 6, 16), + delivery_lag=dt(2023, 6, 20), + payment_lag=dt(2023, 6, 20), + curves=[None, fxfo.curve("eur", "eur"), None, fxfo.curve("usd", "eur")], + delta_type="forward", + premium_ccy="usd", + strike=1.1, + notional=1e6, + ) + with pytest.raises(ValueError, match="`vol` must be supplied. Got"): + fxo.rate(metric="vol", fx=fxfo) + class TestRiskReversal: @pytest.mark.parametrize(