forked from asparagus/search
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
458 lines (344 loc) · 12.3 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
#!/usr/bin/python
# -*- coding: utf-8 -*-
"""Algorithms for search."""
import abc
import sys
import six
import math
import time
import heapq
import utils
import random
import collections
class TimeoutError(Exception):
"""Exception raised when a function call times out."""
pass
@six.add_metaclass(abc.ABCMeta)
class Search:
"""A type of search."""
def create_queue(self):
"""Create a queue for storing the states in the search."""
raise NotImplementedError()
def is_new(self, state, seen, prbolem):
"""Check if a state is new."""
return state not in seen
def add_to_seen(self, state, seen, problem):
"""Add a state to seen set."""
seen.add(state)
def create_seen_set(self):
"""Create a structure to store information of states seen."""
return set()
def push_if_new(self, queue, state, seen, problem):
"""Add a state to the queue if it hasn't been evaluated yet."""
if state not in seen:
self.push(queue, state)
self.add_to_seen(state, seen, problem)
def push(self, queue, state):
"""Add a state to the queue."""
raise NotImplementedError()
def pop(self, queue):
"""Get the next state from the queue."""
raise NotImplementedError()
def branch(self, problem, state):
"""Branch a state into its possible continuations."""
actions = problem.actions(state)
new_states = []
for action in actions:
new_state = action(state)
try:
new_state._action_history = state._action_history + [action]
except:
new_state._action_history = [action]
new_states.append(new_state)
return new_states
def solve(self, problem, initial_state=None, timeout=None):
"""Get a solution to the problem."""
if not timeout:
timeout = float('inf')
start = time.time()
initial_state = initial_state or problem.initial_state()
queue = self.create_queue()
seen = self.create_seen_set()
self.push_if_new(queue, initial_state, seen, problem)
while len(queue) > 0:
current = time.time()
if current - start > timeout:
raise TimeoutError()
state = self.pop(queue)
if problem.is_solution(state):
return state
branched_states = self.branch(problem, state)
for branched_state in branched_states:
self.push_if_new(queue, branched_state, seen, problem)
return None
class BreadthFirstSearch(Search):
"""
A breadth first search.
>>> bfs = BreadthFirstSearch()
>>> q = bfs.create_queue()
>>> bfs.push(q, 1)
>>> bfs.push(q, 2)
>>> len(q)
2
>>> bfs.pop(q)
1
>>> len(q)
1
"""
def create_queue(self):
"""Create a FIFO queue for storing the states in the search."""
return collections.deque()
def push(self, queue, state):
"""Add a state to the queue."""
queue.append(state)
def pop(self, queue):
"""Get the next state from the queue."""
return queue.popleft()
class DepthFirstSearch(Search):
"""
A depth first search.
>>> dfs = DepthFirstSearch()
>>> q = dfs.create_queue()
>>> dfs.push(q, 1)
>>> dfs.push(q, 2)
>>> len(q)
2
>>> dfs.pop(q)
2
>>> len(q)
1
"""
def create_queue(self):
"""Create a LIFO stack for storing the states in the search."""
return []
def push(self, queue, state):
"""Add a state to the stack."""
queue.append(state)
def pop(self, queue):
"""Get the next state from the stack."""
return queue.pop()
class BestFirstSearch(Search):
"""An optiminal search."""
def __init__(self, heuristic=None):
"""
Initialize an instance of A* search.
The instance requires an heuristic function which
receives a state and outputs an expected delta for the solution.
The heuristic must be admissible to ensure an optimal solution.
If no heuristic is provided, the Zero Heuristic is used.
>>> a = BestFirstSearch()
>>> a.heuristic(1)
0
>>> a.heuristic("Sample")
0
"""
self.heuristic = heuristic or ZeroHeuristic()
self.value_states_dict = {}
def create_queue(self):
"""Create a priority queue for storing the states in the search."""
return []
def push(self, queue, state, value=None):
"""
Add a state to the priority queue.
Manage the priority queue as a heap of (value, stack),
where stack contains all the states with the same value.
This minimizes the number of times the heap is used.
States are retrieved using LIFO in order to try a depth first approach.
[value] => stack
relationships are stored in the variable self.value_states_dict
"""
if value is None:
g = state.value
h = self.heuristic(state)
f = g + h
else:
f = value
if f in self.value_states_dict:
stack = self.value_states_dict[f]
else:
stack = []
self.value_states_dict[f] = stack
heapq.heappush(queue, (f, stack))
stack.append(state)
def pop(self, queue):
"""Get the next state from the priority queue."""
value, stack = queue[0]
element = stack.pop()
if not stack:
heapq.heappop(queue)
del self.value_states_dict[value]
return element
class IterativeDepthFirstSearch(BestFirstSearch):
"""
An optimal iterative search.
This algorithm iteratively improves the found solution until it's optimal
or time runs out.
"""
def sort_states(self, problem, states):
"""Sort branched states before insertion."""
pass
def pop(self, queue):
"""Get the next state from the priority queue."""
value, stack = queue[0]
element = stack.pop()
if not stack:
heapq.heappop(queue)
del self.value_states_dict[value]
return value, element
def solve(self, problem, initial_state=None,
timeout=None, soft_timeout=None):
"""Get a solution to the problem."""
if not timeout:
timeout = float('inf')
if not soft_timeout:
soft_timeout = min(timeout, float('inf'))
start = time.time()
initial_state = initial_state or problem.initial_state()
queue = self.create_queue()
seen = self.create_seen_set()
self.push_if_new(queue, initial_state, seen, problem)
best_solution = None
best_value = float('inf')
while len(queue) > 0:
value, state = self.pop(queue)
if value >= best_value:
break
current = time.time()
ellapsed_time = current - start
if best_solution and ellapsed_time > soft_timeout:
break
elif ellapsed_time > timeout:
raise TimeoutError()
remaining_time = timeout - ellapsed_time
new_solution = self.run(
problem, state, queue, seen, remaining_time)
if new_solution:
new_value = new_solution.value
if new_value < best_value or not best_solution:
best_solution = new_solution
best_value = new_value
return best_solution
def run(self, problem, initial_state, queue, seen,
timeout=None):
"""
Get a temporary solution.
Multiple calls to run function will improve on the initial solution.
"""
if not timeout:
timeout = float('inf')
start = time.time()
reached_solution = False
current_state = initial_state
while True:
current = time.time()
if current - start > timeout:
break
if problem.is_solution(current_state):
reached_solution = True
break
branched_states = self.branch(problem, current_state)
branched_states = filter(
lambda x: self.is_new(x, seen, problem),
branched_states
)
self.sort_states(problem, branched_states)
new_seen = self.create_seen_set()
new_states = []
for state in branched_states:
if not new_seen or self.is_new(state, new_seen, problem):
new_states.append(state)
self.add_to_seen(state, new_seen, problem)
if new_states:
new_states.reverse()
# Continue the run through the first state
current_state = new_states[0]
# Store the rest for another run
new_values = [state.value + self.heuristic(state)
for state in new_states]
values_with_states = zip(new_values, new_states)
for value, state in values_with_states[1:]:
self.push(queue, state, value=value)
self.add_to_seen(state, seen, problem)
else: # No way to continue this run
break
return current_state if reached_solution else None
class GreedySearch(Search):
"""A greedy search."""
def solve(self, problem, initial_state=None, timeout=None):
"""Get a solution to the problem."""
if not timeout:
timeout = float('inf')
start = time.time()
current_state = initial_state or problem.initial_state()
best_solution = current_state
best_value = current_state.value
while current_state:
current = time.time()
if current - start > timeout:
break
neighbours = self.branch(problem, current_state)
if neighbours:
next_state = utils.argmin_random_tie(
neighbours,
key=lambda state: state.value
)
if next_state and next_state.value >= best_value:
next_state = None
else:
next_state = None
current_state = next_state
return best_solution
class SimulatedAnnealing(Search):
"""SimulatedAnnealing search implementation taken from aima-python."""
def __init__(self, schedule=utils.exp_schedule()):
"""Initialize the Simulated Annealing with a schedule function."""
self.schedule = schedule
def solve(self, problem, initial_state=None, timeout=None):
"""Get a solution to the problem."""
if not timeout:
timeout = float('inf')
start = time.time()
current_state = initial_state or problem.initial_state()
best_solution = current_state
best_value = current_state.value
for _t in xrange(sys.maxsize):
st = self.schedule(_t)
current = time.time()
if current - start > timeout or st == 0:
break
neighbours = self.branch(problem, current_state)
if not neighbours:
break
next_state = random.choice(neighbours)
delta_e = next_state.value.diff(current_state.value)
if delta_e <= 0 or utils.probability(math.exp(-delta_e / st)):
current_state = next_state
if current_state.value < best_value:
best_solution = current_state
best_value = current_state.value
return best_solution
@six.add_metaclass(abc.ABCMeta)
class Heuristic:
"""An evaluation function used for heuristic purposes."""
@abc.abstractmethod
def __call__(self, state):
"""Evaluate a state."""
raise NotImplementedError()
class ZeroHeuristic(Heuristic):
"""
The Zero Heuristic.
>>> h = ZeroHeuristic()
>>> h(1)
0
>>> h("state")
0
"""
def __call__(self, state):
"""Evaluate a state, return zero."""
return 0
def unit_test():
"""Test the module."""
import doctest
doctest.testmod()
if __name__ == '__main__':
unit_test()