diff --git a/exercises/ex03/.gitignore b/exercises/ex03/.gitignore new file mode 100644 index 0000000..f1f00e5 --- /dev/null +++ b/exercises/ex03/.gitignore @@ -0,0 +1,2 @@ +.idea +nlpwdlfw.egg-info \ No newline at end of file diff --git a/exercises/ex03/nlpwdlfw/__init__.py b/exercises/ex03/nlpwdlfw/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/exercises/ex03/nlpwdlfw/nodes.py b/exercises/ex03/nlpwdlfw/nodes.py index b7cb48a..9f915e5 100644 --- a/exercises/ex03/nlpwdlfw/nodes.py +++ b/exercises/ex03/nlpwdlfw/nodes.py @@ -15,7 +15,7 @@ def __init__(self, arguments: List['ScalarNode']) -> None: for arg in self._arguments: arg._parents.append(self) # --- TASK_6 --- - + self._cache = ScalarNodeCache() # --- TASK_6 --- def value(self) -> float: @@ -32,7 +32,8 @@ def find_self_position_in_parents_arguments(self, parent: 'ScalarNode') -> int: def global_derivative_wrt_self(self) -> float: # --- TASK_6 --- - + if self._cache.global_derivative_wrt_self is not None: + return self._cache.global_derivative_wrt_self # --- TASK_6 --- if len(self._parents) == 0: @@ -42,11 +43,16 @@ def global_derivative_wrt_self(self) -> float: result = 0.0 # --- TASK_5 --- # multiply and add (generalized chain rule) + for p in self._parents: + index_in_parents_arguments = self.find_self_position_in_parents_arguments(p) + parent_to_self_derivative = p.local_partial_derivatives_wrt_arguments()[index_in_parents_arguments] + parent_global_derivative = p.global_derivative_wrt_self() + result += parent_to_self_derivative * parent_global_derivative # --- TASK_5 --- # --- TASK_6 --- - + self._cache.global_derivative_wrt_self = result # --- TASK_6 --- return result @@ -66,18 +72,20 @@ class SumNode(ScalarNode): def value(self) -> float: # --- TASK_6 --- - + if self._cache.value is not None: + return self._cache.value # --- TASK_6 --- result = 0.0 # sum all arguments values # --- TASK_2 --- - + for arg in self._arguments: + result += arg.value() # --- TASK_2 --- # --- TASK_6 --- - + self._cache.value = result # --- TASK_6 --- return result @@ -90,7 +98,7 @@ def local_partial_derivatives_wrt_arguments(self) -> List[float]: # dy/dw_3 = 1 # --- TASK_2 --- - return None + return [1.0] * len(self._arguments) # --- TASK_2 --- @@ -98,25 +106,28 @@ class ProductNode(ScalarNode): def value(self) -> float: # --- TASK_6 --- - + if self._cache.value is not None: + return self._cache.value # --- TASK_6 --- result = 1.0 # multiply all arguments values # --- TASK_3 --- - + for arg in self._arguments: + result *= arg.value() # --- TASK_3 --- # --- TASK_6 --- - + self._cache.value = result # --- TASK_6 --- return result def local_partial_derivatives_wrt_arguments(self) -> List[float]: # --- TASK_6 --- - + if self._cache.local_partial_derivatives_wrt_arguments is not None: + return self._cache.local_partial_derivatives_wrt_arguments # --- TASK_6 --- # Partial derivative wrt. each argument is a product of all other arguments, for example @@ -129,11 +140,17 @@ def local_partial_derivatives_wrt_arguments(self) -> List[float]: result = [0.0] * len(self._arguments) # --- TASK_3 --- - + for i in range(len(self._arguments)): + ith_result = 1.0 + for j in range(len(self._arguments)): + if i != j: + j_value = self._arguments[j].value() + ith_result *= j_value + result[i] = ith_result # --- TASK_3 --- # --- TASK_6 --- - + self._cache.local_partial_derivatives_wrt_arguments = result # --- TASK_6 --- return result diff --git a/exercises/ex03/tests/__init__.py b/exercises/ex03/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/exercises/ex03/tests/test_nodes.py b/exercises/ex03/tests/test_nodes.py index 41205c9..7dca665 100644 --- a/exercises/ex03/tests/test_nodes.py +++ b/exercises/ex03/tests/test_nodes.py @@ -8,7 +8,7 @@ class TestNodes(TestCase): def test_task1(self): value = 2 # --- TASK_1 --- - + value = 1 # --- TASK_1 --- self.assertEqual(1, value) @@ -47,7 +47,12 @@ def test_task4(self): # dummy initialization a, b, one, r, s, e = [ConstantNode(1)] * 6 # --- TASK_4 --- - + a = ConstantNode(2) + b = ConstantNode(3) + one = ConstantNode(1) + r = SumNode([a, b]) + s = SumNode([b, one]) + e = ProductNode([r, s]) # --- TASK_4 --- # test that b has two parents