Skip to content

Commit

Permalink
Merge pull request #194 from ai4co/fjsp
Browse files Browse the repository at this point in the history
Fjsp
  • Loading branch information
LTluttmann authored Jun 12, 2024
2 parents 45e9893 + 5c2c8a6 commit 810af10
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions rl4co/envs/scheduling/fjsp/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
max_processing_time: int = 20,
min_eligible_ma_per_op: int = 1,
max_eligible_ma_per_op: int = None,
same_mean_per_op: bool = True,
**unused_kwargs,
):
self.num_jobs = num_jobs
Expand All @@ -58,7 +59,7 @@ def __init__(
# determines whether to use a fixed number of total operations or let it vary between instances
# NOTE: due to the way rl4co builds datasets, we need a fixed size here
self.n_ops_max = max_ops_per_job * num_jobs

self.same_mean_per_op = same_mean_per_op
# FFSP environment doen't have any other kwargs
if len(unused_kwargs) > 0:
log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}")
Expand Down Expand Up @@ -86,12 +87,33 @@ def _simulate_processing_times(
ma_ops_edges = ma_ops_edges_unshuffled.gather(2, idx).transpose(1, 2)

# (bs, max_ops, machines)
proc_times = torch.ones((bs, n_ops_max, self.num_mas))
proc_times = torch.randint(
self.min_processing_time,
self.max_processing_time + 1,
size=(bs, self.num_mas, n_ops_max),
)
if self.same_mean_per_op:
proc_times = torch.ones((bs, self.num_mas, n_ops_max))
proc_time_means = torch.randint(
self.min_processing_time, self.max_processing_time, (bs, n_ops_max)
)
low_bounds = torch.maximum(
torch.full_like(proc_times, self.min_processing_time),
(proc_time_means * (1 - 0.2)).round().unsqueeze(1),
)
high_bounds = (
torch.minimum(
torch.full_like(proc_times, self.max_processing_time),
(proc_time_means * (1 + 0.2)).round().unsqueeze(1),
)
+ 1
)
proc_times = (
torch.randint(2**63 - 1, size=proc_times.shape)
% (high_bounds - low_bounds)
+ low_bounds
)
else:
proc_times = torch.randint(
self.min_processing_time,
self.max_processing_time + 1,
size=(bs, self.num_mas, n_ops_max),
)

# remove proc_times for which there is no corresponding ma-ops connection
proc_times = proc_times * ma_ops_edges
Expand Down

0 comments on commit 810af10

Please sign in to comment.