Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Adding GLOP model #182

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
791 changes: 791 additions & 0 deletions examples/other/3-glop.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SPCTSPEnv,
SVRPEnv,
TSPEnv,
SHPPEnv,
)

# Scheduling
Expand All @@ -43,6 +44,7 @@
"tsp": TSPEnv,
"smtwtp": SMTWTPEnv,
"mdcpdp": MDCPDPEnv,
"shpp": SHPPEnv,
}


Expand Down
3 changes: 2 additions & 1 deletion rl4co/envs/routing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rl4co.envs.routing.spctsp.env import SPCTSPEnv
from rl4co.envs.routing.svrp.env import SVRPEnv
from rl4co.envs.routing.tsp.env import TSPEnv
from rl4co.envs.routing.shpp.env import SHPPEnv

from rl4co.envs.routing.atsp.generator import ATSPGenerator
from rl4co.envs.routing.cvrp.generator import CVRPGenerator
Expand All @@ -23,4 +24,4 @@
from rl4co.envs.routing.svrp.generator import SVRPGenerator
from rl4co.envs.routing.tsp.generator import TSPGenerator
from rl4co.envs.routing.mdcpdp.generator import MDCPDPGenerator

from rl4co.envs.routing.shpp.generator import SHPPGenerator
187 changes: 187 additions & 0 deletions rl4co/envs/routing/shpp/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Optional

import torch

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.utils.ops import gather_by_index, get_tour_length
from rl4co.utils.pylogger import get_pylogger
from tensordict.tensordict import TensorDict
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)

from .generator import SHPPGenerator
from .render import render

log = get_pylogger(__name__)


class SHPPEnv(RL4COEnvBase):
"""
Shortest Hamiltonian Path Problem (SHPP)
SHPP is referred to the open-loop Traveling Salesman Problem (TSP) in the literature.
The goal of the SHPP is to find the shortest Hamiltonian path in a given graph with
given fixed starting/terminating nodes (they can be different nodes). A Hamiltonian
path visits all other nodes exactly once. At each step, the agent chooses a city to visit.
The reward is 0 unless the agent visits all the cities. In that case, the reward is
(-)length of the path: maximizing the reward is equivalent to minimizing the path length.

Observation:
- locations of each customer
- starting node and terminating node
- the current location of the vehicle

Constraints:
- the first node is the starting node
- the last node is the terminating node
- each node is visited exactly once

Finish condition:
- the agent has visited all the customers and reached the terminating node

Reward:
- (minus) the length of the path

Args:
generator: SHPPGenerator instance as the generator
generator_params: parameters for the generator
"""

name = "shpp"

def __init__(
self,
generator: SHPPGenerator = None,
generator_params: dict = {},
**kwargs,
):
super().__init__(**kwargs)
if generator is None:
generator = SHPPGenerator(**generator_params)
self.generator = generator
self._make_spec(self.generator)

@staticmethod
def _step(td: TensorDict) -> TensorDict:
current_node = td["action"]
first_node = current_node if td["i"].all() == 0 else td["first_node"]

# Set not visited to 0 (i.e., we visited the node)
available = td["available"].scatter(
-1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0
)

# If all other nodes are visited, the terminating node will be available
action_mask = available.clone()
action_mask[..., -1] = ~available[..., :-1].any(dim=-1)

# We are done there are no unvisited locations
done = torch.sum(available, dim=-1) == 0

# The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
reward = torch.zeros_like(done)

td.update(
{
"first_node": first_node,
"current_node": current_node,
"i": td["i"] + 1,
"available": available,
"action_mask": action_mask,
"reward": reward,
"done": done,
},
)
return td

def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
"""Note: the first node is the starting node; the last node is the terminating node"""
device = td.device
locs = td["locs"]

# We do not enforce loading from self for flexibility
num_loc = locs.shape[-2]

# Other variables
current_node = torch.zeros((batch_size), dtype=torch.int64, device=device)
last_node = torch.full(
(batch_size), num_loc - 1, dtype=torch.int64, device=device
)
available = torch.ones(
(*batch_size, num_loc), dtype=torch.bool, device=device
) # 1 means not visited, i.e. action is allowed
action_mask = torch.zeros((*batch_size, num_loc), dtype=torch.bool, device=device)
action_mask[..., 0] = 1 # Only the start point is availabe at the beginning
i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device)

return TensorDict(
{
"locs": locs,
"first_node": current_node,
"last_node": last_node,
"current_node": current_node,
"i": i,
"available": available,
"action_mask": action_mask,
"reward": torch.zeros((*batch_size, 1), dtype=torch.float32),
},
batch_size=batch_size,
)

def _get_reward(self, td, actions) -> TensorDict:
# Gather locations in order of tour and return distance between them (i.e., -reward)
locs_ordered = gather_by_index(td["locs"], actions)
return -get_tour_length(locs_ordered)

@staticmethod
def check_solution_validity(td: TensorDict, actions: torch.Tensor):
"""Check that solution is valid: nodes are visited exactly once"""
assert (
torch.arange(actions.size(1), out=actions.data.new())
.view(1, -1)
.expand_as(actions)
== actions.data.sort(1)[0]
).all(), "Invalid tour"

@staticmethod
def render(td: TensorDict, actions: torch.Tensor=None, ax = None):
return render(td, actions, ax)

