Skip to content

Commit

Permalink
Add reset test in sc scene and fix the problem found by reset test
Browse files Browse the repository at this point in the history
Add reset test in sc scene and fix the problem found by reset test
  • Loading branch information
v-heli committed Jun 16, 2022
1 parent ab118b8 commit daecf0c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ class DistributionDataModel(DataModelBase):
def __init__(self) -> None:
super(DistributionDataModel, self).__init__()

def initialize(self) -> None:
self.reset()

def reset(self) -> None:
super(DistributionDataModel, self).reset()
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def initialize(self) -> None:
self._busy_vehicle_num[vehicle_type] = 0

# TODO: add vehicle patient setting if needed
self.data_model.initialize()

for sku_id in self.facility.products.keys():
self._unit_delay_order_penalty[sku_id] = self.facility.skus[sku_id].unit_delay_order_penalty
Expand Down
78 changes: 52 additions & 26 deletions tests/supply_chain/test_env_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ def test_distribution_unit_reset(self) -> None:

# ##################################### Before reset #####################################

order_1 = Order(src_facility=supplier_3,
dest_facility=warehouse_1,
sku_id=SKU3_ID,
quantity=10,
vehicle_type="train",
creation_tick=env.tick,
expected_finish_tick=env.tick + 7, )
order_1 = Order(
src_facility=supplier_3,
dest_facility=warehouse_1,
sku_id=SKU3_ID,
quantity=10,
vehicle_type="train",
creation_tick=env.tick,
expected_finish_tick=env.tick + 7, )

# There are 2 "train" in total, and 1 left after scheduling this order.
distribution_unit.place_order(order_1)
Expand Down Expand Up @@ -243,15 +244,15 @@ def test_manufacture_unit_reset(self) -> None:
sku3_manufacture_index = manufacture_sku3_unit.data_model_index

storage_nodes = env.snapshot_list["storage"]
manufacture_nodes = env.snapshot_list["manufacture"]

manufacture_features = (
"id", "facility_id", "start_manufacture_quantity", "sku_id", "in_pipeline_quantity", "finished_quantity",
"product_unit_id",
)

# ############################### TICK: 0 ######################################
# ##################################### Before reset #####################################

# tick 0 passed, no product manufacturing.
env.step(None)

capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int)
Expand All @@ -272,17 +273,19 @@ def test_manufacture_unit_reset(self) -> None:
# all the id is greater than 0
self.assertGreater(manufacture_sku3_unit.id, 0)

# ######################################################################

# pass an action to start manufacturing for this tick.
action = ManufactureAction(manufacture_sku3_unit.id, 1)

expect_tick = 30

# Save the env.metric of each tick into env_metric_1
env_metric_1: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot manufacture unit of each tick in states_1
states_1: Dict[int, list] = defaultdict(list)

random_tick: List[int] = []
manufacture_nodes = env.snapshot_list["manufacture"]

# The purpose is to randomly perform the order operation
for i in range(10):
random_tick.append(random.randint(1, 30))

Expand All @@ -293,13 +296,17 @@ def test_manufacture_unit_reset(self) -> None:
env_metric_1[i] = env.metrics
states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int)

# ############################### Test whether reset updates the distribution unit completely ################
# ############################### Test whether reset updates the manufacture unit completely ################
env.reset()
env.step(None)

# snapshot should reset after env.reset().
states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int)
self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states))

storage_nodes = env.snapshot_list["storage"]
manufacture_nodes = env.snapshot_list["manufacture"]

capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int)
remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int)

Expand All @@ -319,9 +326,12 @@ def test_manufacture_unit_reset(self) -> None:
self.assertGreater(manufacture_sku3_unit.id, 0)

expect_tick = 30

# Save the env.metric of each tick into env_metric_2
env_metric_2: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot manufacture unit of each tick in states_2
states_2: Dict[int, list] = defaultdict(list)
manufacture_nodes = env.snapshot_list["manufacture"]

for i in range(expect_tick):
env.step([action])
Expand All @@ -344,13 +354,13 @@ def test_seller_unit_dynamics_sampler(self):

env.step(None)
Store_001: FacilityBase = be.world._get_facility_by_name("Store_001")
seller_unit = Store_001.products[FOOD_1_ID].seller

seller_unit = Store_001.products[FOOD_1_ID].seller
seller_node_index = seller_unit.data_model_index

seller_nodes = env.snapshot_list["seller"]

features = ("sold", "demand", "total_sold", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",)
features = ("sold", "demand", "total_sold", "id", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",)
# ##################################### Before reset #####################################

self.assertEqual(20, seller_unit.sku_id)

Expand All @@ -363,23 +373,33 @@ def test_seller_unit_dynamics_sampler(self):
self.assertEqual(10, seller_unit._total_sold)

expect_tick = 12

# Save the env.metric of each tick into env_metric_1
env_metric_1: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot seller unit of each tick in states_1
states_1: Dict[int, list] = defaultdict(list)
for i in range(expect_tick):
env.step(None)
env_metric_1[i] = env.metrics
states_1[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int)

# ############################### Test whether reset updates the distribution unit completely ################
# ################# Test whether reset updates the seller unit completely ################
env.reset()
env.step(None)

# snapshot should reset after env.reset().
states = seller_nodes[1:seller_node_index:features].flatten().astype(np.int)
self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states))
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states))

expect_tick = 12

# Save the env.metric of each tick into env_metric_2
env_metric_2: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot seller unit of each tick in states_2
states_2: Dict[int, list] = defaultdict(list)

for i in range(expect_tick):
env.step(None)
env_metric_2[i] = env.metrics
Expand All @@ -398,19 +418,20 @@ def test_storage_unit_reset(self) -> None:
env.step(None)

supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3")

storage_unit: StorageUnit = supplier_3.storage
storage_node_index = storage_unit.data_model_index

storage_nodes = env.snapshot_list["storage"]

features = ("id", "facility_id",)

# ############################### Take more than existing ######################################

# which this setting, it will return false, as no enough product for ous
# ##################################### Before reset #####################################

expect_tick = 10

# Save the env.metric of each tick into env_metric_1
env_metric_1: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot storage unit of each tick in states_1
states_1: Dict[int, list] = defaultdict(list)
for i in range(expect_tick):
env.step(None)
Expand All @@ -420,17 +441,22 @@ def test_storage_unit_reset(self) -> None:
states_1[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum())
states_1[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum())

# ############################### Test whether reset updates the distribution unit completely ################
# ############################### Test whether reset updates the storage unit completely ################
env.reset()
env.step(None)

# snapshot should reset after env.reset().
states = storage_nodes[1:storage_node_index:features].flatten().astype(np.int)
self.assertEqual([0, 0], list(states))

expect_tick = 10

# Save the env.metric of each tick into env_metric_2
env_metric_2: Dict[int, dict] = defaultdict(dict)

# Store the information about the snapshot storage unit of each tick in states_2
states_2: Dict[int, list] = defaultdict(list)

for i in range(expect_tick):
env.step(None)
env_metric_2[i] = env.metrics
Expand Down

0 comments on commit daecf0c

Please sign in to comment.