-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
146 lines (117 loc) · 6.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import logging
from pathlib import Path
from typing import Optional, Any
from clustering.DBScan.dbscan_cluster_builder import DBScanClusterBuilder
from clustering.KMeans.kmeans_cluster_builder import KMeansClusterBuilder
from clustering.SimDim.simdim_cluster_builder import SimDimClusterBuilder
from clustering.abstract_cluster_builder import AbstractClusterBuilder
from util.filesystem_validators import WriteableDirectory, ReadableFile
from writers.abstract_cluster_writer import AbstractClusterWriter
from writers.csv_cluster_writer import CSVClusterWriter
from writers.text_cluster_writer import TextClusterWriter
VALID_OUTPUT_MODES = ["csv", "text"]
def main() -> None:
logging.basicConfig(format="%(asctime)s : [%(threadName)s] %(levelname)s : %(message)s", level=logging.INFO)
parser: argparse.ArgumentParser = _initialize_parser()
args: Any = parser.parse_args()
if "action" not in args or not args.action:
parser.print_usage()
return
cluster_builder: Optional[AbstractClusterBuilder] = None
if args.action == "kmeans":
cluster_builder = KMeansClusterBuilder(args.input, args.threads, args.k)
if args.action == "dbscan":
cluster_builder = DBScanClusterBuilder(args.input, args.threads, args.eps)
if args.action == "simdim":
cluster_builder = SimDimClusterBuilder(args.input, args.threads)
if not cluster_builder:
exit(1)
cluster_builder.build_clusters()
_create_writer(args).write(cluster_builder, args.output if "output" in args else None)
def _create_writer(args) -> AbstractClusterWriter:
if args.output_mode == "text":
return TextClusterWriter()
if args.output_mode == "csv":
return CSVClusterWriter()
raise Exception(f"Invalid output type arguments supplied. Choose from {', '.join(VALID_OUTPUT_MODES)}")
def _initialize_parser() -> argparse.ArgumentParser:
general_parser = argparse.ArgumentParser(description="Clustering trained entity embeddings")
subparsers = general_parser.add_subparsers()
_initialize_kmeans_parser(subparsers)
_initialize_dbscan_parser(subparsers)
_initialize_simdim_parser(subparsers)
return general_parser
def _initialize_kmeans_parser(subparsers) -> None:
kmeans_parser = subparsers.add_parser("kmeans",
help="Use k-means for clustering")
kmeans_parser.set_defaults(action="kmeans")
kmeans_parser.add_argument("--input",
help="gensim model containing embedded entities",
type=Path,
action=ReadableFile,
required=True)
kmeans_parser.add_argument("--k",
help="number of clusters to build",
required=True,
type=int)
kmeans_parser.add_argument("--output",
help="Desired location for storing cluster information",
type=Path,
action=WriteableDirectory,
required=True)
kmeans_parser.add_argument("--output-mode",
help=f"Define the type of output. Choose from: {', '.join(VALID_OUTPUT_MODES)}",
required=True)
kmeans_parser.add_argument("--threads",
help="Number of threads to use",
type=int,
default=8)
def _initialize_dbscan_parser(subparsers) -> None:
dbscan_parser = subparsers.add_parser("dbscan",
help="Use DBSCAN for clustering")
dbscan_parser.set_defaults(action="dbscan")
dbscan_parser.add_argument("--input",
help="gensim model containing embedded entities",
type=Path,
action=ReadableFile,
required=True)
dbscan_parser.add_argument("--eps",
help="eps for expanding clusters",
type=float,
required=True)
dbscan_parser.add_argument("--output",
help="Desired location for storing cluster information",
type=Path,
action=WriteableDirectory,
required=True)
dbscan_parser.add_argument("--output-mode",
help=f"Define the type of output. Choose from: {', '.join(VALID_OUTPUT_MODES)}",
required=True)
dbscan_parser.add_argument("--threads",
help="Number of threads to use",
type=int,
default=8)
def _initialize_simdim_parser(subparsers) -> None:
simdim_parser = subparsers.add_parser("simdim",
help="Use SimDim for clustering")
simdim_parser.set_defaults(action="simdim")
simdim_parser.add_argument("--input",
help="gensim model containing embedded entities",
type=Path,
action=ReadableFile,
required=True)
simdim_parser.add_argument("--output",
help="Desired location for storing cluster information",
type=Path,
action=WriteableDirectory,
required=True)
simdim_parser.add_argument("--output-mode",
help=f"Define the type of output. Choose from: {', '.join(VALID_OUTPUT_MODES)}",
required=True)
simdim_parser.add_argument("--threads",
help="Number of threads to use",
type=int,
default=8)
if __name__ == "__main__":
main()