-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from m4tveevm/KruskalsAlgo
small refactoring and adding new solution of kruskal's algorithm
- Loading branch information
Showing
18 changed files
with
425 additions
and
2 deletions.
There are no files selected for viewing
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
class DisjointSet: | ||
def __init__(self, n: int): | ||
self.parent = [i for i in range(n)] | ||
self.rank = [0] * n | ||
|
||
def find(self, x: int) -> int: | ||
if self.parent[x] != x: | ||
self.parent[x] = self.find(self.parent[x]) | ||
return self.parent[x] | ||
|
||
def union(self, x: int, y: int) -> bool: | ||
rx = self.find(x) | ||
ry = self.find(y) | ||
if rx == ry: | ||
return False | ||
if self.rank[rx] < self.rank[ry]: | ||
self.parent[rx] = ry | ||
elif self.rank[rx] > self.rank[ry]: | ||
self.parent[ry] = rx | ||
else: | ||
self.parent[ry] = rx | ||
self.rank[rx] += 1 | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Edge: | ||
u: int | ||
v: int | ||
weight: float |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from copy import copy | ||
from typing import List | ||
|
||
from src.kruskals.edges import Edge | ||
|
||
|
||
class Graph: | ||
def __init__(self, vertices_count: int): | ||
if vertices_count <= 0: | ||
raise ValueError("Количество вершин <= 0") | ||
self.vertices_count = vertices_count | ||
self._edges: List[Edge] = [] | ||
|
||
@property | ||
def edges(self) -> List[Edge]: | ||
return copy(self._edges) | ||
|
||
def add_edge(self, begin: int, end: int, weight: float): | ||
if ( | ||
begin < 0 | ||
or end < 0 | ||
or begin >= self.vertices_count | ||
or end >= self.vertices_count | ||
): | ||
raise ValueError("Неверный индекс вершины") | ||
if begin == end: | ||
raise ValueError("Петли не допускаются") | ||
if weight <= 0 or weight > 1023: | ||
raise ValueError( | ||
"Вес ребра должен быть в диапазоне [1, ..., 1023]" | ||
) | ||
self._edges.append(Edge(begin, end, weight)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from src.kruskals.graph import Graph | ||
|
||
|
||
def _process_edge(i: int, j: int, val: int, graph: Graph): | ||
if val == 0: | ||
return | ||
if val < 1 or val > 1023: | ||
raise ValueError("Вес ребра должен быть в диапазоне [1..1023].") | ||
if j > i: | ||
graph.add_edge(i, j, val) | ||
|
||
|
||
def _read_matrix_row(file_obj, vertices_count: int): | ||
line = file_obj.readline().strip() | ||
if not line: | ||
raise ValueError("Строк в матрице меньше, чем количество вершин.") | ||
|
||
row_values = line.split() | ||
if len(row_values) != vertices_count: | ||
raise ValueError( | ||
"Количество столбцов матрицы не соответствует количеству вершин." | ||
) | ||
return row_values | ||
|
||
|
||
def _read_adjacency_matrix(file_obj, graph: Graph): | ||
vertices_count = graph.vertices_count | ||
for i in range(vertices_count): | ||
row_values = _read_matrix_row(file_obj, vertices_count) | ||
for j, val_str in enumerate(row_values): | ||
val = int(val_str) | ||
_process_edge(i, j, val, graph) | ||
|
||
|
||
class GraphReader: | ||
def __init__(self, max_vertices=50): | ||
self.max_vertices = max_vertices | ||
|
||
def read_graph_from_file(self, filename: str): | ||
with open(filename, "r", encoding="utf-8") as f: | ||
vertex_names = self._read_vertex_names(f) | ||
graph = Graph(len(vertex_names)) | ||
_read_adjacency_matrix(f, graph) | ||
return graph, vertex_names | ||
|
||
def _read_vertex_names(self, file_obj): | ||
first_line = file_obj.readline().strip() | ||
if not first_line: | ||
raise ValueError("File is corrupt") | ||
|
||
vertex_names = first_line.split() | ||
if len(vertex_names) > self.max_vertices: | ||
raise ValueError( | ||
f"Слишком много вершин. Максимум {self.max_vertices}." | ||
) | ||
return vertex_names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
A B C | ||
0 3 1 | ||
3 0 2 | ||
1 2 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from src.kruskals.disjoint_set import DisjointSet | ||
from src.kruskals.edges import Edge | ||
|
||
|
||
class MSTStrategy(ABC): | ||
@abstractmethod | ||
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]: | ||
pass | ||
|
||
|
||
class KruskalMST(MSTStrategy): | ||
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]: | ||
sorted_edges = sorted(edges, key=lambda e: e.weight) | ||
dsu = DisjointSet(vertices_count) | ||
mst = [] | ||
for edge in sorted_edges: | ||
if dsu.union(edge.u, edge.v): | ||
mst.append(edge) | ||
if len(mst) == vertices_count - 1: | ||
break | ||
if len(mst) != vertices_count - 1: | ||
return [] | ||
return mst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
A C | ||
B C | ||
3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from src.kruskals.graph_reader import GraphReader | ||
from src.kruskals.mst_strategy import KruskalMST | ||
|
||
|
||
def main(): | ||
input_file = "input.txt" | ||
output_file = "output.txt" | ||
|
||
reader = GraphReader() | ||
graph, vertex_names = reader.read_graph_from_file(input_file) | ||
|
||
kruskal = KruskalMST() | ||
mst = kruskal.find_mst(graph.vertices_count, graph.edges) | ||
|
||
if not mst: | ||
with open(output_file, "w", encoding="utf-8") as out: | ||
out.write("Остов не существует\n") | ||
return | ||
|
||
named_edges = [] | ||
for edge in mst: | ||
v1_name = vertex_names[edge.u] | ||
v2_name = vertex_names[edge.v] | ||
if v1_name > v2_name: | ||
v1_name, v2_name = v2_name, v1_name | ||
named_edges.append((v1_name, v2_name, edge.weight)) | ||
|
||
named_edges.sort(key=lambda x: (x[0], x[1])) | ||
total_weight = sum(edge[2] for edge in named_edges) | ||
|
||
with open(output_file, "w", encoding="utf-8") as out: | ||
for u_name, v_name, w in named_edges: | ||
out.write(f"{u_name} {v_name}\n") | ||
out.write(f"{total_weight}\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .traversal import TraversalStrategy | ||
from src.tree.traversal import TraversalStrategy | ||
|
||
|
||
class BinaryTree: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import unittest | ||
|
||
from src.kruskals.disjoint_set import DisjointSet | ||
|
||
|
||
class TestDisjointSet(unittest.TestCase): | ||
def test_init(self): | ||
dsu = DisjointSet(5) | ||
for i in range(5): | ||
self.assertEqual(dsu.find(i), i) | ||
|
||
def test_union_find(self): | ||
dsu = DisjointSet(5) | ||
dsu.union(0, 1) | ||
self.assertEqual(dsu.find(0), dsu.find(1)) | ||
dsu.union(1, 2) | ||
self.assertEqual(dsu.find(0), dsu.find(2)) | ||
self.assertNotEqual(dsu.find(3), dsu.find(0)) | ||
|
||
def test_large(self): | ||
n = 50 | ||
dsu = DisjointSet(n) | ||
for i in range(n - 1): | ||
dsu.union(i, i + 1) | ||
root = dsu.find(0) | ||
for i in range(1, n): | ||
self.assertEqual(dsu.find(i), root) | ||
|
||
def test_no_unions(self): | ||
n = 10 | ||
dsu = DisjointSet(n) | ||
for i in range(n): | ||
self.assertEqual(dsu.find(i), i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import unittest | ||
|
||
from src.kruskals.graph import Graph | ||
|
||
|
||
class TestGraph(unittest.TestCase): | ||
def test_init(self): | ||
with self.assertRaises(ValueError): | ||
Graph(0) | ||
graph = Graph(3) | ||
self.assertEqual(graph.vertices_count, 3) | ||
|
||
def test_add_edge_valid(self): | ||
graph = Graph(3) | ||
graph.add_edge(0, 1, 1) | ||
self.assertEqual(len(graph.edges), 1) | ||
|
||
def test_add_edge_invalid_index(self): | ||
graph = Graph(3) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(-1, 0, 1) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(0, 3, 1) | ||
|
||
def test_add_loop_edge(self): | ||
graph = Graph(3) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(0, 0, 1) | ||
|
||
def test_add_negative_or_zero_weight(self): | ||
graph = Graph(3) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(0, 1, -1) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(0, 1, 0) | ||
|
||
def test_max_weight(self): | ||
graph = Graph(3) | ||
graph.add_edge(0, 1, 1023) | ||
self.assertEqual(len(graph.edges), 1) | ||
self.assertEqual(graph.edges[0].weight, 1023) | ||
with self.assertRaises(ValueError): | ||
graph.add_edge(1, 2, 1024) | ||
|
||
def test_large_graph(self): | ||
n = 50 | ||
graph = Graph(n) | ||
for i in range(n - 1): | ||
graph.add_edge(i, i + 1, i + 1) | ||
self.assertEqual(len(graph.edges), n - 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
import unittest | ||
|
||
from src.kruskals.graph_reader import GraphReader | ||
|
||
|
||
class TestGraphReader(unittest.TestCase): | ||
def setUp(self): | ||
self.test_filename = "input.txt" | ||
|
||
def tearDown(self): | ||
if os.path.exists(self.test_filename): | ||
os.remove(self.test_filename) | ||
|
||
def test_empty_file(self): | ||
open(self.test_filename, "w").close() | ||
reader = GraphReader() | ||
with self.assertRaises(ValueError): | ||
reader.read_graph_from_file(self.test_filename) | ||
|
||
def test_wrong_format(self): | ||
with open(self.test_filename, "w") as file: | ||
file.write("A B C\n0 1 2\n0 0") | ||
reader = GraphReader() | ||
with self.assertRaises(ValueError): | ||
reader.read_graph_from_file(self.test_filename) | ||
|
||
def test_too_many_vertices(self): | ||
names = " ".join([f"V{i}" for i in range(51)]) | ||
with open(self.test_filename, "w") as file: | ||
file.write(names + "\n") | ||
reader = GraphReader() | ||
with self.assertRaises(ValueError): | ||
reader.read_graph_from_file(self.test_filename) | ||
|
||
def test_invalid_weight(self): | ||
with open(self.test_filename, "w") as file: | ||
file.write("A B\n0 1024\n1024 0\n") | ||
reader = GraphReader() | ||
with self.assertRaises(ValueError): | ||
reader.read_graph_from_file(self.test_filename) | ||
|
||
def test_valid_graph(self): | ||
with open(self.test_filename, "w") as file: | ||
file.write("A B C\n0 3 1\n3 0 2\n1 2 0\n") | ||
reader = GraphReader() | ||
graph, names = reader.read_graph_from_file(self.test_filename) | ||
self.assertEqual(len(names), 3) | ||
self.assertEqual(graph.vertices_count, 3) | ||
self.assertEqual(len(graph.edges), 3) | ||
weights = sorted(edge.weight for edge in graph.edges) | ||
self.assertEqual(weights, [1, 2, 3]) | ||
|
||
def test_full_graph(self): | ||
""" | ||
Полный граф из 4 вершин | ||
Вершины: A B C D | ||
Матрица: | ||
0 1 2 3 | ||
1 0 4 5 | ||
2 4 0 6 | ||
""" | ||
|
||
with open(self.test_filename, "w") as file: | ||
file.write("A B C D\n") | ||
file.write("0 1 2 3\n") | ||
file.write("1 0 4 5\n") | ||
file.write("2 4 0 6\n") | ||
file.write("3 5 6 0\n") | ||
reader = GraphReader() | ||
g, names = reader.read_graph_from_file(self.test_filename) | ||
self.assertEqual(len(g.edges), 6) | ||
self.assertEqual(sorted(names), ["A", "B", "C", "D"]) |
Oops, something went wrong.