Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added reset teset in supply chain scenarios #550

Open
wants to merge 11 commits into
base: sc_refinement
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ class DistributionDataModel(DataModelBase):
def __init__(self) -> None:
super(DistributionDataModel, self).__init__()

def initialize(self) -> None:
self.reset()

def reset(self) -> None:
super(DistributionDataModel, self).reset()
2 changes: 2 additions & 0 deletions maro/simulator/scenarios/supply_chain/units/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def initialize(self) -> None:

# TODO: add vehicle patient setting if needed

self.data_model.initialize()

for sku_id in self.facility.products.keys():
self._unit_delay_order_penalty[sku_id] = self.facility.skus[sku_id].unit_delay_order_penalty

Expand Down
12 changes: 12 additions & 0 deletions tests/data/supply_chain/case_04/test_case_04.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@ food_1,2021/1/2,43.1,33.39,20
food_1,2021/1/3,43.2,33.39,30
food_1,2021/1/4,43.3,33.39,40
food_1,2021/1/5,43.4,33.39,50
food_1,2021/1/6,43.4,33.39,60
food_1,2021/1/7,43.4,33.39,70
food_1,2021/1/8,43.4,33.39,80
food_1,2021/1/9,43.4,33.39,90
food_1,2021/1/10,43.4,33.39,100
food_1,2021/1/11,43.4,33.39,110
hobby_1,2021/1/1,28.32,21.79,100
hobby_1,2021/1/2,28.32,21.79,200
hobby_1,2021/1/3,28.32,21.79,300
hobby_1,2021/1/4,28.32,21.79,400
hobby_1,2021/1/5,28.32,21.79,500
hobby_1,2021/1/6,28.32,21.79,600
hobby_1,2021/1/7,28.32,21.79,700
hobby_1,2021/1/8,28.32,21.79,800
hobby_1,2021/1/9,28.32,21.79,900
hobby_1,2021/1/10,28.32,21.79,1000
hobby_1,2021/1/11,28.32,21.79,1100
household_1,2022/5/14,17.35974446,14.4,146
household_1,2022/5/15,17.29592816,14.4,164
household_1,2022/5/16,17.43129521,14.4,127
30 changes: 30 additions & 0 deletions tests/data/supply_chain/case_05/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ world:
unit_delay_order_penalty: 20
unit_order_cost: 0

- name: "Supplier_SKU1"
definition_ref: "SupplierFacility"
skus:
sku1:
init_stock: 20
has_manufacture: True
max_manufacture_rate: 50
manufacture_leading_time: 1
unit_product_cost: 1
price: 10
unit_delay_order_penalty: 10
has_consumer: True
sku3:
init_stock: 80
unit_product_cost: 1
price: 10
unit_delay_order_penalty: 10
children:
storage: *small_storage
distribution: *normal_distribution
config:
unit_delay_order_penalty: 20
unit_order_cost: 0

- name: "Warehouse_001"
definition_ref: "WarehouseFacility"
skus:
Expand All @@ -154,6 +178,7 @@ world:
sub_storage_id: 1
storage_upper_bound: 40
price: 100
has_consumer: True
sku2:
init_stock: 12
sub_storage_id: 1
Expand Down Expand Up @@ -257,6 +282,11 @@ world:
cost: 0.6

Warehouse_001:
sku1:
"Supplier_SKU1":
"train":
vlt: 3
cost: 1
sku3:
"Supplier_SKU3":
"train":
Expand Down
119 changes: 112 additions & 7 deletions tests/supply_chain/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,93 @@
# Licensed under the MIT license.

import os
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np

from maro.simulator import Env


def build_env(case_name: str, durations: int):
def build_env(case_name: str, durations: int) -> Env:
case_folder = os.path.join("tests", "data", "supply_chain", case_name)
return Env(scenario="supply_chain", topology=case_folder, durations=durations)

env = Env(scenario="supply_chain", topology=case_folder, durations=durations)

return env


def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int):
def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int) -> Dict[int, int]:
sku_id_list = env.snapshot_list["storage"][frame_index:node_index:"sku_id_list"].flatten().astype(np.int)
product_quantity = env.snapshot_list["storage"][frame_index:node_index:"product_quantity"].flatten().astype(np.int)

return {sku_id: quantity for sku_id, quantity in zip(sku_id_list, product_quantity)}
return dict(zip(sku_id_list, product_quantity))


def snapshot_query(env: Env, i: int) -> Tuple[
Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list]
]:
consumer_nodes = env.snapshot_list["consumer"]
storage_nodes = env.snapshot_list["storage"]
seller_nodes = env.snapshot_list["seller"]
manufacture_nodes = env.snapshot_list["manufacture"]
distribution_nodes = env.snapshot_list["distribution"]

states_consumer: Dict[int, list] = {}
states_storage: Dict[int, list] = {}
states_seller: Dict[int, list] = {}
states_manufacture: Dict[int, list] = {}
states_distribution: Dict[int, list] = {}

env_metric = env.metrics

for idx in range(len(consumer_nodes)):
states_consumer[idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.float)

for idx in range(len(storage_nodes)):
states_storage[idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.float))
states_storage[idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int))
states_storage[idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.float))
states_storage[idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.float))

for idx in range(len(manufacture_nodes)):
states_manufacture[idx] = (
manufacture_nodes[i:idx:manufacture_features]
.flatten()
.astype(
np.float,
)
)

for idx in range(len(distribution_nodes)):
states_distribution[idx] = (
distribution_nodes[i:idx:distribution_features]
.flatten()
.astype(
np.float,
)
)

for idx in range(len(seller_nodes)):
states_seller[idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.float)

return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution


def test_env_reset_snapshot_query(
env: Env,
action_1: object,
action_2: object,
expect_tick: int,
random_tick: list = None,
) -> List[tuple]:
snapshots: List[tuple] = [] # List of (env_metric, states_consumer, ..., states_distribution)
for i in range(expect_tick):
snapshots.append(snapshot_query(env, i))
env.step(action_1)

if random_tick is not None:
if i in random_tick:
env.step(action_2)

return list(zip(*snapshots))


SKU1_ID = 1
Expand All @@ -29,3 +97,40 @@ def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int):
SKU4_ID = 4
FOOD_1_ID = 20
HOBBY_1_ID = 30

consumer_features = (
"id",
"facility_id",
"sku_id",
"order_base_cost",
"purchased",
"received",
"order_product_cost",
"latest_consumptions",
"in_transit_quantity",
)

storage_features = ("id", "facility_id")

seller_features = (
"sold",
"demand",
"total_sold",
"id",
"total_demand",
"backlog_ratio",
"facility_id",
"product_unit_id",
)

manufacture_features = (
"id",
"facility_id",
"start_manufacture_quantity",
"sku_id",
"in_pipeline_quantity",
"finished_quantity",
"product_unit_id",
)

distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity")
Loading