def _make_spec(self, generator):
"""Make the observation and action specs from the parameters"""
self.observation_spec = CompositeSpec(
locs=BoundedTensorSpec(
low=generator.min_loc,
high=generator.max_loc,
shape=(generator.num_loc, 2),
dtype=torch.float32,
),
first_node=UnboundedDiscreteTensorSpec(
shape=(1),
dtype=torch.int64,
),
current_node=UnboundedDiscreteTensorSpec(
shape=(1),
dtype=torch.int64,
),
i=UnboundedDiscreteTensorSpec(
shape=(1),
dtype=torch.int64,
),
action_mask=UnboundedDiscreteTensorSpec(
shape=(generator.num_loc),
dtype=torch.bool,
),
shape=(),
)
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
low=0,
high=generator.num_loc,
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)
55 changes: 55 additions & 0 deletions rl4co/envs/routing/shpp/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Union, Callable

import torch

from torch.distributions import Uniform
from tensordict.tensordict import TensorDict

from rl4co.utils.pylogger import get_pylogger
from rl4co.envs.common.utils import get_sampler, Generator

log = get_pylogger(__name__)


class SHPPGenerator(Generator):
"""Data generator for the Shortest Hamiltonian Path Problem (SHPP).
Args:
num_loc: number of locations (customers) in the TSP
min_loc: minimum value for the location coordinates
max_loc: maximum value for the location coordinates
loc_distribution: distribution for the location coordinates

Returns:
A TensorDict with the following keys:
locs [batch_size, num_loc, 2]: locations of each customer
"""
def __init__(
self,
num_loc: int = 20,
min_loc: float = 0.0,
max_loc: float = 1.0,
loc_distribution: Union[
int, float, str, type, Callable
] = Uniform,
**kwargs
):
self.num_loc = num_loc
self.min_loc = min_loc
self.max_loc = max_loc

# Location distribution
if kwargs.get("loc_sampler", None) is not None:
self.loc_sampler = kwargs["loc_sampler"]
else:
self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs)

def _generate(self, batch_size) -> TensorDict:
# Sample locations
locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2))

return TensorDict(
{
"locs": locs,
},
batch_size=batch_size,
)
66 changes: 66 additions & 0 deletions rl4co/envs/routing/shpp/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import numpy as np
import matplotlib.pyplot as plt

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


def render(td, actions=None, ax=None):
import matplotlib.pyplot as plt
import numpy as np

if ax is None:
# Create a plot of the nodes
_, ax = plt.subplots(figsize=(3, 3))

td = td.detach().cpu()

if actions is None:
actions = td.get("action", None)
# if batch_size greater than 0 , we need to select the first batch element
if td.batch_size != torch.Size([]):
td = td[0]
actions = actions[0]

locs = td["locs"]

# gather locs in order of action if available
if actions is None:
log.warning("No action in TensorDict, rendering unsorted locs")
else:
actions = actions.detach().cpu()
locs = gather_by_index(locs, actions, dim=0)

start_x, start_y = locs[0, 0], locs[0, 1]
end_x, end_y = locs[-1, 0], locs[-1, 1]
city_x, city_y = locs[1:-1, 0], locs[1:-1, 1]
x, y = locs[:, 0], locs[:, 1]

# Plot the start and end nodes
ax.scatter(start_x, start_y, color="tab:green", marker="s")
ax.scatter(end_x, end_y, color="tab:red", marker="x")

# Plot the visited nodes
ax.scatter(city_x, city_y, color="tab:blue")

# Add arrows between visited nodes as a quiver plot
dx, dy = np.diff(x), np.diff(y)
ax.quiver(
x[:-1],
y[:-1],
dx,
dy,
scale_units="xy",
angles="xy",
scale=1,
color="gray",
width=0.003,
headwidth=8,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
1 change: 1 addition & 0 deletions rl4co/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from rl4co.models.zoo.pomo import POMO
from rl4co.models.zoo.ptrnet import PointerNetwork, PointerNetworkPolicy
from rl4co.models.zoo.symnco import SymNCO, SymNCOPolicy
from rl4co.models.zoo.glop import GLOP, GLOPPolicy
35 changes: 35 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mtsp": MTSPContext,
"smtwtp": SMTWTPContext,
"mdcpdp": MDCPDPContext,
"shpp": SHPPContext,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -313,3 +314,37 @@ def __init__(self, embed_dim):
def forward(self, embeddings, td):
cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze()
return self.project_context(cur_node_embedding)


class SHPPContext(EnvContext):
"""Context embedding for the Shortest Hamiltonian Path Problem (SHPP)
Project the following to the embedding space:
- first node embedding
- current node embedding
- terminating node embedding
"""

def __init__(self, embed_dim):
super().__init__(embed_dim, 3 * embed_dim)
self.W_placeholder = nn.Parameter(
torch.Tensor(3 * self.embed_dim).uniform_(-1, 1)
)

def forward(self, embeds, td):
batch_size = embeds.size(0)
# By default, node_dim = -1 (we only have one node embedding per node)
node_dim = (
(-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1)
)
if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast
context_embed = self.W_placeholder[None, :].expand(
batch_size, self.W_placeholder.size(-1)
)
else:
context_embed = gather_by_index(
embeds,
torch.stack(
[td["first_node"], td["current_node"], td["last_node"]], -1
).view(batch_size, -1),
).view(batch_size, *node_dim)
return self.project_context(context_embed)
1 change: 1 addition & 0 deletions rl4co/models/nn/env_embeddings/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module:
"pdp": StaticEmbedding,
"mtsp": StaticEmbedding,
"smtwtp": StaticEmbedding,
"shpp": StaticEmbedding,
}

if env_name not in embedding_registry:
Expand Down
Loading
Loading