Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Search bug fixes #4673

Merged
merged 9 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/train-sets/ref/search_dep_parser_cost_to_go.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ average since instance current true current predicted
loss last counter output prefix output prefix pass pol made hits gener beta
88.000000 88.000000 1 [43:1 5:2 5:2 5:2 1..] [0:8 1:1 2:1 3:1 4:..] 0 0 144 0 141 0.000000
46.500000 5.000000 2 [2:2 3:5 0:8 3:7 99..] [2:2 0:8 4:2 2:3 99..] 0 0 153 0 150 0.001409
30.750000 15.000000 4 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 1 0 306 0 300 0.002906
17.125000 3.500000 8 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 3 0 606 0 600 0.005893
29.250000 12.000000 4 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 1 0 306 0 300 0.002906
16.375000 3.500000 8 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 3 0 606 0 600 0.005893

finished run
number of examples per pass = 2
passes used = 6
weighted example sum = 12.000000
weighted label sum = 0.000000
average loss = 11.416667
total feature number = 270404
average loss = 10.916667
total feature number = 269977
4 changes: 2 additions & 2 deletions test/train-sets/ref/search_dep_parser_one_learner.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
89.000000 89.000000 1 [43:1 5:2 5:2 5:2 1..] [20:9 20:9 20:9 20:..] 0 0 96 0 94 0.000930
47.000000 5.000000 2 [2:2 3:5 0:8 3:7 99..] [2:12 3:7 4:4 0:8 9..] 0 0 102 0 100 0.000990
46.500000 4.000000 2 [2:2 3:5 0:8 3:7 99..] [2:5 3:5 4:5 0:8 99..] 0 0 102 0 100 0.000990

finished run
number of examples = 2
weighted example sum = 2.000000
weighted label sum = 0.000000
average loss = 47.000000
average loss = 46.500000
total feature number = 28636
10 changes: 5 additions & 5 deletions test/train-sets/ref/search_wsj.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
30.000000 30.000000 1 [1 2 3 1 4 5 6 7 8 ..] [1 1 1 1 1 1 1 1 1 ..] 0 0 37 0 37 0.000036
23.500000 17.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 4 1 2 1 1 ..] 0 0 64 0 64 0.000063
16.000000 8.500000 4 [3 4 6 3 ] [11 11 2 3 ] 0 0 97 0 97 0.000096
8.000000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.000000 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 0 388 0 388 0.000387
24.500000 19.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 4 1 12 9 1..] 0 0 64 0 64 0.000063
16.250000 8.000000 4 [3 4 6 3 ] [1 4 6 3 ] 0 0 97 0 97 0.000096
8.125000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.062500 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 0 388 0 388 0.000387

finished run
number of examples per pass = 4
passes used = 6
weighted example sum = 24.000000
weighted label sum = 0.000000
average loss = 2.666667
average loss = 2.708333
total feature number = 52110
4 changes: 2 additions & 2 deletions test/train-sets/ref/search_wsj2.dat.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
30.000000 30.000000 1 [1 2 3 1 4 5 6 7 8 ..] [1 1 1 1 1 1 1 1 1 ..] 0 0 37 0 37 0.000036
24.000000 18.000000 2 [11 2 3 11 11 11 15..] [1 2 3 1 4 1 2 1 1 ..] 0 0 64 0 64 0.000063
16.750000 9.500000 4 [3 4 6 3 ] [11 11 11 11 ] 0 0 97 0 97 0.000096
24.500000 19.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 1 1 12 12 ..] 0 0 64 0 64 0.000063
16.750000 9.000000 4 [3 4 6 3 ] [1 4 6 6 ] 0 0 97 0 97 0.000096
8.375000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.187500 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 1 388 0 388 0.000387

Expand Down
12 changes: 1 addition & 11 deletions test/train-sets/ref/sequence_data.ldf.beam.test.predict
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
5 4 3 2 1 0.000242054
5 4 4 3 2 1.00016
5 4 3 3 2 1.00018
5 4 3 2 4 1.00018
5 4 5 4 3 1.00019
5 4 3 4 3 1.00026
5 4 2 1 4 1.60554
5 3 2 1 4 1.60557
5 4 1 4 3 1.60563
4 3 2 1 4 1.60563

5 4 3 2 1
2 changes: 1 addition & 1 deletion test/train-sets/ref/sequence_data.ldf.beam.test.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Input label = MULTICLASS
Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 0.00024..] 0 0 26 0 0 0.000000
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 ] 0 0 26 0 0 0.000000

finished run
number of examples = 1
Expand Down
12 changes: 1 addition & 11 deletions test/train-sets/ref/sequence_data.nonldf.beam.test.predict
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
5 4 3 2 1 8.34465e-07
5 4 3 5 4 1
5 4 3 2 4 1
5 5 4 3 2 1.02138
5 4 3 2 3 1.02816
5 4 3 2 2 1.03424
5 3 2 1 1 1.76761
5 4 3 1 1 1.79503
4 3 2 1 1 1.79576
5 2 1 1 1 2.53521

5 4 3 2 1
2 changes: 1 addition & 1 deletion test/train-sets/ref/sequence_data.nonldf.beam.test.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Input label = MULTICLASS
Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 8.34465..] 0 0 24 0 0 0.000000
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 ] 0 0 24 0 0 0.000000

finished run
number of examples = 1
Expand Down
12 changes: 6 additions & 6 deletions vowpalwabbit/core/include/vw/core/reductions/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ class search
public: // INTERFACE
// for managing task-specific data that you want on the heap:
template <class T>
void set_task_data(T* data)
void set_task_data(std::shared_ptr<T> data)
{
task_data = std::shared_ptr<T>(data);
task_data = std::move(data);
}
template <class T>
T* get_task_data()
Expand All @@ -98,9 +98,9 @@ class search

// for managing metatask-specific data
template <class T>
void set_metatask_data(T* data)
void set_metatask_data(std::shared_ptr<T> data)
{
metatask_data = std::shared_ptr<T>(data);
metatask_data = std::move(data);
}
template <class T>
T* get_metatask_data()
Expand Down Expand Up @@ -218,7 +218,7 @@ class search
BaseTask base_task(VW::multi_ex& ec) { return BaseTask(this, ec); }

// internal data that you don't get to see!
search_private* priv = nullptr;
std::shared_ptr<search_private> priv = nullptr;
std::shared_ptr<void> task_data = nullptr; // your task data!
std::shared_ptr<void> metatask_data = nullptr; // your metatask data!
const char* task_name = nullptr;
Expand All @@ -227,8 +227,8 @@ class search
VW::workspace& get_vw_pointer_unsafe(); // although you should rarely need this, some times you need a pointer to the
// vw data structure :(
void set_force_oracle(bool force); // if the library wants to force search to use the oracle, set this to true

search();
~search();
};

// for defining new tasks, you must fill out a search_task
Expand Down
Loading
Loading