Skip to content

Commit

Permalink
Auto-commit updated notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Oct 2, 2024
1 parent 06d2eaa commit cc34547
Show file tree
Hide file tree
Showing 5 changed files with 692 additions and 126 deletions.
73 changes: 67 additions & 6 deletions notebooks/_shape_embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"def hashing_fn(args):\n",
" serialized_args = pickle.dumps(vars(args))\n",
" hash_object = hashlib.sha256(serialized_args)\n",
" hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode()\n",
" return hashed_string\n",
"\n",
"def scoring_df(X, y):\n",
" # Split the data into training and test sets\n",
Expand Down Expand Up @@ -90,7 +95,7 @@
" return pd.DataFrame(cv_results)\n",
"\n",
"\n",
"def shape_embed_process():\n",
"def shape_embed_process(clargs):\n",
" # Setting the font size\n",
"\n",
" # rc(\"text\", usetex=True)\n",
Expand All @@ -105,9 +110,11 @@
" )\n",
"\n",
" # matplotlib.use(\"TkAgg\")\n",
" interp_size = 128 * 2\n",
" interp_size = clargs.latent_space_size * 2\n",
" #interp_size = 128 * 2\n",
" max_epochs = 100\n",
" window_size = 128 * 2\n",
" window_size = clargs.latent_space_size * 2\n",
" #window_size = 128 * 2\n",
"\n",
" params = {\n",
" \"model\": \"resnet18_vqvae_legacy\",\n",
Expand All @@ -125,7 +132,7 @@
" }\n",
"\n",
" optimizer_params = {\n",
" \"opt\": \"LAMB\",\n",
" \"opt\": \"AdamW\",\n",
" \"lr\": 0.001,\n",
" \"weight_decay\": 0.0001,\n",
" \"momentum\": 0.9,\n",
Expand All @@ -150,7 +157,7 @@
" # dataset = \"bbbc010\"\n",
"\n",
" # train_data_path = f\"scripts/shapes/data/{dataset_path}\"\n",
" train_data_path = f\"data/{dataset_path}\"\n",
" train_data_path = f\"/nfs/research/uhlmann/afoix/{dataset_path}\"\n",
" metadata = lambda x: f\"results/{dataset_path}_{args.model}/{x}\"\n",
"\n",
" path = Path(metadata(\"\"))\n",
Expand Down Expand Up @@ -522,6 +529,18 @@
" trial_df.to_csv(metadata(\"trial_df.csv\"))\n",
" trial_df.groupby(\"trial\").mean().to_csv(metadata(\"trial_df_mean.csv\"))\n",
" trial_df.plot(kind=\"bar\")\n",
" \n",
" #mean_df = trial_df.groupby(\"trial\").mean()\n",
" #std_df = trial_df.groupby(\"trial\").std()\n",
" #wandb.log_table(mean_df)\n",
" #wandb.log_table(std_df) \n",
" \n",
" #Special metrics for f1 score for wandb\n",
" wandblogger.experiment.log({\"trial_df\": wandb.Table(dataframe=trial_df)})\n",
" mean_df = trial_df.groupby(\"trial\").mean()\n",
" std_df = trial_df.groupby(\"trial\").std()\n",
" wandblogger.experiment.log({\"Mean\": wandb.Table(dataframe=mean_df)})\n",
" wandblogger.experiment.log({\"Std\": wandb.Table(dataframe=std_df)})\n",
"\n",
" melted_df = trial_df.melt(id_vars=\"trial\", var_name=\"Metric\", value_name=\"Score\")\n",
" # fig, ax = plt.subplots(figsize=(width, height))\n",
Expand Down Expand Up @@ -553,8 +572,50 @@
" # tikzplotlib.save(metadata(f\"trials_barplot.tikz\"))\n",
"\n",
"\n",
"\n",
"\n",
"###############################################################################\n",
"\n",
"if __name__ == \"__main__\":\n",
" shape_embed_process()"
"\n",
" def auto_pos_int (x):\n",
" val = int(x,0)\n",
" if val <= 0:\n",
" raise argparse.ArgumentTypeError(\"argument must be a positive int. Got {:d}.\".format(val))\n",
" return val\n",
" \n",
" parser = argparse.ArgumentParser(description='Run the shape embed pipeline')\n",
" \n",
" models = [\n",
" \"resnet18_vae\"\n",
" , \"resnet50_vae\"\n",
" , \"resnet18_vae_bolt\"\n",
" , \"resnet50_vae_bolt\"\n",
" , \"resnet18_vqvae\"\n",
" , \"resnet50_vqvae\"\n",
" , \"resnet18_vqvae_legacy\"\n",
" , \"resnet50_vqvae_legacy\"\n",
" , \"resnet101_vqvae_legacy\"\n",
" , \"resnet110_vqvae_legacy\"\n",
" , \"resnet152_vqvae_legacy\"\n",
" , \"resnet18_vae_legacy\"\n",
" , \"resnet50_vae_legacy\"\n",
" ]\n",
" parser.add_argument(\n",
" '-m', '--model', choices=models, default=models[0], metavar='MODEL'\n",
" , help=f\"The MODEL to use, one of {models} (default {models[0]}).\")\n",
" parser.add_argument(\n",
" '-b', '--batch-size', default=int(4), metavar='BATCH_SIZE', type=auto_pos_int\n",
" , help=\"The BATCH_SIZE for the run, a positive integer (default 4)\")\n",
" parser.add_argument(\n",
" '-l', '--latent-space-size', default=int(128), metavar='LATENT_SPACE_SIZE', type=auto_pos_int\n",
" , help=\"The LATENT_SPACE_SIZE, a positive integer (default 128)\")\n",
" parser.add_argument('--clear-checkpoints', action='store_true'\n",
" , help='remove checkpoints')\n",
" #parser.add_argument('-v', '--verbose', action='count', default=0,\n",
" # help=\"Increase verbosity level by adding more \\\"v\\\".\")\n",
" \n",
" shape_embed_process(parser.parse_args())"
]
}
],
Expand Down
Loading

0 comments on commit cc34547

Please sign in to comment.