-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added notebook docs for running experiments
Showing
3 changed files
with
601 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,516 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b2de4ad0-3447-4930-8913-08a2258720c6", | ||
"metadata": {}, | ||
"source": [ | ||
"# Example Notebook\n", | ||
"\n", | ||
"### Using the 01_generate_single_synth_parameter_data.yaml experiment\n", | ||
"\n", | ||
"This notebook is meant to explain how the objects in this class work, and are configurable in a notebook setting. \n", | ||
"\n", | ||
"Notebooks are a replacement for the `Experiment` class, as we will be handling our experiments in the notebook setting rather than using a .py file\n", | ||
"\n", | ||
"First, let's import all of the stuff we need" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"id": "549cbd99-100c-4ad9-8ff6-cef7309c0210", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Python Lib Packages\n", | ||
"import os\n", | ||
"from pathlib import Path\n", | ||
"import sys\n", | ||
"\n", | ||
"# Pypi imported Modules\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from omegaconf import DictConfig, OmegaConf\n", | ||
"import pandas as pd\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"\n", | ||
"# Putting the dMC module in the Python Path\n", | ||
"current_dir = Path.cwd()\n", | ||
"dmc_dev_path = current_dir.parents[0]\n", | ||
"sys.path.append(str(dmc_dev_path))\n", | ||
"\n", | ||
"# Synthetic Parameter distributions and MLP Networks\n", | ||
"from dMC.nn.power_distribution import Power\n", | ||
"from dMC.nn.single_parameters import SingleParameters\n", | ||
"from dMC.nn.inverse_linear import InverseLinear\n", | ||
"from dMC.nn.parameter_list import ParameterList\n", | ||
"from dMC.nn.mlp import MLP\n", | ||
"from dMC.nn import Initialization\n", | ||
"\n", | ||
"# Physics model\n", | ||
"from dMC.physics.explicit_mc import ExplicitMC\n", | ||
"\n", | ||
"# Experiment\n", | ||
"from dMC.experiments.generate_synthetic import GenerateSynthetic\n", | ||
"from dMC.experiments.writer import Writer\n", | ||
"\n", | ||
"# Utils functions\n", | ||
"from dMC.configuration import _set_device\n", | ||
"from dMC.__main__ import _set_seed\n", | ||
"\n", | ||
"# Required to generate data\n", | ||
"from dMC.data.datasets.nhd_srb import NHDSRB\n", | ||
"from dMC.data.observations.usgs import USGS\n", | ||
"from dMC.data.dates import Dates\n", | ||
"from dMC.data.normalize.min_max import MinMax\n", | ||
"from dMC.data import DataLoader\n", | ||
"\n", | ||
"# For evaluation\n", | ||
"from dMC.experiments.metrics import Metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "de499c58-04a0-4b21-91ba-0aec2702d78c", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setting up the Config\n", | ||
"\n", | ||
"Let's import the config files from our `dMC` directory:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"id": "ac501919-2de5-4b1a-86a4-3063fd060b9f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'defaults': [{'config': '03_train_usgs_period_1a'}, {'hydra_settings': 'settings'}, '_self_'], 'cwd': '/data/tkb5476/projects/dMC-Juniata-hydroDL2', 'data_dir': '/data/tkb5476/projects/dMC-Juniata-hydroDL2/flat_files/dMC-Juniata-hydroDL2/dx-2000dis1_non_merge', 'name': '03_train_usgs_period_1a', 'device': 'cpu', 'config': {'service_locator': {'experiment': 'generate_synthetic.GenerateSynthetic', 'data': 'nhd_srb.NHDSRB', 'observations': 'usgs.USGS', 'physics': 'explicit_mc.ExplicitMC', 'neural_network': 'single_parameters.SingleParameters'}, 'data': {'processed_dir': '${cwd}/flat_files', 'end_node': 4809, 'time': {'start': '02/01/2001 00:00:00', 'end': '09/18/2010 23:00:00', 'steps': 1344, 'tau': 9, 'batch_size': '${config.data.time.steps}'}, 'observations': {'loss_nodes': [1053, 1280, 2662, 2689, 2799, 4780, 4801, 4809], 'dir': '${data_dir}/inflow_interpolated/', 'file_name': '???'}, 'save_paths': {'edges': '${config.data.processed_dir}/${config.data.end_node}_edges.csv', 'nodes': '${config.data.processed_dir}/${config.data.end_node}_nodes.csv', 'areas': '${config.data.processed_dir}/${config.data.end_node}_areas.npy', 'q_prime': '${config.data.processed_dir}/${config.data.end_node}_tau${config.data.time.tau}_{}_{}_q_prime.csv', 'network': '${config.data.processed_dir}/${config.data.end_node}_network_matrix.csv', 'gage_locations': '${config.data.processed_dir}/gages_II_locations.csv', 'q_prime_sum': '${config.data.processed_dir}/${config.data.end_node}_tau${config.data.time.tau}_q_prime_sum.npy'}, 'csv': {'edges': '${data_dir}/graphs/edges_NaNFix.csv', 'nodes': '${data_dir}/graphs/node.csv', 'q_prime': '${data_dir}/graphs/srb_post_process.csv', 'mass_transfer_matrix': '${data_dir}/graphs/TM.csv'}}, 'experiment': {'learning_rate': 0.01, 'epochs': 100, 'warmup': 72, 'lb': [0.01, 0.0], 'ub': [0.3, 3.0], 'factor': 100, 'name': '${name}', 'save_path': '${cwd}/runs/01_synthetic_data/', 'output_cols': '${config.data.observations.loss_nodes}', 'tensorboard_dir': '${cwd}/logs/srb/${name}/${now:%Y-%m-%d}/'}, 'model': {'noise': 0.005, 'train_q': True, 'seed': 0, 'mlp': {'initialization': 'xavier_normal', 'fan': 'fan_in', 'gain': 0.7, 'hidden_size': 6, 'input_size': 8, 'output_size': 2}, 'length': {'idx': 8}, 'slope': {'idx': 2, 'min': 0.0001, 'max': 0.3}, 'velocity': {'min': 0.3, 'max': 15}, 'q_prime': {'min': 0}, 'variables': {'n': 0.03, 'p': 21.0, 'q': 0.5, 't': 3600.0, 'x': 0.3}, 'transformations': {'n': [0.01, 0.3], 'q_spatial': [0, 3]}, 'save_paths': {'areas': '${config.data.save_paths.areas}'}, 'is_base': True}}}" | ||
] | ||
}, | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"cfg = OmegaConf.load(dmc_dev_path / \"dMC/conf/global_settings.yaml\")\n", | ||
"experiment_settings = OmegaConf.load(dmc_dev_path / \"dMC/conf/config/01_generate_single_synth_parameter_data.yaml\")\n", | ||
"cfg.config = experiment_settings\n", | ||
"cfg" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 52, | ||
"id": "ccfeacc2-dbd6-48e4-95e0-5af753a48c1b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Applying our global settings, and specifying an output dir as \n", | ||
"_set_device(cfg)\n", | ||
"_set_seed(cfg)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9d3c0be6-1cab-4b46-b781-5af2e9332561", | ||
"metadata": {}, | ||
"source": [ | ||
"# Building Objects\n", | ||
"\n", | ||
"Below we'll do the \"behind the scenes\" work of building our Dataloader, Model, and Experiment so that we can just use those objects here" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "83dd1105-2543-45f3-a306-43a2a0449ffc", | ||
"metadata": {}, | ||
"source": [ | ||
"## Dataloader:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"id": "6fef710d-566e-449b-abf8-5cd94ce89111", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<torch.utils.data.dataloader.DataLoader at 0x7f6c8dfa6ad0>" | ||
] | ||
}, | ||
"execution_count": 41, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"cfg_data = cfg.config.data\n", | ||
"\n", | ||
"dates = Dates(cfg_data) # Dates Object\n", | ||
"normalize = MinMax(cfg_data) # Normalization Object\n", | ||
"data = NHDSRB(cfg_data, dates=dates, normalize=normalize) # Dataset Object\n", | ||
"obs = USGS(cfg_data, dates, normalize) # Observations Object\n", | ||
"\n", | ||
"# Getting the data\n", | ||
"hydrofabric = data.get_data()\n", | ||
"observations = obs.get_data().transpose(0, 1)\n", | ||
"\n", | ||
"dataloader = DataLoader(data, obs)(cfg_data)\n", | ||
"dataloader" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "18c33757-0892-4c69-a32c-983a2ff1170c", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"id": "154f92d6-4f1b-4f17-9ef4-f85ccbbc7c82", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"ExplicitMC(\n", | ||
" (neural_network): SingleParameters()\n", | ||
")" | ||
] | ||
}, | ||
"execution_count": 47, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"cfg_model = cfg.config.model\n", | ||
"\n", | ||
"neural_network = SingleParameters(cfg=cfg_model).to(cfg_model.device)\n", | ||
"physics_model = ExplicitMC(cfg=cfg_model, neural_network=neural_network)\n", | ||
"physics_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 48, | ||
"id": "0ffde6ac-ae29-41af-b77b-f6fa96a385ac", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Parameter containing:\n", | ||
"tensor(0.0300, requires_grad=True)" | ||
] | ||
}, | ||
"execution_count": 48, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"physics_model.neural_network.n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b7567c1c-c96d-487b-b769-fca5dd9a8c54", | ||
"metadata": {}, | ||
"source": [ | ||
"## Experiment:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 37, | ||
"id": "eb4efb99-df66-4f98-ac01-a86838698a4e", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<dMC.experiments.generate_synthetic.GenerateSynthetic at 0x7f6c8dc42a10>" | ||
] | ||
}, | ||
"execution_count": 37, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"cfg_experiment = cfg.config.experiment\n", | ||
"# writer = Writer(cfg_experiment)\n", | ||
"experiment = GenerateSynthetic(cfg=cfg_experiment, writer=None)\n", | ||
"experiment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2dbf954e-4246-4483-807b-20c472cb7c79", | ||
"metadata": {}, | ||
"source": [ | ||
"# Running the experiment\n", | ||
"\n", | ||
"Similar to the dependency injection framework in the code, you can run the experiment like below" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 49, | ||
"id": "6faa7b78-5f66-40e4-a6b6-6cf9801f0315", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "1351a637b50d4eca89c02c49fbf73fd5", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Epoch 0: Explicit Muskingum Cunge Routing: 0%| | 0/1343 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"experiment.run(dataloader, physics_model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 51, | ||
"id": "cfd0506a-c8dc-4563-8340-5e15e325d74b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"PosixPath('/data/tkb5476/projects/dMC-Juniata-hydroDL2/runs/01_synthetic_data')" | ||
] | ||
}, | ||
"execution_count": 51, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"experiment.save_path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 55, | ||
"id": "e573eb1b-e848-4806-9865-61014c16f40f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>Unnamed: 0</th>\n", | ||
" <th>1053</th>\n", | ||
" <th>1280</th>\n", | ||
" <th>2662</th>\n", | ||
" <th>2689</th>\n", | ||
" <th>2799</th>\n", | ||
" <th>4780</th>\n", | ||
" <th>4801</th>\n", | ||
" <th>4809</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>0</td>\n", | ||
" <td>0.072551</td>\n", | ||
" <td>0.303997</td>\n", | ||
" <td>0.035397</td>\n", | ||
" <td>0.017339</td>\n", | ||
" <td>0.273754</td>\n", | ||
" <td>0.097784</td>\n", | ||
" <td>0.170211</td>\n", | ||
" <td>0.272679</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>1</td>\n", | ||
" <td>0.345283</td>\n", | ||
" <td>1.010940</td>\n", | ||
" <td>0.064402</td>\n", | ||
" <td>0.048944</td>\n", | ||
" <td>0.513326</td>\n", | ||
" <td>0.258317</td>\n", | ||
" <td>0.222125</td>\n", | ||
" <td>0.407868</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>2</td>\n", | ||
" <td>0.755540</td>\n", | ||
" <td>1.308555</td>\n", | ||
" <td>0.099312</td>\n", | ||
" <td>0.127346</td>\n", | ||
" <td>0.782226</td>\n", | ||
" <td>0.438819</td>\n", | ||
" <td>0.284994</td>\n", | ||
" <td>0.669957</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>3</th>\n", | ||
" <td>3</td>\n", | ||
" <td>1.101563</td>\n", | ||
" <td>1.415994</td>\n", | ||
" <td>0.155389</td>\n", | ||
" <td>0.183362</td>\n", | ||
" <td>1.098740</td>\n", | ||
" <td>0.659021</td>\n", | ||
" <td>0.400763</td>\n", | ||
" <td>0.938000</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>4</th>\n", | ||
" <td>4</td>\n", | ||
" <td>1.485405</td>\n", | ||
" <td>1.499759</td>\n", | ||
" <td>0.235011</td>\n", | ||
" <td>0.221742</td>\n", | ||
" <td>1.618124</td>\n", | ||
" <td>0.993112</td>\n", | ||
" <td>0.632635</td>\n", | ||
" <td>1.268810</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" Unnamed: 0 1053 1280 2662 2689 2799 4780 \\\n", | ||
"0 0 0.072551 0.303997 0.035397 0.017339 0.273754 0.097784 \n", | ||
"1 1 0.345283 1.010940 0.064402 0.048944 0.513326 0.258317 \n", | ||
"2 2 0.755540 1.308555 0.099312 0.127346 0.782226 0.438819 \n", | ||
"3 3 1.101563 1.415994 0.155389 0.183362 1.098740 0.659021 \n", | ||
"4 4 1.485405 1.499759 0.235011 0.221742 1.618124 0.993112 \n", | ||
"\n", | ||
" 4801 4809 \n", | ||
"0 0.170211 0.272679 \n", | ||
"1 0.222125 0.407868 \n", | ||
"2 0.284994 0.669957 \n", | ||
"3 0.400763 0.938000 \n", | ||
"4 0.632635 1.268810 " | ||
] | ||
}, | ||
"execution_count": 55, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Our synthetic discharge outputs. The Rows represent time, the Cols are the edge associated with the discharge\n", | ||
"df = pd.read_csv(experiment.save_path / \"01_generate_single_synth_parameter_data.csv\")\n", | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f2765b53-e61b-4a73-9b49-9c5df047d7a3", | ||
"metadata": {}, | ||
"source": [ | ||
"# What now?\n", | ||
"\n", | ||
"Feel free to check out the other experiments. All of the objects that they use are included in their `service_locator` config entry" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fd3c82a8-6b0a-46d5-9edf-153f39573711", | ||
"metadata": {}, | ||
"source": [ | ||
"### 01: Single Parameter Experiments\n", | ||
"To run these, you should use the following configs:\n", | ||
"- `01_generate_single_synth_parameter_data.yaml`\n", | ||
"- `01_train_against_single_synthetic.yaml`\n", | ||
"\n", | ||
"### 02: Synthetic Parameter Distribution Recovery\n", | ||
"\n", | ||
"There are many synthetic parameter experiments. Run the following configs to recreate them\n", | ||
"\n", | ||
"#### Synthetic Constants\n", | ||
"- `02_generate_mlp_param_list.yaml`\n", | ||
"- `02_train_mlp_param_list.yaml`\n", | ||
"\n", | ||
"#### Synthetic Power Law A\n", | ||
"- `02_generate_mlp_power_a.yaml`\n", | ||
"- `02_train_mlp_power_a.yaml`\n", | ||
"\n", | ||
"#### Synthetic Power Law B\n", | ||
"- `02_train_mlp_power_b.yaml`\n", | ||
"- `02_generate_mlp_power_b.yaml`\n", | ||
"\n", | ||
"### 03: Train against USGS data:\n", | ||
"You can run the following cfgs to train models against USGS data\n", | ||
"- `03_train_usgs_period_1a.yaml`\n", | ||
"- `03_train_usgs_period_1b.yaml`\n", | ||
"- `03_train_usgs_period_2a.yaml`\n", | ||
"- `03_train_usgs_period_2b.yaml`\n", | ||
"- `03_train_usgs_period_3a.yaml`\n", | ||
"- `03_train_usgs_period_3b.yaml`\n", | ||
"- `03_train_usgs_period_4a.yaml`\n", | ||
"- `03_train_usgs_period_4b.yaml`" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
hydra-core==1.3.2 | ||
injector==0.21.0 | ||
jupyterlab | ||
matplotlib==3.8.2 | ||
numpy==1.26.2 | ||
omegaconf==2.3.0 | ||
pandas==2.1.3 | ||
pillow==9.3.0 | ||
scikit-learn==1.3.2 | ||
tensorboard==2.16.2 | ||
torch==2.1.1+cpu | ||
torchvision==0.16.1+cpu | ||
tqdm==4.66.1 |