Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
simpler version
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jul 4, 2024
1 parent 99b7ef6 commit 81ac1c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 57 deletions.
60 changes: 10 additions & 50 deletions qadence_embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,81 +48,41 @@ def __init__(
)
self.var_to_call: dict[str, ConcretizedCallable] = var_to_call
self._dtype: DTypeLike = None
self.fparams_assigned: bool = False

def flush_fparams(self) -> None:
"""Flush all stored fparams and set them to None."""
self.fparams = {key: None for key in self.fparams.keys()}
self.fparams_assigned = False

def assign_fparams(
self, inputs: dict[str, ArrayLike | None], flush_current: bool = True
) -> None:
"""Mutate the `self.fparams` field to store inputs from the user."""
if self.fparams_assigned:
(
self.flush_fparams()
if flush_current
else logger.error(
"Fparams are still assigned. Please flush them before re-embedding."
)
)
if not inputs.keys() == self.fparams.keys():
logger.error(
f"Please provide all fparams, Expected {self.fparams.keys()},\
received {inputs.keys()}."
)
self.fparams = inputs
self.fparams_assigned = True

def evaluate_param(
self, param_name: str, inputs: dict[str, ArrayLike]
) -> ArrayLike:
"""Returns the result of evaluation an expression in `var_to_call`."""
return self.var_to_call[param_name](inputs)
@property
def root_param_names(self) -> list[str]:
return list(self.vparams.keys()) + list(self.fparams.keys())

def embed_all(
self,
inputs: dict[str, ArrayLike],
include_root_vars: bool = True,
store_inputs: bool = False,
) -> dict[str, ArrayLike]:
"""The standard embedding of all intermediate and leaf parameters.
Include the root_params, i.e., the vparams and fparams original values
to be reused in computations.
"""
if not include_root_vars:
logger.error(
"Warning: Original parameters are not included, only intermediates and leaves."
)
if store_inputs:
self.assign_fparams(inputs)
for intermediate_or_leaf_var, engine_callable in self.var_to_call.items():
# We mutate the original inputs dict and include intermediates and leaves.
inputs[intermediate_or_leaf_var] = engine_callable(inputs)
return inputs

def reembed_all(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]:
assert (
self.fparams_assigned
), "To reembed, please store original fparam values by setting\
`include_root_vars = True` when calling `embed_all`"

def reembed_all(
self,
embedded_params: dict[str, ArrayLike],
new_root_params: dict[str, ArrayLike],
) -> dict[str, ArrayLike]:
# We filter out intermediates and leaves and leave only the original vparams and fparams +
# the `inputs` dict which contains new <name:parameter value> pairs
inputs = {
p: v
for p, v in inputs.items()
if p in self.vparams.keys() or p in self.fparams.keys()
p: v for p, v in embedded_params.items() if p in self.root_param_names
}
return self.embed_all({**self.vparams, **self.fparams, **inputs})
return self.embed_all({**self.vparams, **inputs, **new_root_params})

def __call__(self, inputs: dict[str, ArrayLike]) -> dict[str, ArrayLike]:
"""Functional version of legacy embedding: Return a new dictionary\
with all embedded parameters."""
return self.embed_all(inputs)


@property
def dtype(self) -> DTypeLike:
return self._dtype
Expand Down
13 changes: 6 additions & 7 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_embedding() -> None:
"x": (torch.tensor(x) if engine_name == "torch" else x),
"theta": (torch.tensor(theta) if engine_name == "torch" else theta),
}
eval_0 = embedding.evaluate_param("%0", inputs)
eval_0 = embedding.var_to_call["%0"](inputs)
results.append(eval_0.item())
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])

Expand Down Expand Up @@ -56,12 +56,11 @@ def test_reembedding() -> None:
"x": (torch.tensor(x) if engine_name == "torch" else x),
"theta": (torch.tensor(theta) if engine_name == "torch" else theta),
}
all_params = embedding.embed_all(
inputs, include_root_vars=True, store_inputs=True
)
reembedded_params = embedding.reembed_all(
{"x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed)}
)
all_params = embedding.embed_all(inputs)
new_params = {
"x": (torch.tensor(x_rembed) if engine_name == "torch" else x_rembed)
}
reembedded_params = embedding.reembed_all(all_params, new_params)
results.append(all_params["%0"].item())
reembedded_results.append(reembedded_params["%0"].item())
assert np.allclose(results[0], results[1]) and np.allclose(results[0], results[2])
Expand Down

0 comments on commit 81ac1c5

Please sign in to comment.