Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small refactoring and adding new solution of kruskal's algorithm #9

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
Empty file added src/kruskals/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions src/kruskals/disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class DisjointSet:
def __init__(self, n: int):
self.parent = [i for i in range(n)]
self.rank = [0] * n

def find(self, x: int) -> int:
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x: int, y: int) -> bool:
rx = self.find(x)
ry = self.find(y)
if rx == ry:
return False
if self.rank[rx] < self.rank[ry]:
self.parent[rx] = ry
elif self.rank[rx] > self.rank[ry]:
self.parent[ry] = rx
else:
self.parent[ry] = rx
self.rank[rx] += 1
return True
8 changes: 8 additions & 0 deletions src/kruskals/edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class Edge:

Check notice on line 5 in src/kruskals/edges.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Class has no `__init__` method

Class has no __init__ method
u: int
v: int
weight: float
32 changes: 32 additions & 0 deletions src/kruskals/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from copy import copy
from typing import List

from src.kruskals.edges import Edge


class Graph:
def __init__(self, vertices_count: int):
if vertices_count <= 0:
raise ValueError("Количество вершин <= 0")
self.vertices_count = vertices_count
self._edges: List[Edge] = []

@property
def edges(self) -> List[Edge]:
return copy(self._edges)

def add_edge(self, begin: int, end: int, weight: float):
if (
begin < 0
or end < 0
or begin >= self.vertices_count
or end >= self.vertices_count
):
raise ValueError("Неверный индекс вершины")
if begin == end:
raise ValueError("Петли не допускаются")
if weight <= 0 or weight > 1023:
raise ValueError(
"Вес ребра должен быть в диапазоне [1, ..., 1023]"
)
self._edges.append(Edge(begin, end, weight))
56 changes: 56 additions & 0 deletions src/kruskals/graph_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from src.kruskals.graph import Graph


def _process_edge(i: int, j: int, val: int, graph: Graph):
if val == 0:
return
if val < 1 or val > 1023:
raise ValueError("Вес ребра должен быть в диапазоне [1..1023].")
if j > i:
graph.add_edge(i, j, val)


def _read_matrix_row(file_obj, vertices_count: int):
line = file_obj.readline().strip()
if not line:
raise ValueError("Строк в матрице меньше, чем количество вершин.")

row_values = line.split()
if len(row_values) != vertices_count:
raise ValueError(
"Количество столбцов матрицы не соответствует количеству вершин."
)
return row_values


def _read_adjacency_matrix(file_obj, graph: Graph):
vertices_count = graph.vertices_count
for i in range(vertices_count):
row_values = _read_matrix_row(file_obj, vertices_count)
for j, val_str in enumerate(row_values):
val = int(val_str)
_process_edge(i, j, val, graph)


class GraphReader:
def __init__(self, max_vertices=50):
self.max_vertices = max_vertices

def read_graph_from_file(self, filename: str):
with open(filename, "r", encoding="utf-8") as f:
vertex_names = self._read_vertex_names(f)
graph = Graph(len(vertex_names))
_read_adjacency_matrix(f, graph)
return graph, vertex_names

def _read_vertex_names(self, file_obj):
first_line = file_obj.readline().strip()
if not first_line:
raise ValueError("File is corrupt")

vertex_names = first_line.split()
if len(vertex_names) > self.max_vertices:
raise ValueError(
f"Слишком много вершин. Максимум {self.max_vertices}."
)
return vertex_names
4 changes: 4 additions & 0 deletions src/kruskals/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
A B C
0 3 1
3 0 2
1 2 0
26 changes: 26 additions & 0 deletions src/kruskals/mst_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import List

from src.kruskals.disjoint_set import DisjointSet
from src.kruskals.edges import Edge


class MSTStrategy(ABC):
@abstractmethod
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]:
pass


class KruskalMST(MSTStrategy):
def find_mst(self, vertices_count: int, edges: List[Edge]) -> List[Edge]:
sorted_edges = sorted(edges, key=lambda e: e.weight)
dsu = DisjointSet(vertices_count)
mst = []
for edge in sorted_edges:
if dsu.union(edge.u, edge.v):
mst.append(edge)
if len(mst) == vertices_count - 1:
break
if len(mst) != vertices_count - 1:
return []
return mst
3 changes: 3 additions & 0 deletions src/kruskals/output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
A C
B C
3
38 changes: 38 additions & 0 deletions src/kruskals/solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from src.kruskals.graph_reader import GraphReader
from src.kruskals.mst_strategy import KruskalMST


def main():
input_file = "input.txt"
output_file = "output.txt"

reader = GraphReader()
graph, vertex_names = reader.read_graph_from_file(input_file)

kruskal = KruskalMST()
mst = kruskal.find_mst(graph.vertices_count, graph.edges)

if not mst:
with open(output_file, "w", encoding="utf-8") as out:
out.write("Остов не существует\n")
return

named_edges = []
for edge in mst:
v1_name = vertex_names[edge.u]
v2_name = vertex_names[edge.v]
if v1_name > v2_name:
v1_name, v2_name = v2_name, v1_name
named_edges.append((v1_name, v2_name, edge.weight))

