Skip to content

Commit

Permalink
remove sh=None, use default
Browse files Browse the repository at this point in the history
  • Loading branch information
blondegeek committed Dec 10, 2019
1 parent 9d983c5 commit 3f2ed85
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions simple_tasks_and_symmetry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@
"markersize = 15\n",
"\n",
"def plot_task(ax, start, finish, title, marker=None):\n",
" ax.plot(torch.cat([finish[:, 0], finish[:, 0]]), \n",
" torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)\n",
" ax.plot(torch.cat([start[:, 0], start[:, 0]]), \n",
" torch.cat([start[:, 1], start[:, 1]]), 'o-', \n",
" markersize=markersize + 5 if marker else markersize, \n",
" marker=marker if marker else 'o')\n",
" ax.plot(torch.cat([finish[:, 0], finish[:, 0]]), \n",
" torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)\n",
" for i in range(N):\n",
" ax.arrow(start[i, 0], start[i, 1], \n",
" finish[i, 0] - start[i, 0], \n",
Expand Down Expand Up @@ -129,7 +129,7 @@
"outputs": [],
"source": [
"class Network(torch.nn.Module):\n",
" def __init__(self, Rs, n_layers=3, sh=None, max_radius=3.0, number_of_basis=3, radial_layers=3):\n",
" def __init__(self, Rs, n_layers=3, max_radius=3.0, number_of_basis=3, radial_layers=3):\n",
" super().__init__()\n",
" self.Rs = Rs\n",
" self.n_layers = n_layers\n",
Expand All @@ -145,7 +145,7 @@
" L=radial_layers, act=sp)\n",
"\n",
" \n",
" K = partial(Kernel, RadialModel=RadialModel, sh=sh)\n",
" K = partial(Kernel, RadialModel=RadialModel)\n",
" C = partial(Convolution, K)\n",
"\n",
" self.layers = torch.nn.ModuleList([\n",
Expand Down Expand Up @@ -516,7 +516,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 3f2ed85

Please sign in to comment.