Skip to content

Commit

Permalink
Merge pull request #9 from m4tveevm/KruskalsAlgo
Browse files Browse the repository at this point in the history
small refactoring and adding new solution of kruskal's algorithm
  • Loading branch information
m4tveevm authored Dec 17, 2024
2 parents b1242a7 + 114ef9d commit 8438993
Show file tree
Hide file tree
Showing 18 changed files with 425 additions and 2 deletions.
File renamed without changes.
Empty file added src/kruskals/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions src/kruskals/disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class DisjointSet:
def __init__(self, n: int):
self.parent = [i for i in range(n)]
self.rank = [0] * n

def find(self, x: int) -> int:
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x: int, y: int) -> bool:
rx = self.find(x)
ry = self.find(y)
if rx == ry:
return False
if self.rank[rx] < self.rank[ry]:
self.parent[rx] = ry
elif self.rank[rx] > self.rank[ry]:
self.parent[ry] = rx
else:
self.parent[ry] = rx
self.rank[rx] += 1
return True
8 changes: 8 additions & 0 deletions src/kruskals/edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class Edge:

Check notice on line 5 in src/kruskals/edges.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Class has no `__init__` method

Class has no __init__ method
u: int
v: int
weight: float
32 changes: 32 additions & 0 deletions src/kruskals/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from copy import copy
from typing import List

from src.kruskals.edges import Edge


class Graph:
def __init__(self, vertices_count: int):
if vertices_count <= 0:
raise ValueError("Количество вершин <= 0")
self.vertices_count = vertices_count
self._edges: List[Edge] = []

@property
def edges(self) -> List[Edge]:
return copy(self._edges)

def add_edge(self, begin: int, end: int, weight: float):
if (
begin < 0
or end < 0
or begin >= self.vertices_count
or end >= self.vertices_count
):
raise ValueError("Неверный индекс вершины")
if begin == end:
raise ValueError("Петли не допускаются")
if weight <= 0 or weight > 1023:
raise ValueError(
"Вес ребра должен быть в диапазоне [1, ..., 1023]"
)
self._edges.append(Edge(begin, end, weight))
56 changes: 56 additions & 0 deletions src/kruskals/graph_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from src.kruskals.graph import Graph


def _process_edge(i: int, j: int, val: int, graph: Graph):
if val == 0:
return
if val < 1 or val > 1023:
raise ValueError("Вес ребра должен быть в диапазоне [1..1023].")
if j > i:
graph.add_edge(i, j, val)


def _read_matrix_row(file_obj, vertices_count: int):
line = file_obj.readline().strip()
if not line:
raise ValueError("Строк в матрице меньше, чем количество вершин.")

row_values = line.split()
if len(row_values) != vertices_count:
raise ValueError(
"Количество столбцов матрицы не соответствует количеству вершин."
)
return row_values


def _read_adjacency_matrix(file_obj, graph: Graph):
vertices_count = graph.vertices_count
for i in range(vertices_count):
row_values = _read_matrix_row(file_obj, vertices_count)
for j, val_str in enumerate(row_values):
val = int(val_str)
_process_edge(i, j, val, graph)


class GraphReader:
def __init__(self, max_vertices=50):
self.max_vertices = max_vertices

def read_graph_from_file(self, filename: str):
with open(filename, "r", encoding="utf-8") as f:
vertex_names = self._read_vertex_names(f)
graph = Graph(len(vertex_names))
_read_adjacency_matrix(f, graph)
return graph, vertex_names

def _read_vertex_names(self, file_obj):
first_line = file_obj.readline().strip()
if not first_line:
raise ValueError("File is corrupt")

vertex_names = first_line.split()
if len(vertex_names) > self.max_vertices:
raise ValueError(
f"Слишком много вершин. Максимум {self.max_vertices}."
)
return vertex_names
4 changes: 4 additions & 0 deletions src/kruskals/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
A B C
0 3 1
3 0 2
1 2 0
26 changes: 26 additions & 0 deletions src/kruskals/mst_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import List

from src.kruskals.disjoint_set import DisjointSet
from src.kruskals.edges import Edge


class MSTStrategy(ABC):
@abstractmethod
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]:
pass


class KruskalMST(MSTStrategy):
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]:
sorted_edges = sorted(edges, key=lambda e: e.weight)
dsu = DisjointSet(vertices_count)
mst = []
for edge in sorted_edges:
if dsu.union(edge.u, edge.v):
mst.append(edge)
if len(mst) == vertices_count - 1:
break
if len(mst) != vertices_count - 1:
return []
return mst
3 changes: 3 additions & 0 deletions src/kruskals/output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
A C
B C
3
38 changes: 38 additions & 0 deletions src/kruskals/solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from src.kruskals.graph_reader import GraphReader
from src.kruskals.mst_strategy import KruskalMST


def main():
input_file = "input.txt"
output_file = "output.txt"

reader = GraphReader()
graph, vertex_names = reader.read_graph_from_file(input_file)

kruskal = KruskalMST()
mst = kruskal.find_mst(graph.vertices_count, graph.edges)

if not mst:
with open(output_file, "w", encoding="utf-8") as out:
out.write("Остов не существует\n")
return

