Skip to content

Commit

Permalink
Deploying to gh-pages from main @ d285e66 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
benjijamorris committed Oct 3, 2024
1 parent 438f693 commit 12331fe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
15 changes: 14 additions & 1 deletion _modules/cyto_dl/datamodules/dataframe/dataframe_datamodule.html
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,10 @@ <h1>Source code for cyto_dl.datamodules.dataframe.dataframe_datamodule</h1><div
<span class="bp">self</span><span class="o">.</span><span class="n">subsample</span> <span class="o">=</span> <span class="n">subsample</span> <span class="ow">or</span> <span class="p">{}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">refresh_subsample</span> <span class="o">=</span> <span class="n">refresh_subsample</span>

<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">dataloader_kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;batch_size&quot;</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># init size is used to check if the batch size has changed (used for Automatic batch size finder)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>

<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">subsample</span><span class="o">.</span><span class="n">keys</span><span class="p">()):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">subsample</span><span class="p">[</span><span class="n">get_canonical_split_name</span><span class="p">(</span><span class="n">key</span><span class="p">)]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">subsample</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>

Expand All @@ -603,16 +607,25 @@ <h1>Source code for cyto_dl.datamodules.dataframe.dataframe_datamodule</h1><div
<div class="viewcode-block" id="DataframeDatamodule.make_dataloader"><a class="viewcode-back" href="../../../../cyto_dl.datamodules.dataframe.dataframe_datamodule.html#cyto_dl.datamodules.dataframe.dataframe_datamodule.DataframeDatamodule.make_dataloader">[docs]</a> <span class="k">def</span> <span class="nf">make_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split</span><span class="p">):</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">dataloader_kwargs</span><span class="p">)</span>
<span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;shuffle&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;shuffle&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span> <span class="ow">and</span> <span class="n">split</span> <span class="o">==</span> <span class="s2">&quot;train&quot;</span>
<span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;batch_size&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>

<span class="n">subset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_dataset</span><span class="p">(</span><span class="n">split</span><span class="p">)</span>
<span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">subset</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>

<div class="viewcode-block" id="DataframeDatamodule.get_dataloader"><a class="viewcode-back" href="../../../../cyto_dl.datamodules.dataframe.dataframe_datamodule.html#cyto_dl.datamodules.dataframe.dataframe_datamodule.DataframeDatamodule.get_dataloader">[docs]</a> <span class="k">def</span> <span class="nf">get_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split</span><span class="p">):</span>
<span class="n">sample_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">subsample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">split</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

<span class="k">if</span> <span class="p">(</span><span class="n">split</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloaders</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">sample_size</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">refresh_subsample</span><span class="p">):</span>
<span class="k">if</span> <span class="p">(</span>
<span class="p">(</span><span class="n">split</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloaders</span><span class="p">)</span>
<span class="ow">or</span> <span class="p">(</span><span class="n">sample_size</span> <span class="o">!=</span> <span class="o">-</span><span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">refresh_subsample</span><span class="p">)</span>
<span class="c1"># check if batch size has changed (used for Automatic batch size finder)</span>
<span class="ow">or</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_init_size</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">):</span>
<span class="c1"># if we want to use a subsample per epoch, we need to remake the</span>
<span class="c1"># dataloader, to refresh the sample</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dataloaders</span><span class="p">[</span><span class="n">split</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_dataloader</span><span class="p">(</span><span class="n">split</span><span class="p">)</span>
<span class="c1"># reset the init size to the current batch size so dataloader isn&#39;t recreated every epoch</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>

<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataloaders</span><span class="p">[</span><span class="n">split</span><span class="p">]</span></div>

Expand Down
11 changes: 11 additions & 0 deletions _modules/cyto_dl/train.html
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ <h1>Source code for cyto_dl.train</h1><div class="highlight"><pre>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">lightning</span> <span class="kn">import</span> <span class="n">Callback</span><span class="p">,</span> <span class="n">LightningDataModule</span><span class="p">,</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span>
<span class="kn">from</span> <span class="nn">lightning.pytorch.loggers.logger</span> <span class="kn">import</span> <span class="n">Logger</span>
<span class="kn">from</span> <span class="nn">lightning.pytorch.tuner</span> <span class="kn">import</span> <span class="n">Tuner</span>
<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">DictConfig</span><span class="p">,</span> <span class="n">OmegaConf</span>

<span class="kn">from</span> <span class="nn">cyto_dl</span> <span class="kn">import</span> <span class="n">utils</span>
Expand Down Expand Up @@ -499,6 +500,12 @@ <h1>Source code for cyto_dl.train</h1><div class="highlight"><pre>
<span class="n">utils</span><span class="o">.</span><span class="n">remove_aux_key</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>

<span class="n">log</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Instantiating data &lt;</span><span class="si">{</span><span class="n">cfg</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;_target_&#39;</span><span class="p">,</span><span class="w"> </span><span class="n">cfg</span><span class="o">.</span><span class="n">data</span><span class="p">)</span><span class="si">}</span><span class="s2">&gt;&quot;</span><span class="p">)</span>

<span class="n">use_batch_tuner</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">cfg</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;batch_size&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;AUTO&quot;</span><span class="p">:</span>
<span class="n">use_batch_tuner</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">cfg</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span>

<span class="n">data</span> <span class="o">=</span> <span class="n">utils</span><span class="o">.</span><span class="n">create_dataloader</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">LightningDataModule</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">MutableMapping</span><span class="p">)</span> <span class="ow">or</span> <span class="s2">&quot;train_dataloaders&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
Expand Down Expand Up @@ -535,6 +542,10 @@ <h1>Source code for cyto_dl.train</h1><div class="highlight"><pre>
<span class="n">log</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Logging hyperparameters!&quot;</span><span class="p">)</span>
<span class="n">utils</span><span class="o">.</span><span class="n">log_hyperparameters</span><span class="p">(</span><span class="n">object_dict</span><span class="p">)</span>

<span class="k">if</span> <span class="n">use_batch_tuner</span><span class="p">:</span>
<span class="n">tuner</span> <span class="o">=</span> <span class="n">Tuner</span><span class="p">(</span><span class="n">trainer</span><span class="o">=</span><span class="n">trainer</span><span class="p">)</span>
<span class="n">tuner</span><span class="o">.</span><span class="n">scale_batch_size</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">data</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;power&quot;</span><span class="p">)</span>

<span class="k">if</span> <span class="n">cfg</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;train&quot;</span><span class="p">):</span>
<span class="n">log</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Starting training!&quot;</span><span class="p">)</span>
<span class="n">model</span><span class="p">,</span> <span class="n">load_params</span> <span class="o">=</span> <span class="n">utils</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">cfg</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;checkpoint&quot;</span><span class="p">))</span>
Expand Down

0 comments on commit 12331fe

Please sign in to comment.