Skip to content

Commit

Permalink
Ex 03 solution
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Habernal committed Nov 2, 2024
1 parent 6805f28 commit 9084ca5
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 15 deletions.
2 changes: 2 additions & 0 deletions exercises/ex03/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
nlpwdlfw.egg-info
Empty file.
43 changes: 30 additions & 13 deletions exercises/ex03/nlpwdlfw/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, arguments: List['ScalarNode']) -> None:
for arg in self._arguments:
arg._parents.append(self)
# --- TASK_6 ---

self._cache = ScalarNodeCache()
# --- TASK_6 ---

def value(self) -> float:
Expand All @@ -32,7 +32,8 @@ def find_self_position_in_parents_arguments(self, parent: 'ScalarNode') -> int:

def global_derivative_wrt_self(self) -> float:
# --- TASK_6 ---

if self._cache.global_derivative_wrt_self is not None:
return self._cache.global_derivative_wrt_self
# --- TASK_6 ---

if len(self._parents) == 0:
Expand All @@ -42,11 +43,16 @@ def global_derivative_wrt_self(self) -> float:
result = 0.0
# --- TASK_5 ---
# multiply and add (generalized chain rule)
for p in self._parents:
index_in_parents_arguments = self.find_self_position_in_parents_arguments(p)
parent_to_self_derivative = p.local_partial_derivatives_wrt_arguments()[index_in_parents_arguments]
parent_global_derivative = p.global_derivative_wrt_self()

result += parent_to_self_derivative * parent_global_derivative
# --- TASK_5 ---

# --- TASK_6 ---

self._cache.global_derivative_wrt_self = result
# --- TASK_6 ---

return result
Expand All @@ -66,18 +72,20 @@ class SumNode(ScalarNode):

def value(self) -> float:
# --- TASK_6 ---

if self._cache.value is not None:
return self._cache.value
# --- TASK_6 ---

result = 0.0
# sum all arguments values

# --- TASK_2 ---

for arg in self._arguments:
result += arg.value()
# --- TASK_2 ---

# --- TASK_6 ---

self._cache.value = result
# --- TASK_6 ---

return result
Expand All @@ -90,33 +98,36 @@ def local_partial_derivatives_wrt_arguments(self) -> List[float]:
# dy/dw_3 = 1

# --- TASK_2 ---
return None
return [1.0] * len(self._arguments)
# --- TASK_2 ---


class ProductNode(ScalarNode):

def value(self) -> float:
# --- TASK_6 ---

if self._cache.value is not None:
return self._cache.value
# --- TASK_6 ---

result = 1.0
# multiply all arguments values

# --- TASK_3 ---

for arg in self._arguments:
result *= arg.value()
# --- TASK_3 ---

# --- TASK_6 ---

self._cache.value = result
# --- TASK_6 ---

return result

def local_partial_derivatives_wrt_arguments(self) -> List[float]:
# --- TASK_6 ---

if self._cache.local_partial_derivatives_wrt_arguments is not None:
return self._cache.local_partial_derivatives_wrt_arguments
# --- TASK_6 ---

# Partial derivative wrt. each argument is a product of all other arguments, for example
Expand All @@ -129,11 +140,17 @@ def local_partial_derivatives_wrt_arguments(self) -> List[float]:
result = [0.0] * len(self._arguments)

# --- TASK_3 ---

for i in range(len(self._arguments)):
ith_result = 1.0
for j in range(len(self._arguments)):
if i != j:
j_value = self._arguments[j].value()
ith_result *= j_value
result[i] = ith_result
# --- TASK_3 ---

# --- TASK_6 ---

self._cache.local_partial_derivatives_wrt_arguments = result
# --- TASK_6 ---

return result
Expand Down
Empty file removed exercises/ex03/tests/__init__.py
Empty file.
9 changes: 7 additions & 2 deletions exercises/ex03/tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestNodes(TestCase):
def test_task1(self):
value = 2
# --- TASK_1 ---

value = 1
# --- TASK_1 ---
self.assertEqual(1, value)

Expand Down Expand Up @@ -47,7 +47,12 @@ def test_task4(self):
# dummy initialization
a, b, one, r, s, e = [ConstantNode(1)] * 6
# --- TASK_4 ---

a = ConstantNode(2)
b = ConstantNode(3)
one = ConstantNode(1)
r = SumNode([a, b])
s = SumNode([b, one])
e = ProductNode([r, s])
# --- TASK_4 ---

# test that b has two parents
Expand Down

0 comments on commit 9084ca5

Please sign in to comment.