From 08744df7b2b351f8ffa35d48b81a2a6c19bb0be0 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Tue, 21 Jun 2022 11:25:58 +0800 Subject: [PATCH] Modify test according to comments Modify test according to comments --- .../supply_chain/units/distribution.py | 1 + tests/data/supply_chain/case_01/config.yml | 2 + tests/data/supply_chain/case_05/config.yml | 25 +- tests/supply_chain/test_action_reset.py | 593 ++++++++++++++++++ tests/supply_chain/test_env_reset.py | 71 ++- tests/supply_chain/test_state_only.py | 4 +- 6 files changed, 672 insertions(+), 24 deletions(-) create mode 100644 tests/supply_chain/test_action_reset.py diff --git a/maro/simulator/scenarios/supply_chain/units/distribution.py b/maro/simulator/scenarios/supply_chain/units/distribution.py index 2405d127a..29ea8a263 100644 --- a/maro/simulator/scenarios/supply_chain/units/distribution.py +++ b/maro/simulator/scenarios/supply_chain/units/distribution.py @@ -100,6 +100,7 @@ def initialize(self) -> None: self._busy_vehicle_num[vehicle_type] = 0 # TODO: add vehicle patient setting if needed + self.data_model.initialize() for sku_id in self.facility.products.keys(): diff --git a/tests/data/supply_chain/case_01/config.yml b/tests/data/supply_chain/case_01/config.yml index 7730960e4..0d21ac8c5 100644 --- a/tests/data/supply_chain/case_01/config.yml +++ b/tests/data/supply_chain/case_01/config.yml @@ -88,6 +88,7 @@ world: sku3: init_stock: 80 has_manufacture: True + has_consumer: True max_manufacture_rate: 50 manufacture_leading_time: 1 unit_product_cost: 1 @@ -105,6 +106,7 @@ world: init_stock: 96 has_manufacture: True max_manufacture_rate: 50 + has_consumer: True manufacture_leading_time: 1 unit_product_cost: 1 price: 100 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 93f2fdf51..93c663234 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -146,6 +146,25 @@ world: unit_delay_order_penalty: 20 unit_order_cost: 0 + - name: "Supplier_SKU1" + definition_ref: "SupplierFacility" + skus: + sku3: + init_stock: 80 + has_manufacture: True + max_manufacture_rate: 50 + manufacture_leading_time: 1 + unit_product_cost: 1 + price: 10 + unit_delay_order_penalty: 10 + has_consumer: True + children: + storage: *small_storage + distribution: *normal_distribution + config: + unit_delay_order_penalty: 20 + unit_order_cost: 0 + - name: "Warehouse_001" definition_ref: "WarehouseFacility" skus: @@ -232,7 +251,7 @@ world: storage: *single_storage config: unit_order_cost: 200 - file_path: "tests/data/supply_chain/case_05/test_case_05.csv" + file_path: "tests/data/supply_chain/case_04/test_case_04.csv" topology: @@ -258,6 +277,10 @@ world: Warehouse_001: sku3: + "Supplier_SKU1": + "train": + vlt: 3 + cost: 1 "Supplier_SKU3": "train": vlt: 3 diff --git a/tests/supply_chain/test_action_reset.py b/tests/supply_chain/test_action_reset.py new file mode 100644 index 000000000..d87c6b914 --- /dev/null +++ b/tests/supply_chain/test_action_reset.py @@ -0,0 +1,593 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +import unittest +from collections import defaultdict +from typing import Dict, List + +import numpy as np + +from maro.simulator.scenarios.supply_chain import ( + ConsumerAction, + ConsumerUnit, + FacilityBase, + ManufactureAction, + ManufactureUnit, + StorageUnit, +) +from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine +from maro.simulator.scenarios.supply_chain.order import Order + +from tests.supply_chain.common import SKU1_ID, SKU3_ID, build_env, get_product_dict_from_storage + + +class MyTestCase(unittest.TestCase): + """ + . consumer unit test + . distribution unit test + . manufacture unit test + . seller unit test + . storage unit test + """ + + def test_env_reset_with_none_action(self) -> None: + """test_env_reset_with_none_action""" + env = build_env("case_05", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") + consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + storage_unit: StorageUnit = supplier_1.storage + seller_unit = Store_001.products[SKU3_ID].seller + manufacture_unit = supplier_1.products[SKU3_ID].manufacture + distribution_unit = supplier_1.distribution + + consumer_nodes = env.snapshot_list["consumer"] + storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] + + consumer_node_index = consumer_unit.data_model_index + storage_node_index = storage_unit.data_model_index + seller_node_index = seller_unit.data_model_index + manufacture_node_index = manufacture_unit.data_model_index + distribution_node_index = distribution_unit.data_model_index + + consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + storage_features = ("id", "facility_id") + + seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) + + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + + distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + + # ##################################### Before reset ##################################### + + expect_tick = 10 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(10): + random_tick.append(random.randint(1, 30)) + + # Store the information about the snapshot of each tick in states_1_x + states_1_consumer: Dict[int, list] = defaultdict(list) + states_1_storage: Dict[int, list] = defaultdict(list) + states_1_seller: Dict[int, list] = defaultdict(list) + states_1_manufacture: Dict[int, list] = defaultdict(list) + states_1_distribution: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step(None) + if i in random_tick: + order = Order( + src_facility=supplier_1, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env_metric_1[i] = env.metrics + states_1_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) + states_1_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + env_metric_1[i] = env.metrics + states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), + ) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), + ) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), + ) + states_1_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) + states_1_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + states_1_distribution[i] = ( + distribution_nodes[i:distribution_node_index:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + # ############################### Test whether reset updates the storage unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset(). + consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) + storage_states = storage_nodes[1:storage_node_index:storage_features].flatten().astype(np.int) + seller_states = seller_nodes[1:seller_node_index:seller_features].flatten().astype(np.int) + manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) + distribution_states = ( + distribution_nodes[1:distribution_node_index:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) + self.assertEqual([0, 0], list(storage_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(seller_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) + self.assertEqual([0, 0, 0, 0], list(distribution_states)) + + expect_tick = 10 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot storage unit of each tick in states_2 + + states_2_consumer: Dict[int, list] = defaultdict(list) + states_2_storage: Dict[int, list] = defaultdict(list) + states_2_seller: Dict[int, list] = defaultdict(list) + states_2_manufacture: Dict[int, list] = defaultdict(list) + states_2_distribution: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step(None) + if i in random_tick: + order = Order( + src_facility=supplier_1, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env_metric_2[i] = env.metrics + states_2_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) + states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), + ) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), + ) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), + ) + states_2_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) + states_2_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + states_2_distribution[i] = ( + distribution_nodes[i:distribution_node_index:distribution_features].flatten().astype(np.int) + ) + + for i in range(expect_tick): + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ManufactureAction_only(self) -> None: + """test env reset with ManufactureAction only""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_storage_index = supplier_3.storage.data_model_index + manufacture_sku3_unit = supplier_3.products[SKU3_ID].manufacture + sku3_manufacture_index = manufacture_sku3_unit.data_model_index + + storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + # ##################################### Before reset ##################################### + + env.step(None) + + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + action = ManufactureAction(manufacture_sku3_unit.id, 1) + + expect_tick = 30 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(10): + random_tick.append(random.randint(1, 30)) + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env_metric_1[i] = env.metrics + states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + # ############################### Test whether reset updates the manufacture unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset(). + states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + + storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] + + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + expect_tick = 30 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_2 + states_2: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env_metric_2[i] = env.metrics + states_2[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + expect_tick = 30 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ConsumerAction_only(self) -> None: + """ "test env reset with ConsumerAction only""" + env = build_env("case_01", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_consumer_unit = supplier_1.products[SKU3_ID].consumer + + consumer_node_index = sku3_consumer_unit.data_model_index + + features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + # ##################################### Before reset ##################################### + consumer_nodes = env.snapshot_list["consumer"] + action = ConsumerAction(sku3_consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step([action]) + env_metric_1[i] = env.metrics + states_1[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset() + states = consumer_nodes[1:consumer_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_2 + states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step([action]) + env_metric_2[i] = env.metrics + states_2[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: + """test env reset with both ManufactureAction and ConsumerAction""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture + storage_unit: StorageUnit = supplier_1.storage + + consumer_node_index = consumer_unit.data_model_index + manufacture_node_index = manufacture_unit.data_model_index + storage_node_index = storage_unit.data_model_index + + consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + storage_features = ("id", "facility_id") + + consumer_nodes = env.snapshot_list["consumer"] + manufacture_nodes = env.snapshot_list["manufacture"] + storage_nodes = env.snapshot_list["storage"] + + # ##################################### Before reset ##################################### + action_consumer = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 20, "train") + action_manufacture = ManufactureAction(manufacture_unit.id, 5) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(30): + random_tick.append(random.randint(0, 90)) + + # Store the information about the snapshot unit of each tick in states_1 + states_1_consumer: Dict[int, list] = defaultdict(list) + states_1_manufacture: Dict[int, list] = defaultdict(list) + states_1_storage: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + + if i in random_tick: + env.step([action_manufacture]) + i += 1 + states_1_manufacture[i] = list( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ), + ) + env_metric_1[i] = env.metrics + continue + + env.step([action_consumer]) + env_metric_1[i] = env.metrics + states_1_consumer[i] = list( + consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), + ) + + states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_1_storage[i].append( + list(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int)), + ) + states_1_storage[i].append( + list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), + ) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset() + consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) + manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_2 + states_2_consumer: Dict[int, list] = defaultdict(list) + states_2_manufacture: Dict[int, list] = defaultdict(list) + states_2_storage: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + + if i in random_tick: + env.step([action_manufacture]) + i += 1 + states_2_manufacture[i] = list( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ), + ) + env_metric_2[i] = env.metrics + continue + + env.step([action_consumer]) + env_metric_2[i] = env.metrics + states_2_consumer[i] = list( + consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), + ) + + states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_2_storage[i].append( + list( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int), + ), + ) + states_2_storage[i].append( + list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), + ) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 1cbffa1f5..2a5ab982d 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. -# Licensed under the MIT license +# Licensed under the MIT license. + import random import unittest from collections import defaultdict @@ -7,21 +8,21 @@ import numpy as np -from maro.simulator.scenarios.supply_chain import FacilityBase, ConsumerAction, ManufactureAction, StorageUnit +from maro.simulator.scenarios.supply_chain import ConsumerAction, FacilityBase, ManufactureAction, StorageUnit from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine from maro.simulator.scenarios.supply_chain.order import Order -from tests.supply_chain.common import build_env, SKU3_ID, FOOD_1_ID, get_product_dict_from_storage +from tests.supply_chain.common import FOOD_1_ID, SKU3_ID, build_env, get_product_dict_from_storage class MyTestCase(unittest.TestCase): """ - . consumer unit test - . distribution unit test - . manufacture unit test - . seller unit test - . storage unit test - """ + . consumer unit test + . distribution unit test + . manufacture unit test + . seller unit test + . storage unit test + """ def test_consumer_unit_reset(self) -> None: """Test whether reset updates the consumer unit completely""" @@ -37,8 +38,17 @@ def test_consumer_unit_reset(self) -> None: consumer_node_index = sku3_consumer_unit.data_model_index - features = ("id", "facility_id", "sku_id", "order_base_cost", "purchased", "received", "order_product_cost", - "latest_consumptions", "in_transit_quantity") + features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) # ##################################### Before reset ##################################### consumer_nodes = env.snapshot_list["consumer"] @@ -107,7 +117,8 @@ def test_distribution_unit_reset(self) -> None: quantity=10, vehicle_type="train", creation_tick=env.tick, - expected_finish_tick=env.tick + 7, ) + expected_finish_tick=env.tick + 7, + ) # There are 2 "train" in total, and 1 left after scheduling this order. distribution_unit.place_order(order_1) @@ -247,7 +258,12 @@ def test_manufacture_unit_reset(self) -> None: manufacture_nodes = env.snapshot_list["manufacture"] manufacture_features = ( - "id", "facility_id", "start_manufacture_quantity", "sku_id", "in_pipeline_quantity", "finished_quantity", + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", "product_unit_id", ) @@ -255,8 +271,10 @@ def test_manufacture_unit_reset(self) -> None: env.step(None) - capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) - remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) # there should be 80 units been taken at the beginning according to the config file. # so remaining space should be 20 @@ -307,8 +325,10 @@ def test_manufacture_unit_reset(self) -> None: storage_nodes = env.snapshot_list["storage"] manufacture_nodes = env.snapshot_list["manufacture"] - capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) - remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) # there should be 80 units been taken at the beginning according to the config file. # so remaining space should be 20 @@ -347,7 +367,7 @@ def test_manufacture_unit_reset(self) -> None: def test_seller_unit_dynamics_sampler(self): """Tested the store_001 Interaction between seller unit and dynamics csv data. - The data file of this test is test_case_ 04.csv""" + The data file of this test is test_case_ 04.csv""" env = build_env("case_04", 600) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) @@ -359,7 +379,16 @@ def test_seller_unit_dynamics_sampler(self): seller_node_index = seller_unit.data_model_index seller_nodes = env.snapshot_list["seller"] - features = ("sold", "demand", "total_sold", "id", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",) + features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) # ##################################### Before reset ##################################### self.assertEqual(20, seller_unit.sku_id) @@ -422,7 +451,7 @@ def test_storage_unit_reset(self) -> None: storage_unit: StorageUnit = supplier_3.storage storage_node_index = storage_unit.data_model_index storage_nodes = env.snapshot_list["storage"] - features = ("id", "facility_id",) + features = ("id", "facility_id") # ##################################### Before reset ##################################### @@ -470,5 +499,5 @@ def test_storage_unit_reset(self) -> None: self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/supply_chain/test_state_only.py b/tests/supply_chain/test_state_only.py index 0dfa90d50..93aa95231 100644 --- a/tests/supply_chain/test_state_only.py +++ b/tests/supply_chain/test_state_only.py @@ -185,7 +185,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") retailer_1: FacilityBase = be.world._get_facility_by_name("Retailer_001") - warehouse_1_id, retailer_1_id = 6, 13 + warehouse_1_id, retailer_1_id = 12, 19 warehouse_1_distribution_unit = warehouse_1.distribution self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) @@ -375,7 +375,7 @@ def test_distribution_state_only(self) -> None: distribution_unit.place_order(order) self.assertEqual(1, len(distribution_unit._order_queues["train"])) self.assertEqual(20, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - supplier_3_id, warehouse_1_id, retailer_1_id = 1, 6, 13 + supplier_3_id, warehouse_1_id, retailer_1_id = 1, 12, 19 env.step(None)