forked from LLNL/ygm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
alg_spmv.cpp
82 lines (66 loc) · 2.06 KB
/
alg_spmv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// Copyright 2019-2021 Lawrence Livermore National Security, LLC and other YGM
// Project Developers. See the top-level COPYRIGHT file for details.
//
// SPDX-License-Identifier: MIT
#include <iostream>
#include <string>
#include <vector>
#include <ygm/comm.hpp>
#include <ygm/container/experimental/maptrix.hpp>
#include <ygm/container/map.hpp>
#include <ygm/io/line_parser.hpp>
#include <ygm/utility.hpp>
int main(int argc, char **argv) {
ygm::comm world(&argc, &argv);
using map_type = ygm::container::map<size_t, int>;
using maptrix_type = ygm::container::experimental::maptrix<size_t, int>;
namespace ns_spmv = ygm::container::experimental::detail::algorithms;
if (argc == 1) {
std::cout << "Expected parameter arguments, exiting.." << std::endl;
exit(0);
}
std::vector<std::string> mat_files({argv[1]});
std::vector<std::string> vec_files;
bool read_vec{false};
if (argc == 3) {
vec_files.push_back(argv[2]);
read_vec = true;
}
map_type x(world);
maptrix_type A(world);
world.cout0("Reading maptrix");
ygm::io::line_parser line_parser(world, mat_files);
line_parser.for_all([&A, &x, read_vec](auto &line) {
size_t src;
size_t dst;
int val = 1;
std::istringstream iss(line);
iss >> src >> dst >> val;
A.async_insert(src, dst, val);
/* Map x. */
if (!read_vec) {
x.async_insert(src, 1);
x.async_insert(dst, 1);
}
});
world.barrier();
if (read_vec) {
world.cout0("Reading vector");
ygm::io::line_parser vec_parser(world, vec_files);
vec_parser.for_all([&x](auto line) {
size_t index;
size_t val;
std::istringstream iss(line);
if (iss >> index >> val) {
x.async_insert(index, val);
}
});
world.barrier();
}
world.cout0("Performing SpMV");
ygm::timer spmv_timer{};
auto y = ns_spmv::spmv(A, x, std::plus<int>(), std::multiplies<int>());
world.barrier();
world.cout0("SpMV time: ", spmv_timer.elapsed(), " seconds");
return 0;
}