From 1d1601f57fa4eff6eb0a9888f7c015f14c7d8446 Mon Sep 17 00:00:00 2001 From: David Angulo Date: Wed, 20 Mar 2024 13:28:05 +0800 Subject: [PATCH] Use shape finder when it's not provided --- kneed/knee_locator.py | 19 +++++++++++++++---- tests/test_sample.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/kneed/knee_locator.py b/kneed/knee_locator.py index a83ed22..85083d5 100644 --- a/kneed/knee_locator.py +++ b/kneed/knee_locator.py @@ -2,6 +2,7 @@ from scipy import interpolate from scipy.signal import argrelextrema from typing import Tuple, Optional, Iterable +from shape_detector import find_shape VALID_CURVE = ["convex", "concave"] VALID_DIRECTION = ["increasing", "decreasing"] @@ -30,9 +31,9 @@ class KneeLocator(object): :type S: float :param curve: If 'concave', algorithm will detect knees. If 'convex', it will detect elbows. - :type curve: str + :type curve: Optional[str] :param direction: one of {"increasing", "decreasing"} - :type direction: str + :type direction: Optional[str] :param interp_method: one of {"interp1d", "polynomial"} :type interp_method: str :param online: kneed will correct old knee points if True, will return first knee if False @@ -133,8 +134,8 @@ def __init__( x: Iterable[float], y: Iterable[float], S: float = 1.0, - curve: str = "concave", - direction: str = "increasing", + curve: Optional[str] = None, + direction: Optional[str] = None, interp_method: str = "interp1d", online: bool = False, polynomial_degree: int = 7, @@ -153,6 +154,16 @@ def __init__( self.online = online self.polynomial_degree = polynomial_degree + # Use find_shape if it's not provided + if self.curve is None or self.direction is None: + direction, curve = find_shape(self.x, self.y) + + if self.curve is None: + self.curve = curve + + if self.direction is None: + self.direction = direction + # I'm implementing Look Before You Leap (LBYL) validation for direction # and curve arguments. This is not preferred in Python. The motivation # is that the logic inside the conditional once y_difference[j] is less diff --git a/tests/test_sample.py b/tests/test_sample.py index e47ef11..cf664cc 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -560,3 +560,39 @@ def test_find_shape(): direction, curve = find_shape(x, y) assert direction == "increasing" assert curve == "convex" + + +def test_missing_curve(): + """Test that find_shape is used when curve is not provided""" + x, y = dg.concave_increasing() + kl = KneeLocator(x, y) + + assert kl.curve is not None + assert kl.curve == "concave" + + +def test_missing_direction(): + """Test that find_shape is used when direction is not provided""" + x, y = dg.concave_increasing() + kl = KneeLocator(x, y) + + assert kl.direction is not None + assert kl.direction == "increasing" + + +def test_provided_curve(): + """Test that find_shape is not used when curve is provided""" + x, y = dg.concave_increasing() + kl = KneeLocator(x, y, curve="convex") + + assert kl.curve is not None + assert kl.curve == "convex" + + +def test_provided_direction(): + """Test that find_shape is not used when direction is provided""" + x, y = dg.concave_increasing() + kl = KneeLocator(x, y, direction="decreasing") + + assert kl.direction is not None + assert kl.direction == "decreasing"