forked from official-stockfish/WDL_model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscoreWDLstat.cpp
471 lines (379 loc) · 14.7 KB
/
scoreWDLstat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
#include <atomic>
#include <chrono>
#include <cmath>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <mutex>
#include <regex>
#include <string>
#include <unordered_map>
#include <vector>
#include "external/chess.hpp"
#include "external/json.hpp"
#include "external/threadpool.hpp"
namespace fs = std::filesystem;
using json = nlohmann::json;
using namespace chess;
enum class Result { WIN = 'W', DRAW = 'D', LOSS = 'L' };
struct ResultKey {
Result white;
Result black;
};
struct Key {
Result outcome; // game outcome from PoV of side to move
int move, material, score; // move number, material count, engine's eval
bool operator==(const Key &k) const {
return outcome == k.outcome && move == k.move && material == k.material && score == k.score;
}
operator std::size_t() const {
// golden ratio hashing, thus 0x9e3779b9
std::uint32_t hash = static_cast<int>(outcome);
hash ^= move + 0x9e3779b9 + (hash << 6) + (hash >> 2);
hash ^= material + 0x9e3779b9 + (hash << 6) + (hash >> 2);
hash ^= score + 0x9e3779b9 + (hash << 6) + (hash >> 2);
return hash;
}
operator std::string() const {
return "('" + std::string(1, static_cast<char>(outcome)) + "', " + std::to_string(move) +
", " + std::to_string(material) + ", " + std::to_string(score) + ")";
}
};
// overload the std::hash function for Key
template <>
struct std::hash<Key> {
std::size_t operator()(const Key &k) const { return static_cast<std::size_t>(k); }
};
// unordered map to count (outcome, move, material, score) tuples in pgns
using map_t = std::unordered_map<Key, int>;
std::atomic<std::size_t> total_chunks = 0;
namespace analysis {
/// @brief Custom stof implementation to avoid locale issues, once clang supports std::from_chars
/// for floats this can be removed
/// @param str
/// @return
float fast_stof(const char *str) {
float result = 0.0f;
int sign = 1;
int decimal = 0;
float fraction = 1.0f;
// Handle sign
if (*str == '-') {
sign = -1;
str++;
} else if (*str == '+') {
str++;
}
// Convert integer part
while (*str >= '0' && *str <= '9') {
result = result * 10.0f + (*str - '0');
str++;
}
// Convert decimal part
if (*str == '.') {
str++;
while (*str >= '0' && *str <= '9') {
result = result * 10.0f + (*str - '0');
fraction *= 10.0f;
str++;
}
decimal = 1;
}
// Apply sign and adjust for decimal
result *= sign;
if (decimal) {
result /= fraction;
}
return result;
}
/// @brief Magic value for fishtest pgns, ~1.2 million keys
static constexpr int map_size = 1200000;
/// @brief Analyze a single game and update the position map, apply filter if present
/// @param pos_map
/// @param game
/// @param regex_engine
void ana_game(map_t &pos_map, const std::optional<Game> &game, const std::string ®ex_engine) {
if (game.value().headers().find("Result") == game.value().headers().end()) {
return;
}
bool do_filter = !regex_engine.empty();
Color filter_side = Color::NONE;
if (do_filter) {
if (game.value().headers().find("White") == game.value().headers().end() ||
game.value().headers().find("Black") == game.value().headers().end()) {
return;
}
std::regex regex(regex_engine);
if (std::regex_match(game.value().headers().at("White"), regex)) {
filter_side = Color::WHITE;
}
if (std::regex_match(game.value().headers().at("Black"), regex)) {
if (filter_side == Color::NONE) {
filter_side = Color::BLACK;
} else {
do_filter = false;
}
}
}
const auto result = game.value().headers().at("Result");
ResultKey resultkey;
if (result == "1-0") {
resultkey.white = Result::WIN;
resultkey.black = Result::LOSS;
} else if (result == "0-1") {
resultkey.white = Result::LOSS;
resultkey.black = Result::WIN;
} else if (result == "1/2-1/2") {
resultkey.white = Result::DRAW;
resultkey.black = Result::DRAW;
} else {
return;
}
Board board = Board();
if (game.value().headers().find("FEN") != game.value().headers().end()) {
board.setFen(game.value().headers().at("FEN"));
}
if (game.value().headers().find("Variant") != game.value().headers().end() &&
game.value().headers().at("Variant") == "fischerandom") {
board.set960(true);
}
for (const auto &move : game.value().moves()) {
if (board.fullMoveNumber() > 200) {
break;
}
const size_t delimiter_pos = move.comment.find('/');
Key key;
key.score = 1002;
if (!do_filter || filter_side == board.sideToMove()) {
if (delimiter_pos != std::string::npos && move.comment != "book") {
const auto match_score = move.comment.substr(0, delimiter_pos);
if (match_score[1] == 'M') {
if (match_score[0] == '+') {
key.score = 1001;
} else {
key.score = -1001;
}
} else {
int score = 100 * fast_stof(match_score.c_str());
if (score > 1000) {
score = 1000;
} else if (score < -1000) {
score = -1000;
}
key.score = int(std::floor(score / 5.0)) * 5; // reduce precision
}
}
}
if (key.score != 1002) { // a score was found
key.outcome = board.sideToMove() == Color::WHITE ? resultkey.white : resultkey.black;
key.move = board.fullMoveNumber();
const auto knights = builtin::popcount(board.pieces(PieceType::KNIGHT));
const auto bishops = builtin::popcount(board.pieces(PieceType::BISHOP));
const auto rooks = builtin::popcount(board.pieces(PieceType::ROOK));
const auto queens = builtin::popcount(board.pieces(PieceType::QUEEN));
const auto pawns = builtin::popcount(board.pieces(PieceType::PAWN));
key.material = 9 * queens + 5 * rooks + 3 * bishops + 3 * knights + pawns;
pos_map[key] += 1;
}
board.makeMove(move.move);
}
}
void ana_files(map_t &map, const std::vector<std::string> &files, const std::string ®ex_engine) {
map.reserve(map_size);
for (const auto &file : files) {
std::ifstream pgn_file(file);
std::string line;
while (true) {
auto game = pgn::readGame(pgn_file);
if (!game.has_value()) {
break;
}
ana_game(map, game, regex_engine);
}
pgn_file.close();
}
}
} // namespace analysis
/// @brief Get all files from a directory.
/// @param path
/// @param recursive
/// @return
[[nodiscard]] std::vector<std::string> get_files(const std::string &path, bool recursive = false) {
std::vector<std::string> files;
for (const auto &entry : fs::directory_iterator(path)) {
if (fs::is_regular_file(entry)) {
if (entry.path().extension() == ".pgn") {
files.push_back(entry.path().string());
}
} else if (recursive && fs::is_directory(entry)) {
auto subdir_files = get_files(entry.path().string(), true);
files.insert(files.end(), subdir_files.begin(), subdir_files.end());
}
}
return files;
}
bool is_matching_book(const std::string &json_filename, const std::regex ®ex) {
std::ifstream json_file(json_filename);
if (!json_file.is_open()) {
return false;
}
json metadata;
json_file >> metadata;
json_file.close();
if (metadata.find("book") != metadata.end()) {
std::string book = metadata["book"];
return std::regex_match(book, regex);
}
return false;
}
void filter_files(std::vector<std::string> &file_list, const std::regex ®ex, bool invert) {
file_list.erase(std::remove_if(file_list.begin(), file_list.end(),
[®ex, invert](const std::string &pgn_filename) {
std::string json_filename =
pgn_filename.substr(0, pgn_filename.find_last_of('-')) +
".json";
bool match = is_matching_book(json_filename, regex);
return invert ? match : !match;
}),
file_list.end());
}
/// @brief Split into successive n-sized chunks from pgns.
/// @param pgns
/// @param target_chunks
/// @return
[[nodiscard]] std::vector<std::vector<std::string>> split_chunks(
const std::vector<std::string> &pgns, int target_chunks) {
const int chunks_size = (pgns.size() + target_chunks - 1) / target_chunks;
auto begin = pgns.begin();
auto end = pgns.end();
std::vector<std::vector<std::string>> chunks;
while (begin != end) {
auto next =
std::next(begin, std::min(chunks_size, static_cast<int>(std::distance(begin, end))));
chunks.push_back(std::vector<std::string>(begin, next));
begin = next;
}
return chunks;
}
void process(const std::vector<std::string> &files_pgn, map_t &pos_map,
const std::string ®ex_engine) {
// Create more chunks than threads to prevent threads from idling.
int target_chunks = 4 * std::max(1, int(std::thread::hardware_concurrency()));
auto files_chunked = split_chunks(files_pgn, target_chunks);
std::cout << "Found " << files_pgn.size() << " pgn files, creating " << files_chunked.size()
<< " chunks for processing." << std::endl;
// Mutex for pos_map access
std::mutex map_mutex;
// Create a thread pool
ThreadPool pool(std::thread::hardware_concurrency());
// Print progress
std::cout << "\rProgress: " << total_chunks << "/" << files_chunked.size() << std::flush;
for (const auto &files : files_chunked) {
pool.enqueue([&files, ®ex_engine, &map_mutex, &pos_map, &files_chunked]() {
map_t map;
analysis::ana_files(map, files, regex_engine);
total_chunks++;
// Limit the scope of the lock
{
const std::lock_guard<std::mutex> lock(map_mutex);
for (const auto &pair : map) {
pos_map[pair.first] += pair.second;
}
// Print progress
std::cout << "\rProgress: " << total_chunks << "/" << files_chunked.size()
<< std::flush;
}
});
}
// Wait for all threads to finish
pool.wait();
}
/// @brief Save the position map to a json file.
/// @param pos_map
/// @param json_filename
void save(const map_t &pos_map, const std::string &json_filename) {
std::uint64_t total = 0;
json j;
for (const auto &pair : pos_map) {
const auto map_key_t = static_cast<std::string>(pair.first);
j[map_key_t] = pair.second;
total += pair.second;
}
// save json to file
std::ofstream out_file(json_filename);
out_file << j.dump(2);
out_file.close();
std::cout << "Wrote " << total << " scored positions to " << json_filename << " for analysis."
<< std::endl;
}
bool find_argument(const std::vector<std::string> &args,
std::vector<std::string>::const_iterator &pos, std::string_view arg,
bool without_parameter = false) {
pos = std::find(args.begin(), args.end(), arg);
return pos != args.end() && (without_parameter || std::next(pos) != args.end());
}
void print_usage(char const *program_name) {
std::cout << "Usage: " << program_name << " [options]" << std::endl;
std::cout << "Options:" << std::endl;
std::cout << " --file <path> Path to pgn file" << std::endl;
std::cout << " --dir <path> Path to directory containing pgns" << std::endl;
std::cout << " -r Search for pgns recursively in subdirectories"
<< std::endl;
std::cout << " --matchEngine <regex> Filter data based on engine name" << std::endl;
std::cout << " --matchBook <regex> Filter data based on book name" << std::endl;
std::cout << " --matchBookInvert Invert the filter" << std::endl;
std::cout << " -o <path> Path to output json file" << std::endl;
}
/// @brief
/// @param argc
/// @param argv Possible ones are --file, --dir, -r, --matchEngine, --matchBook, --matchBookInvert
/// and -o
/// @return
int main(int argc, char const *argv[]) {
const std::vector<std::string> args(argv + 1, argv + argc);
std::vector<std::string> files_pgn;
std::string regex_engine, regex_book, json_filename = "scoreWDLstat.json";
std::vector<std::string>::const_iterator pos;
if (std::find(args.begin(), args.end(), "--help") != args.end()) {
print_usage(argv[0]);
return 0;
}
if (find_argument(args, pos, "--file")) {
files_pgn = {*std::next(pos)};
} else {
std::string path = "./pgns";
if (find_argument(args, pos, "--dir")) {
path = *std::next(pos);
}
bool recursive = find_argument(args, pos, "-r", true);
std::cout << "Looking " << (recursive ? "(recursively) " : "") << "for pgn files in "
<< path << std::endl;
files_pgn = get_files(path, recursive);
}
if (find_argument(args, pos, "--matchBook")) {
regex_book = *std::next(pos);
if (!regex_book.empty()) {
bool invert = find_argument(args, pos, "--matchBookInvert", true);
std::cout << "Filtering pgn files " << (invert ? "not " : "")
<< "matching the book name " << regex_book << std::endl;
std::regex regex(regex_book);
filter_files(files_pgn, regex, invert);
}
}
if (find_argument(args, pos, "--matchEngine")) {
regex_engine = *std::next(pos);
}
if (find_argument(args, pos, "-o")) {
json_filename = *std::next(pos);
}
map_t pos_map;
pos_map.reserve(analysis::map_size);
const auto t0 = std::chrono::high_resolution_clock::now();
process(files_pgn, pos_map, regex_engine);
const auto t1 = std::chrono::high_resolution_clock::now();
std::cout << "\nTime taken: "
<< std::chrono::duration_cast<std::chrono::seconds>(t1 - t0).count() << "s"
<< std::endl;
save(pos_map, json_filename);
return 0;
}