Skip to content

Commit

Permalink
REF: amend cache_states to decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 committed Jan 28, 2025
1 parent 7690b71 commit 577dafc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 49 deletions.
3 changes: 2 additions & 1 deletion python/rateslib/curves/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_new_state_post,
_validate_states,
_WithState,
_WithCache,
)
from rateslib.rs import Modifier, index_left_f64
from rateslib.rs import from_json as from_json_rs
Expand All @@ -56,7 +57,7 @@
# Contact rateslib at gmail.com if this code is observed outside its intended sphere.


class Curve(_WithState):
class Curve(_WithState, _WithCache):
"""
Curve based on DF parametrisation at given node dates with interpolation.
Expand Down
49 changes: 8 additions & 41 deletions python/rateslib/fx_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_clear_cache_post,
_new_state_post,
_WithState,
_WithCache,
)
from rateslib.rs import index_left_f64
from rateslib.splines import PPSplineDual, PPSplineDual2, PPSplineF64, evaluate
Expand All @@ -45,7 +46,7 @@
TERMINAL_DATE = dt(2100, 1, 1)


class FXDeltaVolSmile(_WithState):
class FXDeltaVolSmile(_WithState, _WithCache):
r"""
Create an *FX Volatility Smile* at a given expiry indexed by delta percent.
Expand Down Expand Up @@ -562,26 +563,6 @@ def plot(
return plot(x_as_u, y, labels)
return plot(x, y, labels)

# Cache management

def _clear_cache(self) -> None:
"""
Clear the cache of values on a *Smile* type.
Returns
-------
None
Notes
-----
This should be used if any modification has been made to the *Smile*.
Users are advised against making direct modification to *Curve* classes once
constructed to avoid the issue of un-cleared caches returning erroneous values.
Alternatively the curve caching as a feature can be set to *False* in ``defaults``.
"""
self._cache: dict[float, DualTypes] = dict()

# Mutation

def __set_nodes__(self, nodes: dict[float, DualTypes], ad: int) -> None:
Expand Down Expand Up @@ -793,7 +774,7 @@ def update_node(self, key: float, value: DualTypes) -> None:
# Serialization


class FXDeltaVolSurface(_WithState):
class FXDeltaVolSurface(_WithState, _WithCache):
r"""
Create an *FX Volatility Surface* parametrised by cross-sectional *Smiles* at different
expiries.
Expand Down Expand Up @@ -895,24 +876,9 @@ def __init__(

self._set_ad_order(ad) # includes csolve on each smile

@_new_state_post
def _clear_cache(self) -> None:
"""
Clear the cache of cross-sectional *Smiles* on a *Surface* type.
Returns
-------
None
Notes
-----
This should be used if any modification has been made to the *Surface*.
Users are advised against making direct modification to *Surface* classes once
constructed to avoid the issue of un-cleared caches returning erroneous values.
Alternatively set ``defaults.curve_caching`` to *False* to turn off global
caching in general.
"""
self._cache: dict[datetime, FXDeltaVolSmile] = dict()
self._set_new_state()
super()._clear_cache()

def _get_composited_state(self) -> int:
return hash(smile._state for smile in self.smiles)
Expand All @@ -926,20 +892,21 @@ def _maybe_add_to_cache(self, date: datetime, val: FXDeltaVolSmile) -> None:
if defaults.curve_caching:
self._cache[date] = val

@_clear_cache_post
def _set_ad_order(self, order: int) -> None:
self.ad = order
for smile in self.smiles:
smile._set_ad_order(order)
self._clear_cache()

@_new_state_post
@_clear_cache_post
def _set_node_vector(
self, vector: np.ndarray[tuple[int, ...], np.dtype[np.object_]], ad: int
) -> None:
m = len(self.delta_indexes)
for i in range(int(len(vector) / m)):
# smiles are indexed by expiry, shortest first
self.smiles[i]._set_node_vector(vector[i * m : i * m + m], ad)
self._clear_cache()

def _get_node_vector(self) -> np.ndarray[tuple[int, ...], np.dtype[np.object_]]:
"""Get a 1d array of variables associated with nodes of this object updated by Solver"""
Expand Down
14 changes: 7 additions & 7 deletions python/rateslib/mutability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@

from __future__ import annotations

import os
from collections import OrderedDict
from collections.abc import Callable
from typing import ParamSpec, TypeVar
from typing import ParamSpec, TypeVar, Generic

from rateslib import defaults

P = ParamSpec("P")
R = TypeVar("R")
KT = TypeVar("KT")
VT = TypeVar("VT")


def _validate_states(func: Callable[P, R]) -> Callable[P, R]:
Expand All @@ -38,7 +35,7 @@ def _clear_cache_post(func: Callable[P, R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
self = args[0]
result = func(*args, **kwargs)
self._clear_cache() # type: ignore[attr-defined]
self._clear_cache() # type: ignore[attr-defined]
return result

return wrapper
Expand All @@ -59,7 +56,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper


class _WithState:
class _WithState[KT, VT]:
"""
Record and manage the `state_id` of mutable classes.
Expand Down Expand Up @@ -95,7 +92,10 @@ def _get_composited_state(self) -> int:
objects and set this as the object's own state."""
raise NotImplementedError("Must be implemented for 'mutable by association' types")

def _cached_value(self, key: KT , val: VT) -> VT:

class _WithCache[KT, VT]:

def _cached_value(self, key: KT, val: VT) -> VT:
"""Used to add a value to the cache and control memory size when returning some
parameter from an object using cache and state management."""
if defaults.curve_caching and key not in self._cache:
Expand Down

0 comments on commit 577dafc

Please sign in to comment.