diff --git a/maro/simulator/scenarios/supply_chain/datamodels/distribution.py b/maro/simulator/scenarios/supply_chain/datamodels/distribution.py index 051c458d3..c2f3d45c2 100644 --- a/maro/simulator/scenarios/supply_chain/datamodels/distribution.py +++ b/maro/simulator/scenarios/supply_chain/datamodels/distribution.py @@ -19,5 +19,8 @@ class DistributionDataModel(DataModelBase): def __init__(self) -> None: super(DistributionDataModel, self).__init__() + def initialize(self) -> None: + self.reset() + def reset(self) -> None: super(DistributionDataModel, self).reset() diff --git a/maro/simulator/scenarios/supply_chain/units/distribution.py b/maro/simulator/scenarios/supply_chain/units/distribution.py index 053a54ecf..29ea8a263 100644 --- a/maro/simulator/scenarios/supply_chain/units/distribution.py +++ b/maro/simulator/scenarios/supply_chain/units/distribution.py @@ -101,6 +101,8 @@ def initialize(self) -> None: # TODO: add vehicle patient setting if needed + self.data_model.initialize() + for sku_id in self.facility.products.keys(): self._unit_delay_order_penalty[sku_id] = self.facility.skus[sku_id].unit_delay_order_penalty diff --git a/tests/data/supply_chain/case_04/test_case_04.csv b/tests/data/supply_chain/case_04/test_case_04.csv index 8f77f53c7..acbc2b922 100644 --- a/tests/data/supply_chain/case_04/test_case_04.csv +++ b/tests/data/supply_chain/case_04/test_case_04.csv @@ -4,11 +4,23 @@ food_1,2021/1/2,43.1,33.39,20 food_1,2021/1/3,43.2,33.39,30 food_1,2021/1/4,43.3,33.39,40 food_1,2021/1/5,43.4,33.39,50 +food_1,2021/1/6,43.4,33.39,60 +food_1,2021/1/7,43.4,33.39,70 +food_1,2021/1/8,43.4,33.39,80 +food_1,2021/1/9,43.4,33.39,90 +food_1,2021/1/10,43.4,33.39,100 +food_1,2021/1/11,43.4,33.39,110 hobby_1,2021/1/1,28.32,21.79,100 hobby_1,2021/1/2,28.32,21.79,200 hobby_1,2021/1/3,28.32,21.79,300 hobby_1,2021/1/4,28.32,21.79,400 hobby_1,2021/1/5,28.32,21.79,500 +hobby_1,2021/1/6,28.32,21.79,600 +hobby_1,2021/1/7,28.32,21.79,700 +hobby_1,2021/1/8,28.32,21.79,800 +hobby_1,2021/1/9,28.32,21.79,900 +hobby_1,2021/1/10,28.32,21.79,1000 +hobby_1,2021/1/11,28.32,21.79,1100 household_1,2022/5/14,17.35974446,14.4,146 household_1,2022/5/15,17.29592816,14.4,164 household_1,2022/5/16,17.43129521,14.4,127 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 93f2fdf51..08089864e 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -146,6 +146,30 @@ world: unit_delay_order_penalty: 20 unit_order_cost: 0 + - name: "Supplier_SKU1" + definition_ref: "SupplierFacility" + skus: + sku1: + init_stock: 20 + has_manufacture: True + max_manufacture_rate: 50 + manufacture_leading_time: 1 + unit_product_cost: 1 + price: 10 + unit_delay_order_penalty: 10 + has_consumer: True + sku3: + init_stock: 80 + unit_product_cost: 1 + price: 10 + unit_delay_order_penalty: 10 + children: + storage: *small_storage + distribution: *normal_distribution + config: + unit_delay_order_penalty: 20 + unit_order_cost: 0 + - name: "Warehouse_001" definition_ref: "WarehouseFacility" skus: @@ -154,6 +178,7 @@ world: sub_storage_id: 1 storage_upper_bound: 40 price: 100 + has_consumer: True sku2: init_stock: 12 sub_storage_id: 1 @@ -257,6 +282,11 @@ world: cost: 0.6 Warehouse_001: + sku1: + "Supplier_SKU1": + "train": + vlt: 3 + cost: 1 sku3: "Supplier_SKU3": "train": diff --git a/tests/supply_chain/common.py b/tests/supply_chain/common.py index a3ec84689..3be8507fc 100644 --- a/tests/supply_chain/common.py +++ b/tests/supply_chain/common.py @@ -2,25 +2,93 @@ # Licensed under the MIT license. import os +from collections import defaultdict +from typing import Dict, List, Tuple import numpy as np from maro.simulator import Env -def build_env(case_name: str, durations: int): +def build_env(case_name: str, durations: int) -> Env: case_folder = os.path.join("tests", "data", "supply_chain", case_name) + return Env(scenario="supply_chain", topology=case_folder, durations=durations) - env = Env(scenario="supply_chain", topology=case_folder, durations=durations) - return env - - -def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int): +def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int) -> Dict[int, int]: sku_id_list = env.snapshot_list["storage"][frame_index:node_index:"sku_id_list"].flatten().astype(np.int) product_quantity = env.snapshot_list["storage"][frame_index:node_index:"product_quantity"].flatten().astype(np.int) - return {sku_id: quantity for sku_id, quantity in zip(sku_id_list, product_quantity)} + return dict(zip(sku_id_list, product_quantity)) + + +def snapshot_query(env: Env, i: int) -> Tuple[ + Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list] +]: + consumer_nodes = env.snapshot_list["consumer"] + storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] + + states_consumer: Dict[int, list] = {} + states_storage: Dict[int, list] = {} + states_seller: Dict[int, list] = {} + states_manufacture: Dict[int, list] = {} + states_distribution: Dict[int, list] = {} + + env_metric = env.metrics + + for idx in range(len(consumer_nodes)): + states_consumer[idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.float) + + for idx in range(len(storage_nodes)): + states_storage[idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.float)) + states_storage[idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_storage[idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.float)) + states_storage[idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.float)) + + for idx in range(len(manufacture_nodes)): + states_manufacture[idx] = ( + manufacture_nodes[i:idx:manufacture_features] + .flatten() + .astype( + np.float, + ) + ) + + for idx in range(len(distribution_nodes)): + states_distribution[idx] = ( + distribution_nodes[i:idx:distribution_features] + .flatten() + .astype( + np.float, + ) + ) + + for idx in range(len(seller_nodes)): + states_seller[idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.float) + + return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution + + +def test_env_reset_snapshot_query( + env: Env, + action_1: object, + action_2: object, + expect_tick: int, + random_tick: list = None, +) -> List[tuple]: + snapshots: List[tuple] = [] # List of (env_metric, states_consumer, ..., states_distribution) + for i in range(expect_tick): + snapshots.append(snapshot_query(env, i)) + env.step(action_1) + + if random_tick is not None: + if i in random_tick: + env.step(action_2) + + return list(zip(*snapshots)) SKU1_ID = 1 @@ -29,3 +97,40 @@ def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int): SKU4_ID = 4 FOOD_1_ID = 20 HOBBY_1_ID = 30 + +consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", +) + +storage_features = ("id", "facility_id") + +seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", +) + +manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", +) + +distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") diff --git a/tests/supply_chain/test_supply_chain_consumer_unit.py b/tests/supply_chain/test_consumer_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_consumer_unit.py rename to tests/supply_chain/test_consumer_unit.py diff --git a/tests/supply_chain/test_supply_chain_distribution_unit.py b/tests/supply_chain/test_distribution_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_distribution_unit.py rename to tests/supply_chain/test_distribution_unit.py diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py new file mode 100644 index 000000000..4e6227d4c --- /dev/null +++ b/tests/supply_chain/test_env_reset.py @@ -0,0 +1,406 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +import unittest +from typing import List + +import numpy as np + +from maro.simulator.scenarios.supply_chain import ( + ConsumerAction, + ConsumerUnit, + FacilityBase, + ManufactureAction, + ManufactureUnit, + RetailerFacility, +) +from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine + +from tests.supply_chain.common import ( + SKU1_ID, + SKU3_ID, + build_env, + get_product_dict_from_storage, + snapshot_query, + test_env_reset_snapshot_query, +) + + +class MyTestCase(unittest.TestCase): + """ + . test env reset with none action + . with ManufactureAction only + . with ConsumerAction only + . with both ManufactureAction and ConsumerAction + """ + + def test_env_reset_with_none_action(self) -> None: + """test_env_reset_with_none_action""" + env = build_env("case_05", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + # ##################################### Before reset ##################################### + + expect_tick = 10 + + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=None, + action_2=None, + expect_tick=expect_tick, + random_tick=None, + ) + + # ############################### Test whether reset updates the storage unit completely ################ + env.reset() + env.step(None) + + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + + # Save the env.metric of each tick into env_metric_2 + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=None, + action_2=None, + expect_tick=expect_tick, + random_tick=None, + ) + + for i in range(expect_tick): + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ManufactureAction_only(self) -> None: + """test env reset with ManufactureAction only""" + env = build_env("case_02", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + + storage_unit = supplier_3.storage + manufacture_unit = supplier_3.products[SKU3_ID].manufacture + storage_nodes = env.snapshot_list["storage"] + + # ##################################### Before reset ##################################### + + env.step(None) + + storage_node_index = storage_unit.data_model_index + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + expect_tick = 30 + + action_1 = ManufactureAction(manufacture_unit.id, 1) + action_2 = ManufactureAction(manufacture_unit.id, 0) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(10): + random_tick.append(random.randint(1, 30)) + + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action_1, + action_2=action_2, + expect_tick=expect_tick, + random_tick=random_tick, + ) + + # ############################### Test whether reset updates the manufacture unit completely ################ + env.reset() + env.step(None) + + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_unit.id, 0) + + # Save the env.metric of each tick into env_metric_2 + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action_1, + action_2=action_2, + expect_tick=expect_tick, + random_tick=random_tick, + ) + + for i in range(expect_tick): + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ConsumerAction_only(self) -> None: + """ "test env reset with ConsumerAction only""" + env = build_env("case_05", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + consumer_unit = warehouse_1.products[SKU3_ID].consumer + + # ##################################### Before reset ##################################### + action = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action, + action_2=None, + expect_tick=expect_tick, + random_tick=None, + ) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + + # Save the env.metric of each tick into env_metric_2 + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action, + action_2=None, + expect_tick=expect_tick, + random_tick=None, + ) + + for i in range(expect_tick): + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: + """test env reset with both ManufactureAction and ConsumerAction""" + env = build_env("case_05", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + warehouse_1: RetailerFacility = be.world._get_facility_by_name("Warehouse_001") + consumer_unit: ConsumerUnit = warehouse_1.products[SKU1_ID].consumer + manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture + + # ##################################### Before reset ##################################### + action_consumer = ConsumerAction(consumer_unit.id, SKU1_ID, supplier_1.id, 5, "train") + action_manufacture = ManufactureAction(manufacture_unit.id, 1) + + expect_tick = 100 + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(30): + random_tick.append(random.randint(0, 90)) + + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action_consumer, + action_2=action_manufacture, + expect_tick=expect_tick, + random_tick=random_tick, + ) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + + # Save the env.metric of each tick into env_metric_2 + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env=env, + action_1=action_consumer, + action_2=action_manufacture, + expect_tick=expect_tick, + random_tick=random_tick, + ) + + for i in range(expect_tick): + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/supply_chain/test_supply_chain_manufacture_unit.py b/tests/supply_chain/test_manufacture_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_manufacture_unit.py rename to tests/supply_chain/test_manufacture_unit.py diff --git a/tests/supply_chain/test_supply_chain_readdata.py b/tests/supply_chain/test_read_data.py similarity index 100% rename from tests/supply_chain/test_supply_chain_readdata.py rename to tests/supply_chain/test_read_data.py diff --git a/tests/supply_chain/test_supply_chain_seller_unit.py b/tests/supply_chain/test_seller_unit.py similarity index 99% rename from tests/supply_chain/test_supply_chain_seller_unit.py rename to tests/supply_chain/test_seller_unit.py index f74e4a320..3a78e1a52 100644 --- a/tests/supply_chain/test_supply_chain_seller_unit.py +++ b/tests/supply_chain/test_seller_unit.py @@ -192,7 +192,7 @@ def test_seller_unit_dynamics_sampler(self): # NOTE: this simple seller unit return demands that same as current tick - # Tick 0 will have demand == 25.first row of data after preprocessing data. + # Tick 0 will have demand == 10.first row of data after preprocessing data. # from sample_preprocessed.csv self.assertEqual(10, seller_unit._sold) self.assertEqual(10, seller_unit._demand) diff --git a/tests/supply_chain/test_supply_chain_state_only.py b/tests/supply_chain/test_state_only.py similarity index 96% rename from tests/supply_chain/test_supply_chain_state_only.py rename to tests/supply_chain/test_state_only.py index 0dfa90d50..d6584e4a9 100644 --- a/tests/supply_chain/test_supply_chain_state_only.py +++ b/tests/supply_chain/test_state_only.py @@ -54,7 +54,6 @@ def test_distribution_state_only_small_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") distribution_unit = supplier_3.distribution - warehouse_1.products[SKU3_ID].consumer env.step(None) # vlt is greater than len(pending_order_len), which will cause the pending order to increase @@ -185,15 +184,12 @@ def test_distribution_state_only_bigger_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") retailer_1: FacilityBase = be.world._get_facility_by_name("Retailer_001") - warehouse_1_id, retailer_1_id = 6, 13 warehouse_1_distribution_unit = warehouse_1.distribution self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) env.step(None) - retailer_1.products[SKU2_ID].consumer - order_1 = Order(warehouse_1, retailer_1, SKU2_ID, 1, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_1) # The vlt configuration of this topology is 5. @@ -230,7 +226,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: list(env.metrics["products"][retailer_1.products[SKU2_ID].id]["pending_order_daily"]), ) - self.assertEqual(3, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(3, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) # There are a total of two trains in the configuration, and they have all been dispatched now. self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) @@ -245,9 +241,9 @@ def test_distribution_state_only_bigger_vlt(self) -> None: list(env.metrics["products"][retailer_1.products[SKU2_ID].id]["pending_order_daily"]), ) - self.assertEqual(6, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(6, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) - self.assertEqual(0, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(0, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU3_ID]) # After env.step runs, where tick is 5. order_1 arrives after env.step. # order_2 will arrive at tick=6.order_3 is expected to arrive at tick=8 under normal circumstances. @@ -267,7 +263,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: # When order_1 arrives at the next step, the in_transit_orders of retailer_1 should be the negative number # 1+2+3-1 of the arrival order of retailer_1. - self.assertEqual(5, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(5, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) # After env.step runs, where tick is 7. order_2 arrives after env.step. # There are empty cars at this time, order_3 will arrive at tick = 11. @@ -275,7 +271,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: # When order_2 arrives at the next step, the in_transit_orders of retailer_1 should be the negative number # 1+2+3-1-2 of the arrival order of retailer_1. - self.assertEqual(3, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(3, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) order_4 = Order(warehouse_1, retailer_1, SKU2_ID, 4, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_4) @@ -375,7 +371,6 @@ def test_distribution_state_only(self) -> None: distribution_unit.place_order(order) self.assertEqual(1, len(distribution_unit._order_queues["train"])) self.assertEqual(20, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - supplier_3_id, warehouse_1_id, retailer_1_id = 1, 6, 13 env.step(None) @@ -384,7 +379,7 @@ def test_distribution_state_only(self) -> None: [0, 0, 20, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) # add another order, it would be successfully scheduled. order = Order(supplier_3, warehouse_1, SKU3_ID, 25, "train", env.tick, None) @@ -398,7 +393,7 @@ def test_distribution_state_only(self) -> None: self.assertEqual(2, len(distribution_unit._order_queues["train"])) self.assertEqual(25 + 30, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3_id]["pending_order"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3.id]["pending_order"][SKU3_ID]) self.assertEqual(25 + 30, distribution_unit.pending_product_quantity[SKU3_ID]) warehouse_1_distribution_unit = warehouse_1.distribution @@ -409,16 +404,16 @@ def test_distribution_state_only(self) -> None: order_3 = Order(warehouse_1, retailer_1, SKU3_ID, 5, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_3) - self.assertEqual(5 + 5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5 + 5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5 + 5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) order_4 = Order(warehouse_1, retailer_1, SKU3_ID, 5, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_4) - self.assertEqual(5 + 5 + 5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5 + 5 + 5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5 + 5 + 5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) # There is no place_order for the distribution of supplier_3, there should be no change - self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3_id]["pending_order"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3.id]["pending_order"][SKU3_ID]) self.assertEqual(25 + 30, distribution_unit.pending_product_quantity[SKU3_ID]) start_tick = env.tick @@ -429,7 +424,7 @@ def test_distribution_state_only(self) -> None: [0, 20, 25, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) while env.tick < expected_supplier_tick - 1: env.step(None) @@ -438,20 +433,20 @@ def test_distribution_state_only(self) -> None: [20, 25, 0, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) env.step(None) self.assertEqual( [25, 0, 0, 30], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) # will arrive at the end of this tick, still on the way. assert env.tick == expected_supplier_tick self.assertEqual(0, len(distribution_unit._order_queues["train"])) self.assertEqual(0, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) env.step(None) @@ -459,7 +454,7 @@ def test_distribution_state_only(self) -> None: self.assertEqual(0, len(distribution_unit._order_queues["train"])) self.assertEqual(0, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) diff --git a/tests/supply_chain/test_supply_chain_storage_unit.py b/tests/supply_chain/test_storage_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_storage_unit.py rename to tests/supply_chain/test_storage_unit.py diff --git a/tests/supply_chain/test_supply_chain_units_interaction.py b/tests/supply_chain/test_units_interaction.py similarity index 100% rename from tests/supply_chain/test_supply_chain_units_interaction.py rename to tests/supply_chain/test_units_interaction.py