From 114ef9d581f2f393fb6e6ed1c3e570f377149eb6 Mon Sep 17 00:00:00 2001 From: Mihhail Matvejev Date: Wed, 18 Dec 2024 00:08:05 +0300 Subject: [PATCH] small refactoring and adding new solution of kruskal's algorithm --- src/{ShuntingYard => }/__init__.py | 0 src/kruskals/__init__.py | 0 src/kruskals/disjoint_set.py | 23 ++++++ src/kruskals/edges.py | 8 ++ src/kruskals/graph.py | 32 ++++++++ src/kruskals/graph_reader.py | 56 ++++++++++++++ src/kruskals/input.txt | 4 + src/kruskals/mst_strategy.py | 26 +++++++ src/kruskals/output.txt | 3 + src/kruskals/solution.py | 38 +++++++++ src/shunting_yard/__init__.py | 0 .../solution.py | 0 src/tree/binary_tree.py | 2 +- test/test_disjoint_set.py | 33 ++++++++ test/test_graph.py | 50 ++++++++++++ test/test_graph_reader.py | 73 ++++++++++++++++++ test/test_mst_strategy.py | 77 +++++++++++++++++++ test/test_shunting_yard.py | 2 +- 18 files changed, 425 insertions(+), 2 deletions(-) rename src/{ShuntingYard => }/__init__.py (100%) create mode 100644 src/kruskals/__init__.py create mode 100644 src/kruskals/disjoint_set.py create mode 100644 src/kruskals/edges.py create mode 100644 src/kruskals/graph.py create mode 100644 src/kruskals/graph_reader.py create mode 100644 src/kruskals/input.txt create mode 100644 src/kruskals/mst_strategy.py create mode 100644 src/kruskals/output.txt create mode 100644 src/kruskals/solution.py create mode 100644 src/shunting_yard/__init__.py rename src/{ShuntingYard => shunting_yard}/solution.py (100%) create mode 100644 test/test_disjoint_set.py create mode 100644 test/test_graph.py create mode 100644 test/test_graph_reader.py create mode 100644 test/test_mst_strategy.py diff --git a/src/ShuntingYard/__init__.py b/src/__init__.py similarity index 100% rename from src/ShuntingYard/__init__.py rename to src/__init__.py diff --git a/src/kruskals/__init__.py b/src/kruskals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/kruskals/disjoint_set.py b/src/kruskals/disjoint_set.py new file mode 100644 index 0000000..7b309cd --- /dev/null +++ b/src/kruskals/disjoint_set.py @@ -0,0 +1,23 @@ +class DisjointSet: + def __init__(self, n: int): + self.parent = [i for i in range(n)] + self.rank = [0] * n + + def find(self, x: int) -> int: + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x: int, y: int) -> bool: + rx = self.find(x) + ry = self.find(y) + if rx == ry: + return False + if self.rank[rx] < self.rank[ry]: + self.parent[rx] = ry + elif self.rank[rx] > self.rank[ry]: + self.parent[ry] = rx + else: + self.parent[ry] = rx + self.rank[rx] += 1 + return True diff --git a/src/kruskals/edges.py b/src/kruskals/edges.py new file mode 100644 index 0000000..4386a17 --- /dev/null +++ b/src/kruskals/edges.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Edge: + u: int + v: int + weight: float diff --git a/src/kruskals/graph.py b/src/kruskals/graph.py new file mode 100644 index 0000000..5843598 --- /dev/null +++ b/src/kruskals/graph.py @@ -0,0 +1,32 @@ +from copy import copy +from typing import List + +from src.kruskals.edges import Edge + + +class Graph: + def __init__(self, vertices_count: int): + if vertices_count <= 0: + raise ValueError("Количество вершин <= 0") + self.vertices_count = vertices_count + self._edges: List[Edge] = [] + + @property + def edges(self) -> List[Edge]: + return copy(self._edges) + + def add_edge(self, begin: int, end: int, weight: float): + if ( + begin < 0 + or end < 0 + or begin >= self.vertices_count + or end >= self.vertices_count + ): + raise ValueError("Неверный индекс вершины") + if begin == end: + raise ValueError("Петли не допускаются") + if weight <= 0 or weight > 1023: + raise ValueError( + "Вес ребра должен быть в диапазоне [1, ..., 1023]" + ) + self._edges.append(Edge(begin, end, weight)) diff --git a/src/kruskals/graph_reader.py b/src/kruskals/graph_reader.py new file mode 100644 index 0000000..94c5f0f --- /dev/null +++ b/src/kruskals/graph_reader.py @@ -0,0 +1,56 @@ +from src.kruskals.graph import Graph + + +def _process_edge(i: int, j: int, val: int, graph: Graph): + if val == 0: + return + if val < 1 or val > 1023: + raise ValueError("Вес ребра должен быть в диапазоне [1..1023].") + if j > i: + graph.add_edge(i, j, val) + + +def _read_matrix_row(file_obj, vertices_count: int): + line = file_obj.readline().strip() + if not line: + raise ValueError("Строк в матрице меньше, чем количество вершин.") + + row_values = line.split() + if len(row_values) != vertices_count: + raise ValueError( + "Количество столбцов матрицы не соответствует количеству вершин." + ) + return row_values + + +def _read_adjacency_matrix(file_obj, graph: Graph): + vertices_count = graph.vertices_count + for i in range(vertices_count): + row_values = _read_matrix_row(file_obj, vertices_count) + for j, val_str in enumerate(row_values): + val = int(val_str) + _process_edge(i, j, val, graph) + + +class GraphReader: + def __init__(self, max_vertices=50): + self.max_vertices = max_vertices + + def read_graph_from_file(self, filename: str): + with open(filename, "r", encoding="utf-8") as f: + vertex_names = self._read_vertex_names(f) + graph = Graph(len(vertex_names)) + _read_adjacency_matrix(f, graph) + return graph, vertex_names + + def _read_vertex_names(self, file_obj): + first_line = file_obj.readline().strip() + if not first_line: + raise ValueError("File is corrupt") + + vertex_names = first_line.split() + if len(vertex_names) > self.max_vertices: + raise ValueError( + f"Слишком много вершин. Максимум {self.max_vertices}." + ) + return vertex_names diff --git a/src/kruskals/input.txt b/src/kruskals/input.txt new file mode 100644 index 0000000..317752d --- /dev/null +++ b/src/kruskals/input.txt @@ -0,0 +1,4 @@ +A B C +0 3 1 +3 0 2 +1 2 0 \ No newline at end of file diff --git a/src/kruskals/mst_strategy.py b/src/kruskals/mst_strategy.py new file mode 100644 index 0000000..8adbb2c --- /dev/null +++ b/src/kruskals/mst_strategy.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import List + +from src.kruskals.disjoint_set import DisjointSet +from src.kruskals.edges import Edge + + +class MSTStrategy(ABC): + @abstractmethod + def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]: + pass + + +class KruskalMST(MSTStrategy): + def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]: + sorted_edges = sorted(edges, key=lambda e: e.weight) + dsu = DisjointSet(vertices_count) + mst = [] + for edge in sorted_edges: + if dsu.union(edge.u, edge.v): + mst.append(edge) + if len(mst) == vertices_count - 1: + break + if len(mst) != vertices_count - 1: + return [] + return mst diff --git a/src/kruskals/output.txt b/src/kruskals/output.txt new file mode 100644 index 0000000..8799cd2 --- /dev/null +++ b/src/kruskals/output.txt @@ -0,0 +1,3 @@ +A C +B C +3 diff --git a/src/kruskals/solution.py b/src/kruskals/solution.py new file mode 100644 index 0000000..dfb2bf2 --- /dev/null +++ b/src/kruskals/solution.py @@ -0,0 +1,38 @@ +from src.kruskals.graph_reader import GraphReader +from src.kruskals.mst_strategy import KruskalMST + + +def main(): + input_file = "input.txt" + output_file = "output.txt" + + reader = GraphReader() + graph, vertex_names = reader.read_graph_from_file(input_file) + + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + + if not mst: + with open(output_file, "w", encoding="utf-8") as out: + out.write("Остов не существует\n") + return + + named_edges = [] + for edge in mst: + v1_name = vertex_names[edge.u] + v2_name = vertex_names[edge.v] + if v1_name > v2_name: + v1_name, v2_name = v2_name, v1_name + named_edges.append((v1_name, v2_name, edge.weight)) + + named_edges.sort(key=lambda x: (x[0], x[1])) + total_weight = sum(edge[2] for edge in named_edges) + + with open(output_file, "w", encoding="utf-8") as out: + for u_name, v_name, w in named_edges: + out.write(f"{u_name} {v_name}\n") + out.write(f"{total_weight}\n") + + +if __name__ == "__main__": + main() diff --git a/src/shunting_yard/__init__.py b/src/shunting_yard/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ShuntingYard/solution.py b/src/shunting_yard/solution.py similarity index 100% rename from src/ShuntingYard/solution.py rename to src/shunting_yard/solution.py diff --git a/src/tree/binary_tree.py b/src/tree/binary_tree.py index f028169..015601c 100644 --- a/src/tree/binary_tree.py +++ b/src/tree/binary_tree.py @@ -1,4 +1,4 @@ -from .traversal import TraversalStrategy +from src.tree.traversal import TraversalStrategy class BinaryTree: diff --git a/test/test_disjoint_set.py b/test/test_disjoint_set.py new file mode 100644 index 0000000..2afcb18 --- /dev/null +++ b/test/test_disjoint_set.py @@ -0,0 +1,33 @@ +import unittest + +from src.kruskals.disjoint_set import DisjointSet + + +class TestDisjointSet(unittest.TestCase): + def test_init(self): + dsu = DisjointSet(5) + for i in range(5): + self.assertEqual(dsu.find(i), i) + + def test_union_find(self): + dsu = DisjointSet(5) + dsu.union(0, 1) + self.assertEqual(dsu.find(0), dsu.find(1)) + dsu.union(1, 2) + self.assertEqual(dsu.find(0), dsu.find(2)) + self.assertNotEqual(dsu.find(3), dsu.find(0)) + + def test_large(self): + n = 50 + dsu = DisjointSet(n) + for i in range(n - 1): + dsu.union(i, i + 1) + root = dsu.find(0) + for i in range(1, n): + self.assertEqual(dsu.find(i), root) + + def test_no_unions(self): + n = 10 + dsu = DisjointSet(n) + for i in range(n): + self.assertEqual(dsu.find(i), i) diff --git a/test/test_graph.py b/test/test_graph.py new file mode 100644 index 0000000..c7c2fe6 --- /dev/null +++ b/test/test_graph.py @@ -0,0 +1,50 @@ +import unittest + +from src.kruskals.graph import Graph + + +class TestGraph(unittest.TestCase): + def test_init(self): + with self.assertRaises(ValueError): + Graph(0) + graph = Graph(3) + self.assertEqual(graph.vertices_count, 3) + + def test_add_edge_valid(self): + graph = Graph(3) + graph.add_edge(0, 1, 1) + self.assertEqual(len(graph.edges), 1) + + def test_add_edge_invalid_index(self): + graph = Graph(3) + with self.assertRaises(ValueError): + graph.add_edge(-1, 0, 1) + with self.assertRaises(ValueError): + graph.add_edge(0, 3, 1) + + def test_add_loop_edge(self): + graph = Graph(3) + with self.assertRaises(ValueError): + graph.add_edge(0, 0, 1) + + def test_add_negative_or_zero_weight(self): + graph = Graph(3) + with self.assertRaises(ValueError): + graph.add_edge(0, 1, -1) + with self.assertRaises(ValueError): + graph.add_edge(0, 1, 0) + + def test_max_weight(self): + graph = Graph(3) + graph.add_edge(0, 1, 1023) + self.assertEqual(len(graph.edges), 1) + self.assertEqual(graph.edges[0].weight, 1023) + with self.assertRaises(ValueError): + graph.add_edge(1, 2, 1024) + + def test_large_graph(self): + n = 50 + graph = Graph(n) + for i in range(n - 1): + graph.add_edge(i, i + 1, i + 1) + self.assertEqual(len(graph.edges), n - 1) diff --git a/test/test_graph_reader.py b/test/test_graph_reader.py new file mode 100644 index 0000000..450517a --- /dev/null +++ b/test/test_graph_reader.py @@ -0,0 +1,73 @@ +import os +import unittest + +from src.kruskals.graph_reader import GraphReader + + +class TestGraphReader(unittest.TestCase): + def setUp(self): + self.test_filename = "input.txt" + + def tearDown(self): + if os.path.exists(self.test_filename): + os.remove(self.test_filename) + + def test_empty_file(self): + open(self.test_filename, "w").close() + reader = GraphReader() + with self.assertRaises(ValueError): + reader.read_graph_from_file(self.test_filename) + + def test_wrong_format(self): + with open(self.test_filename, "w") as file: + file.write("A B C\n0 1 2\n0 0") + reader = GraphReader() + with self.assertRaises(ValueError): + reader.read_graph_from_file(self.test_filename) + + def test_too_many_vertices(self): + names = " ".join([f"V{i}" for i in range(51)]) + with open(self.test_filename, "w") as file: + file.write(names + "\n") + reader = GraphReader() + with self.assertRaises(ValueError): + reader.read_graph_from_file(self.test_filename) + + def test_invalid_weight(self): + with open(self.test_filename, "w") as file: + file.write("A B\n0 1024\n1024 0\n") + reader = GraphReader() + with self.assertRaises(ValueError): + reader.read_graph_from_file(self.test_filename) + + def test_valid_graph(self): + with open(self.test_filename, "w") as file: + file.write("A B C\n0 3 1\n3 0 2\n1 2 0\n") + reader = GraphReader() + graph, names = reader.read_graph_from_file(self.test_filename) + self.assertEqual(len(names), 3) + self.assertEqual(graph.vertices_count, 3) + self.assertEqual(len(graph.edges), 3) + weights = sorted(edge.weight for edge in graph.edges) + self.assertEqual(weights, [1, 2, 3]) + + def test_full_graph(self): + """ + Полный граф из 4 вершин + Вершины: A B C D + Матрица: + 0 1 2 3 + 1 0 4 5 + 2 4 0 6 + """ + + with open(self.test_filename, "w") as file: + file.write("A B C D\n") + file.write("0 1 2 3\n") + file.write("1 0 4 5\n") + file.write("2 4 0 6\n") + file.write("3 5 6 0\n") + reader = GraphReader() + g, names = reader.read_graph_from_file(self.test_filename) + self.assertEqual(len(g.edges), 6) + self.assertEqual(sorted(names), ["A", "B", "C", "D"]) diff --git a/test/test_mst_strategy.py b/test/test_mst_strategy.py new file mode 100644 index 0000000..ec88a91 --- /dev/null +++ b/test/test_mst_strategy.py @@ -0,0 +1,77 @@ +import unittest + +from src.kruskals.graph import Graph +from src.kruskals.mst_strategy import KruskalMST + + +class TestKruskalMST(unittest.TestCase): + def test_empty_graph(self): + graph = Graph(1) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(mst, []) + + def test_single_vertex(self): + graph = Graph(1) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(mst, []) + + def test_disconnected_graph(self): + graph = Graph(4) + graph.add_edge(0, 1, 1) + graph.add_edge(2, 3, 2) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(mst, []) + + def test_normal_graph(self): + graph = Graph(4) + graph.add_edge(0, 1, 1) + graph.add_edge(0, 2, 2) + graph.add_edge(1, 2, 4) + graph.add_edge(2, 3, 3) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(len(mst), 3) + self.assertEqual(sum(e.weight for e in mst), 6.0) + + def test_equal_weights(self): + graph = Graph(3) + graph.add_edge(0, 1, 1) + graph.add_edge(1, 2, 1) + graph.add_edge(0, 2, 1) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(len(mst), 2) + self.assertEqual(sum(e.weight for e in mst), 2) + + def test_large_line_graph(self): + n = 50 + graph = Graph(n) + for i in range(n - 1): + graph.add_edge(i, i + 1, i + 1) + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(len(mst), n - 1) + self.assertEqual(sum(e.weight for e in mst), sum(range(1, n))) + + def test_large_complete_graph(self): + """ + Полный граф из 5 вершин + Веса постепенно растут + Вершины: 0,1,2,3,4 + C(5,2) = 10 рёбер + MST в таком случае - это 4 ребра с наименьшими весами + """ + + graph = Graph(5) + weight = 1 + for i in range(5): + for j in range(i + 1, 5): + graph.add_edge(i, j, weight) + weight += 1 + kruskal = KruskalMST() + mst = kruskal.find_mst(graph.vertices_count, graph.edges) + self.assertEqual(len(mst), 4) + self.assertEqual(sum(e.weight for e in mst), 10) diff --git a/test/test_shunting_yard.py b/test/test_shunting_yard.py index b05806d..90dd792 100644 --- a/test/test_shunting_yard.py +++ b/test/test_shunting_yard.py @@ -1,6 +1,6 @@ import unittest -from src.ShuntingYard.solution import Solution +from src.shunting_yard.solution import Solution class MyTestCase(unittest.TestCase):