named_edges.sort(key=lambda x: (x[0], x[1]))
total_weight = sum(edge[2] for edge in named_edges)

with open(output_file, "w", encoding="utf-8") as out:
for u_name, v_name, w in named_edges:
out.write(f"{u_name} {v_name}\n")
out.write(f"{total_weight}\n")


if __name__ == "__main__":
main()
Empty file added src/shunting_yard/__init__.py
Empty file.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/tree/binary_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .traversal import TraversalStrategy
from src.tree.traversal import TraversalStrategy


class BinaryTree:
Expand Down
33 changes: 33 additions & 0 deletions test/test_disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

from src.kruskals.disjoint_set import DisjointSet


class TestDisjointSet(unittest.TestCase):
def test_init(self):
dsu = DisjointSet(5)
for i in range(5):
self.assertEqual(dsu.find(i), i)

def test_union_find(self):
dsu = DisjointSet(5)
dsu.union(0, 1)
self.assertEqual(dsu.find(0), dsu.find(1))
dsu.union(1, 2)
self.assertEqual(dsu.find(0), dsu.find(2))
self.assertNotEqual(dsu.find(3), dsu.find(0))

def test_large(self):
n = 50
dsu = DisjointSet(n)
for i in range(n - 1):
dsu.union(i, i + 1)
root = dsu.find(0)
for i in range(1, n):
self.assertEqual(dsu.find(i), root)

def test_no_unions(self):
n = 10
dsu = DisjointSet(n)
for i in range(n):
self.assertEqual(dsu.find(i), i)
50 changes: 50 additions & 0 deletions test/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

from src.kruskals.graph import Graph


class TestGraph(unittest.TestCase):
def test_init(self):
with self.assertRaises(ValueError):
Graph(0)
graph = Graph(3)
self.assertEqual(graph.vertices_count, 3)

def test_add_edge_valid(self):
graph = Graph(3)
graph.add_edge(0, 1, 1)
self.assertEqual(len(graph.edges), 1)

def test_add_edge_invalid_index(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(-1, 0, 1)
with self.assertRaises(ValueError):
graph.add_edge(0, 3, 1)

def test_add_loop_edge(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(0, 0, 1)

def test_add_negative_or_zero_weight(self):
graph = Graph(3)
with self.assertRaises(ValueError):
graph.add_edge(0, 1, -1)
with self.assertRaises(ValueError):
graph.add_edge(0, 1, 0)

def test_max_weight(self):
graph = Graph(3)
graph.add_edge(0, 1, 1023)
self.assertEqual(len(graph.edges), 1)
self.assertEqual(graph.edges[0].weight, 1023)
with self.assertRaises(ValueError):
graph.add_edge(1, 2, 1024)

def test_large_graph(self):
n = 50
graph = Graph(n)
for i in range(n - 1):
graph.add_edge(i, i + 1, i + 1)
self.assertEqual(len(graph.edges), n - 1)
73 changes: 73 additions & 0 deletions test/test_graph_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import unittest

from src.kruskals.graph_reader import GraphReader


class TestGraphReader(unittest.TestCase):
def setUp(self):

Check notice on line 8 in test/test_graph_reader.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Function name should be lowercase
self.test_filename = "input.txt"

def tearDown(self):

Check notice on line 11 in test/test_graph_reader.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Function name should be lowercase
if os.path.exists(self.test_filename):
os.remove(self.test_filename)

def test_empty_file(self):
open(self.test_filename, "w").close()
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_wrong_format(self):
with open(self.test_filename, "w") as file:
file.write("A B C\n0 1 2\n0 0")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_too_many_vertices(self):
names = " ".join([f"V{i}" for i in range(51)])
with open(self.test_filename, "w") as file:
file.write(names + "\n")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_invalid_weight(self):
with open(self.test_filename, "w") as file:
file.write("A B\n0 1024\n1024 0\n")
reader = GraphReader()
with self.assertRaises(ValueError):
reader.read_graph_from_file(self.test_filename)

def test_valid_graph(self):
with open(self.test_filename, "w") as file:
file.write("A B C\n0 3 1\n3 0 2\n1 2 0\n")
reader = GraphReader()
graph, names = reader.read_graph_from_file(self.test_filename)
self.assertEqual(len(names), 3)
self.assertEqual(graph.vertices_count, 3)
self.assertEqual(len(graph.edges), 3)
weights = sorted(edge.weight for edge in graph.edges)
self.assertEqual(weights, [1, 2, 3])

def test_full_graph(self):
"""
Полный граф из 4 вершин
Вершины: A B C D
Матрица:
0 1 2 3
1 0 4 5
2 4 0 6
"""

with open(self.test_filename, "w") as file:
file.write("A B C D\n")
file.write("0 1 2 3\n")
file.write("1 0 4 5\n")
file.write("2 4 0 6\n")
file.write("3 5 6 0\n")
reader = GraphReader()
g, names = reader.read_graph_from_file(self.test_filename)
self.assertEqual(len(g.edges), 6)
self.assertEqual(sorted(names), ["A", "B", "C", "D"])
Loading
Loading