Skip to content

Commit

Permalink
fix: use correct version of orbax (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma authored Aug 22, 2023
1 parent fd818d2 commit 61d1031
Show file tree
Hide file tree
Showing 5 changed files with 561 additions and 564 deletions.
1 change: 0 additions & 1 deletion app/gradio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def infer(prompt):
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):

text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
pillow
jax==0.3.25
flax==0.6.3
orbax==0.0.23
wandb

[options.extras_require]
Expand Down
5 changes: 1 addition & 4 deletions src/dalle_mini/model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _smelu(x: Any) -> Any:

ACT2FN.update({"smelu": smelu()})


# deepnet initialization
def deepnet_init(init_std, gain=1):
init = jax.nn.initializers.normal(init_std)
Expand Down Expand Up @@ -498,7 +499,6 @@ class GLU(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:

if self.config.use_deepnet_scaling:
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
self.config
Expand Down Expand Up @@ -567,7 +567,6 @@ class FFN(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:

if self.config.use_deepnet_scaling:
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
self.config
Expand Down Expand Up @@ -634,7 +633,6 @@ def __call__(
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:

if self.config.use_scan:
hidden_states = hidden_states[0]

Expand Down Expand Up @@ -742,7 +740,6 @@ def __call__(
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:

if self.config.use_scan:
hidden_states = hidden_states[0]

Expand Down
Loading

0 comments on commit 61d1031

Please sign in to comment.