Skip to content

Commit

Permalink
Improvements in Layer and Sequential building error recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 15, 2023
1 parent 233761d commit fa7bb67
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
30 changes: 10 additions & 20 deletions keras/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,15 +1205,15 @@ def _maybe_build(self, call_spec):

# Otherwise, attempt to build the layer by calling it on symbolic input.
if might_have_unbuilt_state(self):
if len(shapes_dict) == 1:
success = self._build_by_run_for_single_pos_arg(first_shape)
else:
success = self._build_by_run_for_kwargs(shapes_dict)
if not success:
try:
backend.compute_output_spec(
self.call, **call_spec.arguments_dict
)
except Exception as e:
if call_spec.eager:
# Will let the actual eager call do state-building
return
raise ValueError(
warnings.warn(
f"Layer '{self.name}' looks like it has unbuilt state, but "
"Keras is not able to trace the layer `call()` in order to "
"build it automatically. Possible causes:\n"
Expand All @@ -1225,22 +1225,11 @@ def _maybe_build(self, call_spec):
"to implement the `def build(self, input_shape)` method on "
"your layer. It should create all variables used by the "
"layer (e.g. by calling `layer.build()` on all its "
"children layers)."
"children layers).\n"
f"Exception encoutered: ''{e}''"
)

self.build(first_shape)

def _build_by_run(self, *args, **kwargs):
call_spec = CallSpec(self._call_signature, args, kwargs)
shapes_dict = get_shapes_dict(call_spec)
if len(shapes_dict) == 1:
success = self._build_by_run_for_single_pos_arg(
tuple(shapes_dict.values())[0]
)
else:
success = self._build_by_run_for_kwargs(shapes_dict)
return success

def _build_by_run_for_single_pos_arg(self, input_shape):
# Case: all inputs are in the first arg (possibly nested).
input_tensors = map_shape_structure(
Expand All @@ -1249,7 +1238,8 @@ def _build_by_run_for_single_pos_arg(self, input_shape):
try:
backend.compute_output_spec(self.call, input_tensors)
return True
except:
except Exception as e:
raise e
return False

def _build_by_run_for_kwargs(self, shapes_dict):
Expand Down
16 changes: 16 additions & 0 deletions keras/models/sequential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import inspect

import tree

Expand Down Expand Up @@ -178,6 +179,21 @@ def build(self, input_shape=None):
# Can happen if shape inference is not implemented.
# TODO: consider reverting inbound nodes on layers processed.
return
except TypeError as e:
signature = inspect.signature(layer.call)
positional_args = [
param
for param in signature.parameters.values()
if param.default == inspect.Parameter.empty
]
if len(positional_args) != 1:
raise ValueError(
"Layers added to a Sequential model "
"can only have a single positional argument, "
f"the input tensor. Layer {layer.__class__.__name__} "
f"has multiple positional arguments: {positional_args}"
)
raise e
outputs = x
self._functional = Functional(inputs=inputs, outputs=outputs)
self.built = True
Expand Down
12 changes: 12 additions & 0 deletions keras/models/sequential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,15 @@ def test_bad_layer(self):
model = Sequential(name="seq")
with self.assertRaisesRegex(ValueError, "Only instances of"):
model.add({})

model = Sequential(name="seq")

class BadLayer(layers.Layer):
def call(self, inputs, training):
return inputs

model.add(BadLayer())
with self.assertRaisesRegex(
ValueError, "can only have a single positional"
):
model.build((None, 2))

0 comments on commit fa7bb67

Please sign in to comment.