Skip to content

Commit

Permalink
Merge branch 'main' into tagged-op-phase-by
Browse files Browse the repository at this point in the history
  • Loading branch information
richrines1 authored Nov 14, 2024
2 parents 601b33b + 30f5b48 commit c6a63b4
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 19 deletions.
21 changes: 21 additions & 0 deletions cirq-google/cirq_google/api/v2/program.proto
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,32 @@ message ArgMapping {
repeated ArgEntry entries = 1;
}

message FunctionInterpolation {
// The x_values must be sorted in ascending order.
// The x_values and y_values must be of the same length.
repeated float x_values = 1 [packed = true]; // The independent variable.
repeated float y_values = 2 [packed = true]; // The dependent variable.

// Currently only piecewise linear interpolation (i.e. np.interp) is supported.
// That's we connect (x[i], y[i]) to (x[i+1], y[i+1]))
}

message CustomArg {
oneof custom_arg {
FunctionInterpolation function_interpolation_data = 1;
}
}

message InternalGate{
string name = 1; // Gate name.
string module = 2; // Gate module.
int32 num_qubits = 3; // Number of qubits. Required during deserialization.
map<string, Arg> gate_args = 4; // Gate args.

// Custom args are arguments that require special processing during deserialization.
// The `key` is the argument in the internal class's constructor, the `value`
// is a representation from which an internal object can be constructed.
map<string, CustomArg> custom_args = 5;
}

