Skip to content

Commit

Permalink
Update transpile_pubs for unit tests (#1815)
Browse files Browse the repository at this point in the history
  • Loading branch information
kt474 authored Jul 24, 2024
1 parent 3d92355 commit 0837326
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TestEstimatorV2(IBMTestCase):
def test_run_program_inputs(self, abs_pubs):
"""Verify program inputs are correct."""
backend = get_mocked_backend()
t_pubs = transpile_pubs(abs_pubs, backend)
t_pubs = transpile_pubs(abs_pubs, backend, "estimator")

inst = EstimatorV2(backend=backend)
inst.run(t_pubs)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setUp(self) -> None:
def test_run_program_inputs(self, in_pubs):
"""Verify program inputs are correct."""
backend = get_mocked_backend()
t_pubs = transpile_pubs(in_pubs, backend)
t_pubs = transpile_pubs(in_pubs, backend, "sampler")
inst = SamplerV2(backend=backend)
inst.run(t_pubs)
input_params = backend.service.run.call_args.kwargs["inputs"]
Expand Down
11 changes: 8 additions & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,22 @@ def get_primitive_inputs(primitive, backend=None, num_sets=1):
raise ValueError(f"Invalid primitive type {type(primitive)}")


def transpile_pubs(in_pubs, backend):
def transpile_pubs(in_pubs, backend, program):
"""Return pubs with transformed circuits and observables."""
t_pubs = []
for pub in in_pubs:
t_circ = transpile(pub[0], backend=backend)
if len(pub) > 2:
if program == "estimator":
t_obs = remap_observables(pub[1], t_circ)
t_pub = [t_circ, t_obs]
for elem in pub[2:]:
t_pub.append(elem)
t_pubs.append(tuple(t_pub))
if program == "sampler":
if len(pub) == 2:
t_pub = [t_circ, pub[1]]
else:
t_pub = [t_circ]
t_pubs.append(tuple(t_pub))
return t_pubs


Expand Down

0 comments on commit 0837326

Please sign in to comment.