Skip to content

Commit

Permalink
Deploying to gh-pages from main @ 592177b 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
benjijamorris committed Jul 2, 2024
1 parent 3a9e7db commit fd2a520
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 47 deletions.
79 changes: 45 additions & 34 deletions _modules/cyto_dl/nn/vits/blocks/patchify.html

Large diffs are not rendered by default.

28 changes: 23 additions & 5 deletions _modules/cyto_dl/nn/vits/cross_mae.html
Original file line number Diff line number Diff line change
Expand Up @@ -505,18 +505,29 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>

<div class="viewcode-block" id="CrossMAE_Decoder.forward"><a class="viewcode-back" href="../../../../cyto_dl.nn.vits.cross_mae.html#cyto_dl.nn.vits.cross_mae.CrossMAE_Decoder.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">forward_indexes</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="p">):</span>
<span class="c1"># HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># features can be n t b c (if intermediate feature weighter used) or t b c if not</span>
<span class="n">features</span> <span class="o">=</span> <span class="n">features</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span> <span class="k">else</span> <span class="n">features</span>
<span class="n">T</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now</span>
<span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_norm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">features</span><span class="p">))</span>

<span class="c1"># add cls token</span>
<span class="n">backward_indexes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">backward_indexes</span><span class="p">),</span> <span class="n">backward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">backward_indexes</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span>
<span class="p">),</span>
<span class="n">backward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">forward_indexes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">forward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">forward_indexes</span><span class="p">),</span> <span class="n">forward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">forward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">forward_indexes</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span>
<span class="p">),</span>
<span class="n">forward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># fill in masked regions</span>
Expand Down Expand Up @@ -553,13 +564,20 @@ <h1>Source code for cyto_dl.nn.vits.cross_mae</h1><div class="highlight"><pre>

<span class="c1"># add back in visible/encoded tokens that we don&#39;t calculate loss on</span>
<span class="n">patches</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">T</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">patches</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">patches</span><span class="p">),</span> <span class="n">patches</span><span class="p">],</span>
<span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">T</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">patches</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span>
<span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">patches</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">patches</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="p">),</span>
<span class="n">patches</span><span class="p">,</span>
<span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">patches</span> <span class="o">=</span> <span class="n">take_indexes</span><span class="p">(</span><span class="n">patches</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># patches to image</span>
<span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch2img</span><span class="p">(</span><span class="n">patches</span><span class="p">)</span>

<span class="k">return</span> <span class="n">img</span></div></div>
</pre></div>
</article>
Expand Down
7 changes: 6 additions & 1 deletion _modules/cyto_dl/nn/vits/mae.html
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,12 @@ <h1>Source code for cyto_dl.nn.vits.mae</h1><div class="highlight"><pre>
<span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_norm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">features</span><span class="p">))</span>

<span class="n">backward_indexes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">backward_indexes</span><span class="p">),</span> <span class="n">backward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">backward_indexes</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">backward_indexes</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span>
<span class="p">),</span>
<span class="n">backward_indexes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># fill in masked regions</span>
Expand Down
4 changes: 1 addition & 3 deletions _modules/cyto_dl/nn/vits/utils.html
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,7 @@ <h1>Source code for cyto_dl.nn.vits.utils</h1><div class="highlight"><pre>


<div class="viewcode-block" id="take_indexes"><a class="viewcode-back" href="../../../../cyto_dl.nn.vits.utils.html#cyto_dl.nn.vits.utils.take_indexes">[docs]</a><span class="k">def</span> <span class="nf">take_indexes</span><span class="p">(</span><span class="n">sequences</span><span class="p">,</span> <span class="n">indexes</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
<span class="n">sequences</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">repeat</span><span class="p">(</span><span class="n">indexes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">sequences</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="s2">&quot;t b -&gt; t b c&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">sequences</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="p">)</span></div>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="n">sequences</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">repeat</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="s2">&quot;t b -&gt; t b c&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">sequences</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span></div>
</pre></div>
</article>
</div>
Expand Down
6 changes: 3 additions & 3 deletions cyto_dl.nn.vits.blocks.patchify.html
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>patch_size</strong> (<em>List[int]</em>) – Size of each patch</p></li>
<li><p><strong>patch_size</strong> (<em>List[int]</em>) – Size of each patch in pix (ZYX order for 3D, YX order for 2D)</p></li>
<li><p><strong>emb_dim</strong> (<em>int</em>) – Dimension of encoder</p></li>
<li><p><strong>n_patches</strong> (<em>List[int]</em>) – Number of patches in each spatial dimension</p></li>
<li><p><strong>n_patches</strong> (<em>List[int]</em>) – Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)</p></li>
<li><p><strong>spatial_dims</strong> (<em>int</em>) – Number of spatial dimensions</p></li>
<li><p><strong>context_pixels</strong> (<em>List[int]</em>) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.</p></li>
<li><p><strong>input_channels</strong> (<em>int</em>) – Number of input channels</p></li>
Expand All @@ -448,7 +448,7 @@

<dl class="py function">
<dt class="sig sig-object py" id="cyto_dl.nn.vits.blocks.patchify.random_indexes">
<span class="sig-prename descclassname"><span class="pre">cyto_dl.nn.vits.blocks.patchify.</span></span><span class="sig-name descname"><span class="pre">random_indexes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cyto_dl/nn/vits/blocks/patchify.html#random_indexes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#cyto_dl.nn.vits.blocks.patchify.random_indexes" title="Permalink to this definition">#</a></dt>
<span class="sig-prename descclassname"><span class="pre">cyto_dl.nn.vits.blocks.patchify.</span></span><span class="sig-name descname"><span class="pre">random_indexes</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/cyto_dl/nn/vits/blocks/patchify.html#random_indexes"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#cyto_dl.nn.vits.blocks.patchify.random_indexes" title="Permalink to this definition">#</a></dt>
<dd></dd></dl>

</section>
Expand Down
2 changes: 1 addition & 1 deletion searchindex.js

Large diffs are not rendered by default.

0 comments on commit fd2a520

Please sign in to comment.