message CouplerPulseGate{
Expand Down
42 changes: 27 additions & 15 deletions cirq-google/cirq_google/api/v2/program_pb2.py

Large diffs are not rendered by default.

74 changes: 73 additions & 1 deletion cirq-google/cirq_google/api/v2/program_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 74 additions & 3 deletions cirq-google/cirq_google/ops/internal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict
from typing import Any, Dict, Optional, Sequence, Union
from collections.abc import Mapping

import numpy as np

from cirq import ops, value
from cirq_google.api.v2 import program_pb2


@value.value_equality
Expand All @@ -25,39 +30,57 @@ class InternalGate(ops.Gate):
constructor stored in `self.gate_args`.
"""

def __init__(self, gate_name: str, gate_module: str, num_qubits: int = 1, **kwargs):
def __init__(
self,
gate_name: str,
gate_module: str,
num_qubits: int = 1,
custom_args: Optional[Mapping[str, program_pb2.CustomArg]] = None,
**kwargs,
):
"""Instatiates an InternalGate.
Arguments:
gate_name: Gate class name.
gate_module: The module of the gate.
num_qubits: Number of qubits that the gate acts on.
custom_args: A mapping from argument name to `CustomArg`.
This is to support argument that require special processing.
**kwargs: The named arguments to be passed to the gate constructor.
"""
self.gate_module = gate_module
self.gate_name = gate_name
self._num_qubits = num_qubits
self.gate_args = kwargs
self.custom_args = custom_args or {}

def _num_qubits_(self) -> int:
return self._num_qubits

def __str__(self):
gate_args = ', '.join(f'{k}={v}' for k, v in self.gate_args.items())
gate_args = ', '.join(f'{k}={v}' for k, v in (self.gate_args | self.custom_args).items())
return f'{self.gate_module}.{self.gate_name}({gate_args})'

def __repr__(self) -> str:
gate_args = ', '.join(f'{k}={repr(v)}' for k, v in self.gate_args.items())
if gate_args != '':
gate_args = ', ' + gate_args

custom_args = ''
if self.custom_args:
custom_args = f", custom_args={self.custom_args}"

return (
f"cirq_google.InternalGate(gate_name='{self.gate_name}', "
f"gate_module='{self.gate_module}', "
f"num_qubits={self._num_qubits}"
f"{custom_args}"
f"{gate_args})"
)

def _json_dict_(self) -> Dict[str, Any]:
if self.custom_args:
raise ValueError('InternalGate with custom args are not json serializable')
return dict(
gate_name=self.gate_name,
gate_module=self.gate_module,
Expand All @@ -78,3 +101,51 @@ def _value_equality_values_(self):
self._num_qubits,
frozenset(self.gate_args.items()) if hashable else self.gate_args,
)


def function_points_to_proto(
x: Union[Sequence[float], np.ndarray],
y: Union[Sequence[float], np.ndarray],
msg: Optional[program_pb2.CustomArg] = None,
) -> program_pb2.CustomArg:
"""Return CustomArg that expresses a function through its x and y values.
Args:
x: Sequence of values of the free variable.
For 1D functions, this input is assumed to be given in increasing order.
y: Sequence of values of the dependent variable.
Where y[i] = func(x[i]) where `func` is the function being encoded.
msg: Optional CustomArg to serialize to.
If not provided a CustomArg is created.
Returns:
A CustomArg encoding the function.
Raises:
ValueError: If
- `x` is 1D and not sorted in increasing order.
- `x` and `y` don't have the same number of points.
- `y` is multidimensional.
- `x` is multidimensional.
"""

x = np.asarray(x)
y = np.asarray(y)

if len(x.shape) != 1:
raise ValueError('The free variable must be one dimensional')

if len(x.shape) == 1 and not np.all(np.diff(x) > 0):
raise ValueError('The free variable must be sorted in increasing order')

if len(y.shape) != 1:
raise ValueError('The dependent variable must be one dimensional')

if x.shape[0] != y.shape[0]:
raise ValueError('Mismatch between number of points in x and y')

if msg is None:
msg = program_pb2.CustomArg()
msg.function_interpolation_data.x_values[:] = x
msg.function_interpolation_data.y_values[:] = y
return msg
67 changes: 67 additions & 0 deletions cirq-google/cirq_google/ops/internal_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import cirq
import cirq_google
import pytest
from cirq_google.ops import internal_gate
from cirq_google.serialization import arg_func_langs


def test_internal_gate():
Expand Down Expand Up @@ -67,3 +71,66 @@ def test_internal_gate_with_hashable_args_is_hashable():
)
with pytest.raises(TypeError, match="unhashable"):
_ = hash(unhashable)


def test_internal_gate_with_custom_function_repr():
x = np.linspace(-1, 1, 10)
y = x**2
encoded_func = internal_gate.function_points_to_proto(x=x, y=y)

gate = internal_gate.InternalGate(
gate_name='GateWithFunction',
gate_module='test',
num_qubits=2,
custom_args={'func': encoded_func},
)

assert repr(gate) == (
"cirq_google.InternalGate(gate_name='GateWithFunction', "
f"gate_module='test', num_qubits=2, custom_args={gate.custom_args})"
)

assert str(gate) == (f"test.GateWithFunction(func={encoded_func})")

with pytest.raises(ValueError):
_ = cirq.to_json(gate)


def test_internal_gate_with_custom_function_round_trip():
original_func = lambda x: x**2
x = np.linspace(-1, 1, 10)
y = original_func(x)
encoded_func = internal_gate.function_points_to_proto(x=x, y=y)

gate = internal_gate.InternalGate(
gate_name='GateWithFunction',
gate_module='test',
num_qubits=2,
custom_args={'func': encoded_func},
)

msg = arg_func_langs.internal_gate_arg_to_proto(gate)

new_gate = arg_func_langs.internal_gate_from_proto(msg, arg_func_langs.MOST_PERMISSIVE_LANGUAGE)

func_proto = new_gate.custom_args['func'].function_interpolation_data

np.testing.assert_allclose(x, func_proto.x_values)
np.testing.assert_allclose(y, func_proto.y_values)


def test_function_points_to_proto_invalid_args_raise():
x = np.linspace(-1, 1, 10)
y = x + 1

with pytest.raises(ValueError, match='The free variable must be one dimensional'):
_ = internal_gate.function_points_to_proto(np.zeros((10, 2)), y)

with pytest.raises(ValueError, match='sorted in increasing order'):
_ = internal_gate.function_points_to_proto(x[::-1], y)

with pytest.raises(ValueError, match='Mismatch between number of points in x and y'):
_ = internal_gate.function_points_to_proto(x, np.linspace(-1, 1, 40))

with pytest.raises(ValueError, match='The dependent variable must be one dimensional'):
_ = internal_gate.function_points_to_proto(x, np.zeros((10, 2)))
5 changes: 5 additions & 0 deletions cirq-google/cirq_google/serialization/arg_func_langs.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,13 @@ def internal_gate_arg_to_proto(
msg.name = value.gate_name
msg.module = value.gate_module
msg.num_qubits = value.num_qubits()

for k, v in value.gate_args.items():
arg_to_proto(value=v, out=msg.gate_args[k])

for ck, cv in value.custom_args.items():
msg.custom_args[ck].MergeFrom(cv)

return msg


Expand All @@ -461,6 +465,7 @@ def internal_gate_from_proto(
gate_name=str(msg.name),
gate_module=str(msg.module),
num_qubits=int(msg.num_qubits),
custom_args=msg.custom_args,
**gate_args,
)

Expand Down

0 comments on commit c6a63b4

Please sign in to comment.