-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpreprocess_split.go
179 lines (145 loc) · 5.78 KB
/
preprocess_split.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_preprocess_split
#include <capi/preprocess_split.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type PreprocessSplitOptionalParam struct {
InputLabels *mat.Dense
NoShuffle bool
Seed int
StratifyData bool
TestRatio float64
Verbose bool
}
func PreprocessSplitOptions() *PreprocessSplitOptionalParam {
return &PreprocessSplitOptionalParam{
InputLabels: nil,
NoShuffle: false,
Seed: 0,
StratifyData: false,
TestRatio: 0.2,
Verbose: false,
}
}
/*
This utility takes a dataset and optionally labels and splits them into a
training set and a test set. Before the split, the points in the dataset are
randomly reordered. The percentage of the dataset to be used as the test set
can be specified with the "TestRatio" parameter; the default is 0.2 (20%).
The output training and test matrices may be saved with the "Training" and
"Test" output parameters.
Optionally, labels can also be split along with the data by specifying the
"InputLabels" parameter. Splitting labels works the same way as splitting the
data. The output training and test labels may be saved with the
"TrainingLabels" and "TestLabels" output parameters, respectively.
So, a simple example where we want to split the dataset X into X_train and
X_test with 60% of the data in the training set and 40% of the dataset in the
test set, we could run
// Initialize optional parameters for PreprocessSplit().
param := mlpack.PreprocessSplitOptions()
param.TestRatio = 0.4
X_test, _, X_train, _ := mlpack.PreprocessSplit(X, param)
Also by default the dataset is shuffled and split; you can provide the
"NoShuffle" option to avoid shuffling the data; an example to avoid shuffling
of data is:
// Initialize optional parameters for PreprocessSplit().
param := mlpack.PreprocessSplitOptions()
param.TestRatio = 0.4
param.NoShuffle = true
X_test, _, X_train, _ := mlpack.PreprocessSplit(X, param)
If we had a dataset X and associated labels y, and we wanted to split these
into X_train, y_train, X_test, and y_test, with 30% of the data in the test
set, we could run
// Initialize optional parameters for PreprocessSplit().
param := mlpack.PreprocessSplitOptions()
param.InputLabels = y
param.TestRatio = 0.3
X_test, y_test, X_train, y_train := mlpack.PreprocessSplit(X, param)
To maintain the ratio of each class in the train and test sets,
the"StratifyData" option can be used.
// Initialize optional parameters for PreprocessSplit().
param := mlpack.PreprocessSplitOptions()
param.TestRatio = 0.4
param.StratifyData = true
X_test, _, X_train, _ := mlpack.PreprocessSplit(X, param)
Input parameters:
- input (mat.Dense): Matrix containing data.
- InputLabels (mat.Dense): Matrix containing labels.
- NoShuffle (bool): Avoid shuffling the data before splitting.
- Seed (int): Random seed (0 for std::time(NULL)). Default value 0.
- StratifyData (bool): Stratify the data according to labels
- TestRatio (float64): Ratio of test set; if not set,the ratio defaults
to 0.2 Default value 0.2.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- test (mat.Dense): Matrix to save test data to.
- testLabels (mat.Dense): Matrix to save test labels to.
- training (mat.Dense): Matrix to save training data to.
- trainingLabels (mat.Dense): Matrix to save train labels to.
*/
func PreprocessSplit(input *mat.Dense, param *PreprocessSplitOptionalParam) (*mat.Dense, *mat.Dense, *mat.Dense, *mat.Dense) {
params := getParams("preprocess_split")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
gonumToArmaMat(params, "input", input, false)
setPassed(params, "input")
// Detect if the parameter was passed; set if so.
if param.InputLabels != nil {
gonumToArmaUmat(params, "input_labels", param.InputLabels)
setPassed(params, "input_labels")
}
// Detect if the parameter was passed; set if so.
if param.NoShuffle != false {
setParamBool(params, "no_shuffle", param.NoShuffle)
setPassed(params, "no_shuffle")
}
// Detect if the parameter was passed; set if so.
if param.Seed != 0 {
setParamInt(params, "seed", param.Seed)
setPassed(params, "seed")
}
// Detect if the parameter was passed; set if so.
if param.StratifyData != false {
setParamBool(params, "stratify_data", param.StratifyData)
setPassed(params, "stratify_data")
}
// Detect if the parameter was passed; set if so.
if param.TestRatio != 0.2 {
setParamDouble(params, "test_ratio", param.TestRatio)
setPassed(params, "test_ratio")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "test")
setPassed(params, "test_labels")
setPassed(params, "training")
setPassed(params, "training_labels")
// Call the mlpack program.
C.mlpackPreprocessSplit(params.mem, timers.mem)
// Initialize result variable and get output.
var testPtr mlpackArma
test := testPtr.armaToGonumMat(params, "test")
var testLabelsPtr mlpackArma
testLabels := testLabelsPtr.armaToGonumUmat(params, "test_labels")
var trainingPtr mlpackArma
training := trainingPtr.armaToGonumMat(params, "training")
var trainingLabelsPtr mlpackArma
trainingLabels := trainingLabelsPtr.armaToGonumUmat(params, "training_labels")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return test, testLabels, training, trainingLabels
}