named_edges = []
for edge in mst:
v1_name = vertex_names[edge.u]
v2_name = vertex_names[edge.v]
if v1_name > v2_name:
v1_name, v2_name = v2_name, v1_name
named_edges.append((v1_name, v2_name, edge.weight))

named_edges.sort(key=lambda x: (x[0], x[1]))
total_weight = sum(edge[2] for edge in named_edges)

with open(output_file, "w", encoding="utf-8") as out:
for u_name, v_name, w in named_edges:
out.write(f"{u_name} {v_name}\n")
out.write(f"{total_weight}\n")


if __name__ == "__main__":
main()
Empty file added src/shunting_yard/__init__.py
Empty file.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/tree/binary_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .traversal import TraversalStrategy
from src.tree.traversal import TraversalStrategy


class BinaryTree:
Expand Down
33 changes: 33 additions & 0 deletions test/test_disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

from src.kruskals.disjoint_set import DisjointSet


class TestDisjointSet(unittest.TestCase):
def test_init(self):
dsu = DisjointSet(5)
for i in range(5):
self.assertEqual(dsu.find(i), i)

def test_union_find(self):
dsu = DisjointSet(5)
dsu.union(0, 1)
self.assertEqual(dsu.find(0), dsu.find(1))
dsu.union(1, 2)
self.assertEqual(dsu.find(0), dsu.find(2))
self.assertNotEqual(dsu.find(3), dsu.find(0))

def test_large(self):
n = 50
dsu = DisjointSet(n)
for i in range(n - 1):
dsu.union(i, i + 1)
root = dsu.find(0)
for i in range(1, n):
self.assertEqual(dsu.find(i), root)

def test_no_unions(self):
n = 10
dsu = DisjointSet(n)
for i in range(n):
self.assertEqual(dsu.find(i), i)
50 changes: 50 additions & 0 deletions test/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

from src.kruskals.graph import Graph


class TestGraph(unittest.TestCase):
def test_init(self):
with self.assertRaises(ValueError):
Graph(0)
graph = Graph(3)
self.assertEqual(graph.vertices_count, 3)

def test_add_edge_valid(self):
graph = Graph(3)
graph.add_edge(0, 1, 1)
self.assertEqual(len(graph.edges), 1)

def test_add_edge_invalid_index(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(-1, 0, 1)
with self.assertRaises(ValueError):
graph.add_edge(0, 3, 1)

def test_add_loop_edge(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(0, 0, 1)

def test_add_negative_or_zero_weight(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(0, 1, -1)
with self.assertRaises(ValueError):
graph.add_edge(0, 1, 0)

def test_max_weight(self):
graph = Graph(3)
graph.add_edge(0, 1, 1023)
self.assertEqual(len(graph.edges), 1)
self.assertEqual(graph.edges[0].weight, 1023)
with self.assertRaises(ValueError):
graph.add_edge(1, 2, 1024)

def test_large_graph(self):
n = 50
graph = Graph(n)
for i in range(n - 1):
graph.add_edge(i, i + 1, i + 1)
self.assertEqual(len(graph.edges), n - 1)
73 changes: 73 additions & 0 deletions test/test_graph_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import unittest

from src.kruskals.graph_reader import GraphReader


class TestGraphReader(unittest.TestCase):
def setUp(self):

Check notice on line 8 in test/test_graph_reader.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Function name should be lowercase
self.test_filename = "input.txt"

def tearDown(self):

Check notice on line 11 in test/test_graph_reader.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Function name should be lowercase
if os.path.exists(self.test_filename):
os.remove(self.test_filename)

def test_empty_file(self):
open(self.test_filename, "w").close()
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_wrong_format(self):
with open(self.test_filename, "w") as file:
file.write("A B C\n0 1 2\n0 0")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_too_many_vertices(self):
names = " ".join([f"V{i}" for i in range(51)])
with open(self.test_filename, "w") as file:
file.write(names + "\n")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_invalid_weight(self):
with open(self.test_filename, "w") as file:
file.write("A B\n0 1024\n1024 0\n")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_valid_graph(self):
with open(self.test_filename, "w") as file:
file.write("A B C\n0 3 1\n3 0 2\n1 2 0\n")
reader = GraphReader()
graph, names = reader.read_graph_from_file(self.test_filename)
self.assertEqual(len(names), 3)
self.assertEqual(graph.vertices_count, 3)
self.assertEqual(len(graph.edges), 3)
weights = sorted(edge.weight for edge in graph.edges)
self.assertEqual(weights, [1, 2, 3])

def test_full_graph(self):
"""
Полный граф из 4 вершин
Вершины: A B C D
Матрица:
0 1 2 3
1 0 4 5
2 4 0 6
"""

with open(self.test_filename, "w") as file:
file.write("A B C D\n")
file.write("0 1 2 3\n")
file.write("1 0 4 5\n")
file.write("2 4 0 6\n")
file.write("3 5 6 0\n")
reader = GraphReader()
g, names = reader.read_graph_from_file(self.test_filename)
self.assertEqual(len(g.edges), 6)
self.assertEqual(sorted(names), ["A", "B", "C", "D"])
Loading

0 comments on commit 8438993

Please sign in to comment.