Skip to content

Commit

Permalink
7 add standard a (omron-sinicx#8)
Browse files Browse the repository at this point in the history
* Add pq astar

* Refactor neural A*

* Add test

* Update notebooks and results
  • Loading branch information
yonetaniryo authored Dec 17, 2022
1 parent dc01899 commit e3e85c4
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 104 deletions.
139 changes: 77 additions & 62 deletions example.ipynb

Large diffs are not rendered by default.

Binary file modified model/mazes_032_moore_c8/best.pt
Binary file not shown.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"torch==1.12.1",
"torchvision==0.13.1",
"segmentation-models-pytorch==0.3.1",
"pqdict==1.2.0",
"hydra-core==1.2.0",
"numpy>=1.19.2",
"tensorboard>=2.5",
Expand Down
146 changes: 105 additions & 41 deletions src/neural_astar/planner/astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,64 @@
"""
from __future__ import annotations

from typing import Optional
from functools import partial

import torch
import torch.nn as nn

from . import encoder
from .differentiable_astar import AstarOutput, DifferentiableAstar
from .pq_astar import pq_astar


class NeuralAstar(nn.Module):
class VanillaAstar(nn.Module):
def __init__(
self,
g_ratio: float = 0.5,
Tmax: float = 1.0,
encoder_input: str = "m+",
encoder_arch: float = "CNN",
encoder_depth: int = 4,
learn_obstacles: bool = False,
const: Optional[float] = None,
use_differentiable_astar: bool = True,
):
"""
Neural A* search
Vanilla A* search
Args:
g_ratio (float, optional): ratio between g(v) + h(v). Set 0 to perform as best-first search. Defaults to 0.5.
Tmax (float, optional): how much of the map the model explores during training. Set a small value (0.25) when training the model. Defaults to 1.0.
encoder_input (str, optional): input format. Set "m+" to use the concatenation of map_design and (start_map + goal_map). Set "m" to use map_design only. Defaults to "m+".
encoder_backbone (str, optional): encoder architecture. Defaults to "vgg16_bn".
encoder_depth (int, optional): depth of the encoder. Defaults to 4.
learn_obstacles (bool, optional): if the obstacle is invisible to the model. Defaults to False.
const (Optional[float], optional): learnable weight to be multiplied for h(v). Defaults to None.
use_differentiable_astar (bool, optional): if the differentiable A* is used instead of standard A*. Defaults to True.
Examples:
>>> planner = NeuralAstar()
>>> planner = VanillaAstar()
>>> outputs = planner(map_designs, start_maps, goal_maps)
>>> histories = outputs.histories
>>> paths = outputs.paths
Note:
For perform inference on a large map, set use_differentiable_astar = False to peform a faster A* with priority queue
"""

super().__init__()
self.astar = DifferentiableAstar(
g_ratio=g_ratio,
Tmax=Tmax,
Tmax=1.0,
)
self.encoder_input = encoder_input
encoder_arch = getattr(encoder, encoder_arch)
self.encoder = encoder_arch(len(self.encoder_input), encoder_depth, const)
self.learn_obstacles = learn_obstacles
if self.learn_obstacles:
print("WARNING: learn_obstacles has been set to True")
self.g_ratio = g_ratio
self.use_differentiable_astar = use_differentiable_astar

def forward(
def perform_astar(
self,
map_designs: torch.tensor,
start_maps: torch.tensor,
goal_maps: torch.tensor,
obstacles_maps: torch.tensor,
store_intermediate_results: bool = False,
) -> AstarOutput:
inputs = map_designs
if "+" in self.encoder_input:
inputs = torch.cat((inputs, start_maps + goal_maps), dim=1)
pred_cost_maps = self.encoder(inputs)
obstacles_maps = (
map_designs if not self.learn_obstacles else torch.ones_like(map_designs)

astar = (
self.astar
if self.use_differentiable_astar
else partial(pq_astar, g_ratio=self.g_ratio)
)

astar_outputs = self.astar(
pred_cost_maps,
astar_outputs = astar(
map_designs,
start_maps,
goal_maps,
obstacles_maps,
Expand All @@ -79,30 +70,86 @@ def forward(

return astar_outputs

def forward(
self,
map_designs: torch.tensor,
start_maps: torch.tensor,
goal_maps: torch.tensor,
store_intermediate_results: bool = False,
) -> AstarOutput:
"""
Perform A* search
class VanillaAstar(nn.Module):
Args:
map_designs (torch.tensor): map designs (obstacle maps or raw image)
start_maps (torch.tensor): start maps indicating the start location with one-hot binary map
goal_maps (torch.tensor): goal maps indicating the goal location with one-hot binary map
store_intermediate_results (bool, optional): If the intermediate search results are stored in Astar output. Defaults to False.
Returns:
AstarOutput: search histories and solution paths, and optionally intermediate search results.
"""

cost_maps = map_designs
obstacles_maps = map_designs

return self.perform_astar(
cost_maps,
start_maps,
goal_maps,
obstacles_maps,
store_intermediate_results,
)


class NeuralAstar(VanillaAstar):
def __init__(
self,
g_ratio: float = 0.5,
Tmax: float = 1.0,
encoder_input: str = "m+",
encoder_arch: float = "CNN",
encoder_depth: int = 4,
learn_obstacles: bool = False,
const: float = None,
use_differentiable_astar: bool = True,
):
"""
Vanilla A* search
Neural A* search
Args:
g_ratio (float, optional): ratio between g(v) + h(v). Set 0 to perform as best-first search. Defaults to 0.5.
Tmax (float, optional): how much of the map the model explores during training. Set a small value (0.25) when training the model. Defaults to 1.0.
encoder_input (str, optional): input format. Set "m+" to use the concatenation of map_design and (start_map + goal_map). Set "m" to use map_design only. Defaults to "m+".
encoder_arch (str, optional): encoder architecture. Defaults to "CNN".
encoder_depth (int, optional): depth of the encoder. Defaults to 4.
learn_obstacles (bool, optional): if the obstacle is invisible to the model. Defaults to False.
const (float, optional): learnable weight to be multiplied for h(v). Defaults to None.
use_differentiable_astar (bool, optional): if the differentiable A* is used instead of standard A*. Defaults to True.
Examples:
>>> planner = VanillaAstar()
>>> planner = NeuralAstar()
>>> outputs = planner(map_designs, start_maps, goal_maps)
>>> histories = outputs.histories
>>> paths = outputs.paths
Note:
For perform inference on a large map, set use_differentiable_astar = False to peform a faster A* with priority queue
"""

super().__init__()
self.astar = DifferentiableAstar(
g_ratio=g_ratio,
Tmax=1.0,
Tmax=Tmax,
)
self.encoder_input = encoder_input
encoder_arch = getattr(encoder, encoder_arch)
self.encoder = encoder_arch(len(self.encoder_input), encoder_depth, const)
self.learn_obstacles = learn_obstacles
if self.learn_obstacles:
print("WARNING: learn_obstacles has been set to True")
self.g_ratio = g_ratio
self.use_differentiable_astar = use_differentiable_astar

def forward(
self,
Expand All @@ -111,14 +158,31 @@ def forward(
goal_maps: torch.tensor,
store_intermediate_results: bool = False,
) -> AstarOutput:
obstacles_maps = map_designs
"""
Perform neural A* search
astar_outputs = self.astar(
map_designs,
Args:
map_designs (torch.tensor): map designs (obstacle maps or raw image)
start_maps (torch.tensor): start maps indicating the start location with one-hot binary map
goal_maps (torch.tensor): goal maps indicating the goal location with one-hot binary map
store_intermediate_results (bool, optional): If the intermediate search results are stored in Astar output. Defaults to False.
Returns:
AstarOutput: search histories and solution paths, and optionally intermediate search results.
"""

inputs = map_designs
if "+" in self.encoder_input:
inputs = torch.cat((inputs, start_maps + goal_maps), dim=1)
cost_maps = self.encoder(inputs)
obstacles_maps = (
map_designs if not self.learn_obstacles else torch.ones_like(map_designs)
)

return self.perform_astar(
cost_maps,
start_maps,
goal_maps,
obstacles_maps,
store_intermediate_results,
)

return astar_outputs
14 changes: 14 additions & 0 deletions src/neural_astar/planner/differentiable_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,20 @@ def forward(
obstacles_maps: torch.tensor,
store_intermediate_results: bool = False,
) -> AstarOutput:
"""
Perform differentiable A* search
Args:
cost_maps (torch.tensor): cost maps
start_maps (torch.tensor): start maps indicating the start location with one-hot binary map
goal_maps (torch.tensor): goal maps indicating the goal location with one-hot binary map
obstacle_maps (torch.tensor): binary maps indicating obstacle locations
store_intermediate_results (bool, optional): If the intermediate search results are stored in Astar output. Defaults to False.
Returns:
AstarOutput: search histories and solution paths, and optionally intermediate search results.
"""

assert cost_maps.ndim == 4
assert start_maps.ndim == 4
assert goal_maps.ndim == 4
Expand Down
Loading

0 comments on commit e3e85c4

Please sign in to comment.