From aa9ed29586032b32a865fcde1051aa9ac7baf9bd Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Thu, 16 Jan 2025 15:22:04 -0800 Subject: [PATCH] Add to_dict method to models for serialization Closes #26 --- minfraud/models.py | 75 +++++++++++++++++++++++++++++++------------- tests/test_models.py | 54 ++++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 29 deletions(-) diff --git a/minfraud/models.py b/minfraud/models.py index ce6c2b9..3d6efe8 100644 --- a/minfraud/models.py +++ b/minfraud/models.py @@ -14,7 +14,31 @@ import geoip2.records -class IPRiskReason(SimpleEquality): +class _Serializable(SimpleEquality): + def to_dict(self): + """Returns a dict of the object suitable for serialization""" + result = {} + for key, value in self.__dict__.items(): + if hasattr(value, "to_dict") and callable(value.to_dict): + result[key] = value.to_dict() + elif hasattr(value, "raw"): + # geoip2 uses "raw" for historical reasons + result[key] = value.raw + elif isinstance(value, list): + result[key] = [ + ( + item.to_dict() + if hasattr(item, "to_dict") and callable(item.to_dict) + else item + ) + for item in value + ] + else: + result[key] = value + return result + + +class IPRiskReason(_Serializable): """Reason for the IP risk. This class provides both a machine-readable code and a human-readable @@ -202,7 +226,7 @@ class IPAddress(geoip2.models.Insights): def __init__( self, - locales: Sequence[str], + locales: Optional[Sequence[str]], *, country: Optional[Dict] = None, location: Optional[Dict] = None, @@ -210,15 +234,24 @@ def __init__( risk_reasons: Optional[List[Dict]] = None, **kwargs, ) -> None: - - super().__init__(kwargs, locales=list(locales)) + # For raw attribute + if country is not None: + kwargs["country"] = country + if location is not None: + kwargs["location"] = location + if risk is not None: + kwargs["risk"] = risk + if risk_reasons is not None: + kwargs["risk_reasons"] = risk_reasons + + super().__init__(kwargs, locales=list(locales or [])) self.country = GeoIP2Country(locales, **(country or {})) self.location = GeoIP2Location(**(location or {})) self.risk = risk self.risk_reasons = [IPRiskReason(**x) for x in risk_reasons or []] -class ScoreIPAddress(SimpleEquality): +class ScoreIPAddress(_Serializable): """Information about the IP address for minFraud Score. .. attribute:: risk @@ -235,7 +268,7 @@ def __init__(self, *, risk: Optional[float] = None, **_): self.risk = risk -class Issuer(SimpleEquality): +class Issuer(_Serializable): """Information about the credit card issuer. .. attribute:: name @@ -293,7 +326,7 @@ def __init__( self.matches_provided_phone_number = matches_provided_phone_number -class Device(SimpleEquality): +class Device(_Serializable): """Information about the device associated with the IP address. In order to receive device output from minFraud Insights or minFraud @@ -353,7 +386,7 @@ def __init__( self.local_time = local_time -class Disposition(SimpleEquality): +class Disposition(_Serializable): """Information about disposition for the request as set by custom rules. In order to receive a disposition, you must be use the minFraud custom @@ -402,7 +435,7 @@ def __init__( self.rule_label = rule_label -class EmailDomain(SimpleEquality): +class EmailDomain(_Serializable): """Information about the email domain passed in the request. .. attribute:: first_seen @@ -421,7 +454,7 @@ def __init__(self, *, first_seen: Optional[str] = None, **_): self.first_seen = first_seen -class Email(SimpleEquality): +class Email(_Serializable): """Information about the email address passed in the request. .. attribute:: domain @@ -484,7 +517,7 @@ def __init__( self.is_high_risk = is_high_risk -class CreditCard(SimpleEquality): +class CreditCard(_Serializable): """Information about the credit card based on the issuer ID number. .. attribute:: country @@ -578,7 +611,7 @@ def __init__( self.type = type -class BillingAddress(SimpleEquality): +class BillingAddress(_Serializable): """Information about the billing address. .. attribute:: distance_to_ip_location @@ -644,7 +677,7 @@ def __init__( self.is_in_ip_country = is_in_ip_country -class ShippingAddress(SimpleEquality): +class ShippingAddress(_Serializable): """Information about the shipping address. .. attribute:: distance_to_ip_location @@ -733,7 +766,7 @@ def __init__( self.distance_to_billing_address = distance_to_billing_address -class Phone(SimpleEquality): +class Phone(_Serializable): """Information about the billing or shipping phone number. .. attribute:: country @@ -790,7 +823,7 @@ def __init__( self.number_type = number_type -class ServiceWarning(SimpleEquality): +class ServiceWarning(_Serializable): """Warning from the web service. .. attribute:: code @@ -837,7 +870,7 @@ def __init__( self.input_pointer = input_pointer -class Subscores(SimpleEquality): +class Subscores(_Serializable): """Risk factor scores used in calculating the overall risk score. .. deprecated:: 2.12.0 @@ -1081,7 +1114,7 @@ def __init__( self.time_of_day = time_of_day -class Reason(SimpleEquality): +class Reason(_Serializable): """The risk score reason for the multiplier. This class provides both a machine-readable code and a human-readable @@ -1174,7 +1207,7 @@ def __init__( self.reason = reason -class RiskScoreReason(SimpleEquality): +class RiskScoreReason(_Serializable): """The risk score multiplier and the reasons for that multiplier. .. attribute:: multiplier @@ -1209,7 +1242,7 @@ def __init__( self.reasons = [Reason(**x) for x in reasons or []] -class Factors(SimpleEquality): +class Factors(_Serializable): """Model for Factors response. .. attribute:: id @@ -1397,7 +1430,7 @@ def __init__( ] -class Insights(SimpleEquality): +class Insights(_Serializable): """Model for Insights response. .. attribute:: id @@ -1557,7 +1590,7 @@ def __init__( self.warnings = [ServiceWarning(**x) for x in warnings or []] -class Score(SimpleEquality): +class Score(_Serializable): """Model for Score response. .. attribute:: id diff --git a/tests/test_models.py b/tests/test_models.py index 83e0b10..b2792e2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,9 @@ class TestModels(unittest.TestCase): + def setUp(self): + self.maxDiff = 20_000 + def test_billing_address(self): address = BillingAddress(**self.address_dict) self.check_address(address) @@ -261,14 +264,15 @@ def test_risk_score_reason(self): def test_score(self): id = "b643d445-18b2-4b9d-bad4-c9c4366e402a" - score = Score( - id=id, - funds_remaining=10.01, - queries_remaining=123, - risk_score=0.01, - ip_address={"risk": 99}, - warnings=[{"code": "INVALID_INPUT"}], - ) + response = { + "id": id, + "funds_remaining": 10.01, + "queries_remaining": 123, + "risk_score": 0.01, + "ip_address": {"risk": 99}, + "warnings": [{"code": "INVALID_INPUT"}], + } + score = Score(**response) self.assertEqual(id, score.id) self.assertEqual(10.01, score.funds_remaining) @@ -277,11 +281,15 @@ def test_score(self): self.assertEqual("INVALID_INPUT", score.warnings[0].code) self.assertEqual(99, score.ip_address.risk) + self.assertEqual(response, self._remove_empty_values(score.to_dict())) + def test_insights(self): response = self.factors_response() + del response["risk_score_reasons"] del response["subscores"] insights = Insights(None, **response) self.check_insights_data(insights, response["id"]) + self.assertEqual(response, self._remove_empty_values(insights.to_dict())) def test_factors(self): response = self.factors_response() @@ -313,6 +321,8 @@ def test_factors(self): ) self.assertEqual(0.17, factors.subscores.time_of_day) + self.assertEqual(response, self._remove_empty_values(factors.to_dict())) + def factors_response(self): return { "id": "b643d445-18b2-4b9d-bad4-c9c4366e402a", @@ -399,3 +409,31 @@ def check_risk_score_reasons_data(self, reasons): self.assertEqual( "Risk due to IP being an Anonymous IP", reasons[0].reasons[0].reason ) + + def _remove_empty_values(self, data): + if isinstance(data, dict): + m = {} + for k, v in data.items(): + v = self._remove_empty_values(v) + if self._is_not_empty(v): + m[k] = v + return m + + if isinstance(data, list): + ls = [] + for e in data: + e = self._remove_empty_values(e) + if self._is_not_empty(e): + ls.append(e) + return ls + + return data + + def _is_not_empty(self, v): + if v is None: + return False + if isinstance(v, dict) and not v: + return False + if isinstance(v, list) and not v: + return False + return True