From 30f5b48295eac196918b9e195d47bf8445756276 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Thu, 14 Nov 2024 07:26:23 -0800 Subject: [PATCH] Create a way to encode general functions (#6788) --- cirq-google/cirq_google/api/v2/program.proto | 21 +++++ cirq-google/cirq_google/api/v2/program_pb2.py | 42 ++++++---- .../cirq_google/api/v2/program_pb2.pyi | 74 +++++++++++++++++- cirq-google/cirq_google/ops/internal_gate.py | 77 ++++++++++++++++++- .../cirq_google/ops/internal_gate_test.py | 67 ++++++++++++++++ .../serialization/arg_func_langs.py | 5 ++ 6 files changed, 267 insertions(+), 19 deletions(-) diff --git a/cirq-google/cirq_google/api/v2/program.proto b/cirq-google/cirq_google/api/v2/program.proto index 580ba28c988..cbc66cc47f2 100644 --- a/cirq-google/cirq_google/api/v2/program.proto +++ b/cirq-google/cirq_google/api/v2/program.proto @@ -443,11 +443,32 @@ message ArgMapping { repeated ArgEntry entries = 1; } +message FunctionInterpolation { + // The x_values must be sorted in ascending order. + // The x_values and y_values must be of the same length. + repeated float x_values = 1 [packed = true]; // The independent variable. + repeated float y_values = 2 [packed = true]; // The dependent variable. + + // Currently only piecewise linear interpolation (i.e. np.interp) is supported. + // That's we connect (x[i], y[i]) to (x[i+1], y[i+1])) +} + +message CustomArg { + oneof custom_arg { + FunctionInterpolation function_interpolation_data = 1; + } +} + message InternalGate{ string name = 1; // Gate name. string module = 2; // Gate module. int32 num_qubits = 3; // Number of qubits. Required during deserialization. map gate_args = 4; // Gate args. + + // Custom args are arguments that require special processing during deserialization. + // The `key` is the argument in the internal class's constructor, the `value` + // is a representation from which an internal object can be constructed. + map custom_args = 5; } message CouplerPulseGate{ diff --git a/cirq-google/cirq_google/api/v2/program_pb2.py b/cirq-google/cirq_google/api/v2/program_pb2.py index 7c6c86dc2bf..c04d25d6f23 100644 --- a/cirq-google/cirq_google/api/v2/program_pb2.py +++ b/cirq-google/cirq_google/api/v2/program_pb2.py @@ -14,7 +14,7 @@ from tunits.proto import tunits_pb2 as tunits_dot_proto_dot_tunits__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n cirq_google/api/v2/program.proto\x12\x12\x63irq.google.api.v2\x1a\x19tunits/proto/tunits.proto\"\xd7\x01\n\x07Program\x12.\n\x08language\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.Language\x12.\n\x07\x63ircuit\x18\x02 \x01(\x0b\x32\x1b.cirq.google.api.v2.CircuitH\x00\x12\x30\n\x08schedule\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ScheduleH\x00\x12/\n\tconstants\x18\x04 \x03(\x0b\x32\x1c.cirq.google.api.v2.ConstantB\t\n\x07program\"\x93\x01\n\x08\x43onstant\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x34\n\rcircuit_value\x18\x02 \x01(\x0b\x32\x1b.cirq.google.api.v2.CircuitH\x00\x12*\n\x05qubit\x18\x03 \x01(\x0b\x32\x19.cirq.google.api.v2.QubitH\x00\x42\r\n\x0b\x63onst_value\"\xd4\x01\n\x07\x43ircuit\x12K\n\x13scheduling_strategy\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.Circuit.SchedulingStrategy\x12+\n\x07moments\x18\x02 \x03(\x0b\x32\x1a.cirq.google.api.v2.Moment\"O\n\x12SchedulingStrategy\x12#\n\x1fSCHEDULING_STRATEGY_UNSPECIFIED\x10\x00\x12\x14\n\x10MOMENT_BY_MOMENT\x10\x01\"}\n\x06Moment\x12\x31\n\noperations\x18\x01 \x03(\x0b\x32\x1d.cirq.google.api.v2.Operation\x12@\n\x12\x63ircuit_operations\x18\x02 \x03(\x0b\x32$.cirq.google.api.v2.CircuitOperation\"P\n\x08Schedule\x12\x44\n\x14scheduled_operations\x18\x03 \x03(\x0b\x32&.cirq.google.api.v2.ScheduledOperation\"`\n\x12ScheduledOperation\x12\x30\n\toperation\x18\x01 \x01(\x0b\x32\x1d.cirq.google.api.v2.Operation\x12\x18\n\x10start_time_picos\x18\x02 \x01(\x03\"?\n\x08Language\x12\x14\n\x08gate_set\x18\x01 \x01(\tB\x02\x18\x01\x12\x1d\n\x15\x61rg_function_language\x18\x02 \x01(\t\"k\n\x08\x46loatArg\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12\x10\n\x06symbol\x18\x02 \x01(\tH\x00\x12/\n\x04\x66unc\x18\x03 \x01(\x0b\x32\x1f.cirq.google.api.v2.ArgFunctionH\x00\x42\x05\n\x03\x61rg\":\n\x08XPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\":\n\x08YPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"Q\n\x08ZPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x15\n\ris_physical_z\x18\x02 \x01(\x08\"v\n\x0ePhasedXPowGate\x12\x34\n\x0ephase_exponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12.\n\x08\x65xponent\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\xad\x01\n\x0cPhasedXZGate\x12\x30\n\nx_exponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x30\n\nz_exponent\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x39\n\x13\x61xis_phase_exponent\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\";\n\tCZPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\x7f\n\x08\x46SimGate\x12+\n\x05theta\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12)\n\x03phi\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x1b\n\x13translate_via_model\x18\x03 \x01(\x08\">\n\x0cISwapPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"e\n\x0fMeasurementGate\x12$\n\x03key\x18\x01 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\x12,\n\x0binvert_mask\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\"@\n\x08WaitGate\x12\x34\n\x0e\x64uration_nanos\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\xce\t\n\tOperation\x12*\n\x04gate\x18\x01 \x01(\x0b\x32\x18.cirq.google.api.v2.GateB\x02\x18\x01\x12\x30\n\x08xpowgate\x18\x07 \x01(\x0b\x32\x1c.cirq.google.api.v2.XPowGateH\x00\x12\x30\n\x08ypowgate\x18\x08 \x01(\x0b\x32\x1c.cirq.google.api.v2.YPowGateH\x00\x12\x30\n\x08zpowgate\x18\t \x01(\x0b\x32\x1c.cirq.google.api.v2.ZPowGateH\x00\x12<\n\x0ephasedxpowgate\x18\n \x01(\x0b\x32\".cirq.google.api.v2.PhasedXPowGateH\x00\x12\x38\n\x0cphasedxzgate\x18\x0b \x01(\x0b\x32 .cirq.google.api.v2.PhasedXZGateH\x00\x12\x32\n\tczpowgate\x18\x0c \x01(\x0b\x32\x1d.cirq.google.api.v2.CZPowGateH\x00\x12\x30\n\x08\x66simgate\x18\r \x01(\x0b\x32\x1c.cirq.google.api.v2.FSimGateH\x00\x12\x38\n\x0ciswappowgate\x18\x0e \x01(\x0b\x32 .cirq.google.api.v2.ISwapPowGateH\x00\x12>\n\x0fmeasurementgate\x18\x0f \x01(\x0b\x32#.cirq.google.api.v2.MeasurementGateH\x00\x12\x30\n\x08waitgate\x18\x10 \x01(\x0b\x32\x1c.cirq.google.api.v2.WaitGateH\x00\x12\x38\n\x0cinternalgate\x18\x11 \x01(\x0b\x32 .cirq.google.api.v2.InternalGateH\x00\x12@\n\x10\x63ouplerpulsegate\x18\x12 \x01(\x0b\x32$.cirq.google.api.v2.CouplerPulseGateH\x00\x12\x38\n\x0cidentitygate\x18\x13 \x01(\x0b\x32 .cirq.google.api.v2.IdentityGateH\x00\x12\x30\n\x08hpowgate\x18\x14 \x01(\x0b\x32\x1c.cirq.google.api.v2.HPowGateH\x00\x12N\n\x17singlequbitcliffordgate\x18\x15 \x01(\x0b\x32+.cirq.google.api.v2.SingleQubitCliffordGateH\x00\x12\x39\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\'.cirq.google.api.v2.Operation.ArgsEntryB\x02\x18\x01\x12)\n\x06qubits\x18\x03 \x03(\x0b\x32\x19.cirq.google.api.v2.Qubit\x12\x1c\n\x14qubit_constant_index\x18\x06 \x03(\x05\x12\x15\n\x0btoken_value\x18\x04 \x01(\tH\x01\x12\x1e\n\x14token_constant_index\x18\x05 \x01(\x05H\x01\x12%\n\x04tags\x18\x16 \x03(\x0b\x32\x17.cirq.google.api.v2.Tag\x1a\x44\n\tArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg:\x02\x38\x01\x42\x0c\n\ngate_valueB\x07\n\x05token\"<\n\x16\x44ynamicalDecouplingTag\x12\x15\n\x08protocol\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x0b\n\t_protocol\"X\n\x03Tag\x12J\n\x14\x64ynamical_decoupling\x18\x01 \x01(\x0b\x32*.cirq.google.api.v2.DynamicalDecouplingTagH\x00\x42\x05\n\x03tag\"\x12\n\x04Gate\x12\n\n\x02id\x18\x01 \x01(\t\"\x13\n\x05Qubit\x12\n\n\x02id\x18\x02 \x01(\t\"\x9c\x01\n\x03\x41rg\x12\x31\n\targ_value\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12\x10\n\x06symbol\x18\x02 \x01(\tH\x00\x12/\n\x04\x66unc\x18\x03 \x01(\x0b\x32\x1f.cirq.google.api.v2.ArgFunctionH\x00\x12\x18\n\x0e\x63onstant_index\x18\x04 \x01(\x05H\x00\x42\x05\n\x03\x61rg\"\xf9\x02\n\x08\x41rgValue\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12:\n\x0b\x62ool_values\x18\x02 \x01(\x0b\x32#.cirq.google.api.v2.RepeatedBooleanH\x00\x12\x16\n\x0cstring_value\x18\x03 \x01(\tH\x00\x12\x16\n\x0c\x64ouble_value\x18\x04 \x01(\x01H\x00\x12\x39\n\x0cint64_values\x18\x05 \x01(\x0b\x32!.cirq.google.api.v2.RepeatedInt64H\x00\x12;\n\rdouble_values\x18\x06 \x01(\x0b\x32\".cirq.google.api.v2.RepeatedDoubleH\x00\x12;\n\rstring_values\x18\x07 \x01(\x0b\x32\".cirq.google.api.v2.RepeatedStringH\x00\x12(\n\x0fvalue_with_unit\x18\x08 \x01(\x0b\x32\r.tunits.ValueH\x00\x42\x0b\n\targ_value\"\x1f\n\rRepeatedInt64\x12\x0e\n\x06values\x18\x01 \x03(\x03\" \n\x0eRepeatedDouble\x12\x0e\n\x06values\x18\x01 \x03(\x01\" \n\x0eRepeatedString\x12\x0e\n\x06values\x18\x01 \x03(\t\"!\n\x0fRepeatedBoolean\x12\x0e\n\x06values\x18\x01 \x03(\x08\"B\n\x0b\x41rgFunction\x12\x0c\n\x04type\x18\x01 \x01(\t\x12%\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x17.cirq.google.api.v2.Arg\"\xaf\x02\n\x10\x43ircuitOperation\x12\x1e\n\x16\x63ircuit_constant_index\x18\x01 \x01(\x05\x12M\n\x18repetition_specification\x18\x02 \x01(\x0b\x32+.cirq.google.api.v2.RepetitionSpecification\x12\x33\n\tqubit_map\x18\x03 \x01(\x0b\x32 .cirq.google.api.v2.QubitMapping\x12\x46\n\x13measurement_key_map\x18\x04 \x01(\x0b\x32).cirq.google.api.v2.MeasurementKeyMapping\x12/\n\x07\x61rg_map\x18\x05 \x01(\x0b\x32\x1e.cirq.google.api.v2.ArgMapping\"\xbc\x01\n\x17RepetitionSpecification\x12S\n\x0erepetition_ids\x18\x01 \x01(\x0b\x32\x39.cirq.google.api.v2.RepetitionSpecification.RepetitionIdsH\x00\x12\x1a\n\x10repetition_count\x18\x02 \x01(\x05H\x00\x1a\x1c\n\rRepetitionIds\x12\x0b\n\x03ids\x18\x01 \x03(\tB\x12\n\x10repetition_value\"\xac\x01\n\x0cQubitMapping\x12<\n\x07\x65ntries\x18\x01 \x03(\x0b\x32+.cirq.google.api.v2.QubitMapping.QubitEntry\x1a^\n\nQubitEntry\x12&\n\x03key\x18\x01 \x01(\x0b\x32\x19.cirq.google.api.v2.Qubit\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Qubit\"$\n\x0eMeasurementKey\x12\x12\n\nstring_key\x18\x01 \x01(\t\"\xe2\x01\n\x15MeasurementKeyMapping\x12N\n\x07\x65ntries\x18\x01 \x03(\x0b\x32=.cirq.google.api.v2.MeasurementKeyMapping.MeasurementKeyEntry\x1ay\n\x13MeasurementKeyEntry\x12/\n\x03key\x18\x01 \x01(\x0b\x32\".cirq.google.api.v2.MeasurementKey\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".cirq.google.api.v2.MeasurementKey\"\xa0\x01\n\nArgMapping\x12\x38\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\'.cirq.google.api.v2.ArgMapping.ArgEntry\x1aX\n\x08\x41rgEntry\x12$\n\x03key\x18\x01 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\"\xcd\x01\n\x0cInternalGate\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06module\x18\x02 \x01(\t\x12\x12\n\nnum_qubits\x18\x03 \x01(\x05\x12\x41\n\tgate_args\x18\x04 \x03(\x0b\x32..cirq.google.api.v2.InternalGate.GateArgsEntry\x1aH\n\rGateArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg:\x02\x38\x01\"\xd8\x03\n\x10\x43ouplerPulseGate\x12\x37\n\x0chold_time_ps\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x00\x88\x01\x01\x12\x37\n\x0crise_time_ps\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x01\x88\x01\x01\x12:\n\x0fpadding_time_ps\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x02\x88\x01\x01\x12\x37\n\x0c\x63oupling_mhz\x18\x04 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x03\x88\x01\x01\x12\x38\n\rq0_detune_mhz\x18\x05 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x04\x88\x01\x01\x12\x38\n\rq1_detune_mhz\x18\x06 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x05\x88\x01\x01\x42\x0f\n\r_hold_time_psB\x0f\n\r_rise_time_psB\x12\n\x10_padding_time_psB\x0f\n\r_coupling_mhzB\x10\n\x0e_q0_detune_mhzB\x10\n\x0e_q1_detune_mhz\"\x8b\x01\n\x0f\x43liffordTableau\x12\x17\n\nnum_qubits\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rinitial_state\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\n\n\x02rs\x18\x03 \x03(\x08\x12\n\n\x02xs\x18\x04 \x03(\x08\x12\n\n\x02zs\x18\x05 \x03(\x08\x42\r\n\x0b_num_qubitsB\x10\n\x0e_initial_state\"O\n\x17SingleQubitCliffordGate\x12\x34\n\x07tableau\x18\x01 \x01(\x0b\x32#.cirq.google.api.v2.CliffordTableau\"!\n\x0cIdentityGate\x12\x11\n\tqid_shape\x18\x01 \x03(\r\":\n\x08HPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgB/\n\x1d\x63om.google.cirq.google.api.v2B\x0cProgramProtoP\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n cirq_google/api/v2/program.proto\x12\x12\x63irq.google.api.v2\x1a\x19tunits/proto/tunits.proto\"\xd7\x01\n\x07Program\x12.\n\x08language\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.Language\x12.\n\x07\x63ircuit\x18\x02 \x01(\x0b\x32\x1b.cirq.google.api.v2.CircuitH\x00\x12\x30\n\x08schedule\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ScheduleH\x00\x12/\n\tconstants\x18\x04 \x03(\x0b\x32\x1c.cirq.google.api.v2.ConstantB\t\n\x07program\"\x93\x01\n\x08\x43onstant\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x34\n\rcircuit_value\x18\x02 \x01(\x0b\x32\x1b.cirq.google.api.v2.CircuitH\x00\x12*\n\x05qubit\x18\x03 \x01(\x0b\x32\x19.cirq.google.api.v2.QubitH\x00\x42\r\n\x0b\x63onst_value\"\xd4\x01\n\x07\x43ircuit\x12K\n\x13scheduling_strategy\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.Circuit.SchedulingStrategy\x12+\n\x07moments\x18\x02 \x03(\x0b\x32\x1a.cirq.google.api.v2.Moment\"O\n\x12SchedulingStrategy\x12#\n\x1fSCHEDULING_STRATEGY_UNSPECIFIED\x10\x00\x12\x14\n\x10MOMENT_BY_MOMENT\x10\x01\"}\n\x06Moment\x12\x31\n\noperations\x18\x01 \x03(\x0b\x32\x1d.cirq.google.api.v2.Operation\x12@\n\x12\x63ircuit_operations\x18\x02 \x03(\x0b\x32$.cirq.google.api.v2.CircuitOperation\"P\n\x08Schedule\x12\x44\n\x14scheduled_operations\x18\x03 \x03(\x0b\x32&.cirq.google.api.v2.ScheduledOperation\"`\n\x12ScheduledOperation\x12\x30\n\toperation\x18\x01 \x01(\x0b\x32\x1d.cirq.google.api.v2.Operation\x12\x18\n\x10start_time_picos\x18\x02 \x01(\x03\"?\n\x08Language\x12\x14\n\x08gate_set\x18\x01 \x01(\tB\x02\x18\x01\x12\x1d\n\x15\x61rg_function_language\x18\x02 \x01(\t\"k\n\x08\x46loatArg\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12\x10\n\x06symbol\x18\x02 \x01(\tH\x00\x12/\n\x04\x66unc\x18\x03 \x01(\x0b\x32\x1f.cirq.google.api.v2.ArgFunctionH\x00\x42\x05\n\x03\x61rg\":\n\x08XPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\":\n\x08YPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"Q\n\x08ZPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x15\n\ris_physical_z\x18\x02 \x01(\x08\"v\n\x0ePhasedXPowGate\x12\x34\n\x0ephase_exponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12.\n\x08\x65xponent\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\xad\x01\n\x0cPhasedXZGate\x12\x30\n\nx_exponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x30\n\nz_exponent\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x39\n\x13\x61xis_phase_exponent\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\";\n\tCZPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\x7f\n\x08\x46SimGate\x12+\n\x05theta\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12)\n\x03phi\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\x12\x1b\n\x13translate_via_model\x18\x03 \x01(\x08\">\n\x0cISwapPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"e\n\x0fMeasurementGate\x12$\n\x03key\x18\x01 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\x12,\n\x0binvert_mask\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\"@\n\x08WaitGate\x12\x34\n\x0e\x64uration_nanos\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArg\"\xce\t\n\tOperation\x12*\n\x04gate\x18\x01 \x01(\x0b\x32\x18.cirq.google.api.v2.GateB\x02\x18\x01\x12\x30\n\x08xpowgate\x18\x07 \x01(\x0b\x32\x1c.cirq.google.api.v2.XPowGateH\x00\x12\x30\n\x08ypowgate\x18\x08 \x01(\x0b\x32\x1c.cirq.google.api.v2.YPowGateH\x00\x12\x30\n\x08zpowgate\x18\t \x01(\x0b\x32\x1c.cirq.google.api.v2.ZPowGateH\x00\x12<\n\x0ephasedxpowgate\x18\n \x01(\x0b\x32\".cirq.google.api.v2.PhasedXPowGateH\x00\x12\x38\n\x0cphasedxzgate\x18\x0b \x01(\x0b\x32 .cirq.google.api.v2.PhasedXZGateH\x00\x12\x32\n\tczpowgate\x18\x0c \x01(\x0b\x32\x1d.cirq.google.api.v2.CZPowGateH\x00\x12\x30\n\x08\x66simgate\x18\r \x01(\x0b\x32\x1c.cirq.google.api.v2.FSimGateH\x00\x12\x38\n\x0ciswappowgate\x18\x0e \x01(\x0b\x32 .cirq.google.api.v2.ISwapPowGateH\x00\x12>\n\x0fmeasurementgate\x18\x0f \x01(\x0b\x32#.cirq.google.api.v2.MeasurementGateH\x00\x12\x30\n\x08waitgate\x18\x10 \x01(\x0b\x32\x1c.cirq.google.api.v2.WaitGateH\x00\x12\x38\n\x0cinternalgate\x18\x11 \x01(\x0b\x32 .cirq.google.api.v2.InternalGateH\x00\x12@\n\x10\x63ouplerpulsegate\x18\x12 \x01(\x0b\x32$.cirq.google.api.v2.CouplerPulseGateH\x00\x12\x38\n\x0cidentitygate\x18\x13 \x01(\x0b\x32 .cirq.google.api.v2.IdentityGateH\x00\x12\x30\n\x08hpowgate\x18\x14 \x01(\x0b\x32\x1c.cirq.google.api.v2.HPowGateH\x00\x12N\n\x17singlequbitcliffordgate\x18\x15 \x01(\x0b\x32+.cirq.google.api.v2.SingleQubitCliffordGateH\x00\x12\x39\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\'.cirq.google.api.v2.Operation.ArgsEntryB\x02\x18\x01\x12)\n\x06qubits\x18\x03 \x03(\x0b\x32\x19.cirq.google.api.v2.Qubit\x12\x1c\n\x14qubit_constant_index\x18\x06 \x03(\x05\x12\x15\n\x0btoken_value\x18\x04 \x01(\tH\x01\x12\x1e\n\x14token_constant_index\x18\x05 \x01(\x05H\x01\x12%\n\x04tags\x18\x16 \x03(\x0b\x32\x17.cirq.google.api.v2.Tag\x1a\x44\n\tArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg:\x02\x38\x01\x42\x0c\n\ngate_valueB\x07\n\x05token\"<\n\x16\x44ynamicalDecouplingTag\x12\x15\n\x08protocol\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x0b\n\t_protocol\"X\n\x03Tag\x12J\n\x14\x64ynamical_decoupling\x18\x01 \x01(\x0b\x32*.cirq.google.api.v2.DynamicalDecouplingTagH\x00\x42\x05\n\x03tag\"\x12\n\x04Gate\x12\n\n\x02id\x18\x01 \x01(\t\"\x13\n\x05Qubit\x12\n\n\x02id\x18\x02 \x01(\t\"\x9c\x01\n\x03\x41rg\x12\x31\n\targ_value\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12\x10\n\x06symbol\x18\x02 \x01(\tH\x00\x12/\n\x04\x66unc\x18\x03 \x01(\x0b\x32\x1f.cirq.google.api.v2.ArgFunctionH\x00\x12\x18\n\x0e\x63onstant_index\x18\x04 \x01(\x05H\x00\x42\x05\n\x03\x61rg\"\xf9\x02\n\x08\x41rgValue\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12:\n\x0b\x62ool_values\x18\x02 \x01(\x0b\x32#.cirq.google.api.v2.RepeatedBooleanH\x00\x12\x16\n\x0cstring_value\x18\x03 \x01(\tH\x00\x12\x16\n\x0c\x64ouble_value\x18\x04 \x01(\x01H\x00\x12\x39\n\x0cint64_values\x18\x05 \x01(\x0b\x32!.cirq.google.api.v2.RepeatedInt64H\x00\x12;\n\rdouble_values\x18\x06 \x01(\x0b\x32\".cirq.google.api.v2.RepeatedDoubleH\x00\x12;\n\rstring_values\x18\x07 \x01(\x0b\x32\".cirq.google.api.v2.RepeatedStringH\x00\x12(\n\x0fvalue_with_unit\x18\x08 \x01(\x0b\x32\r.tunits.ValueH\x00\x42\x0b\n\targ_value\"\x1f\n\rRepeatedInt64\x12\x0e\n\x06values\x18\x01 \x03(\x03\" \n\x0eRepeatedDouble\x12\x0e\n\x06values\x18\x01 \x03(\x01\" \n\x0eRepeatedString\x12\x0e\n\x06values\x18\x01 \x03(\t\"!\n\x0fRepeatedBoolean\x12\x0e\n\x06values\x18\x01 \x03(\x08\"B\n\x0b\x41rgFunction\x12\x0c\n\x04type\x18\x01 \x01(\t\x12%\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x17.cirq.google.api.v2.Arg\"\xaf\x02\n\x10\x43ircuitOperation\x12\x1e\n\x16\x63ircuit_constant_index\x18\x01 \x01(\x05\x12M\n\x18repetition_specification\x18\x02 \x01(\x0b\x32+.cirq.google.api.v2.RepetitionSpecification\x12\x33\n\tqubit_map\x18\x03 \x01(\x0b\x32 .cirq.google.api.v2.QubitMapping\x12\x46\n\x13measurement_key_map\x18\x04 \x01(\x0b\x32).cirq.google.api.v2.MeasurementKeyMapping\x12/\n\x07\x61rg_map\x18\x05 \x01(\x0b\x32\x1e.cirq.google.api.v2.ArgMapping\"\xbc\x01\n\x17RepetitionSpecification\x12S\n\x0erepetition_ids\x18\x01 \x01(\x0b\x32\x39.cirq.google.api.v2.RepetitionSpecification.RepetitionIdsH\x00\x12\x1a\n\x10repetition_count\x18\x02 \x01(\x05H\x00\x1a\x1c\n\rRepetitionIds\x12\x0b\n\x03ids\x18\x01 \x03(\tB\x12\n\x10repetition_value\"\xac\x01\n\x0cQubitMapping\x12<\n\x07\x65ntries\x18\x01 \x03(\x0b\x32+.cirq.google.api.v2.QubitMapping.QubitEntry\x1a^\n\nQubitEntry\x12&\n\x03key\x18\x01 \x01(\x0b\x32\x19.cirq.google.api.v2.Qubit\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Qubit\"$\n\x0eMeasurementKey\x12\x12\n\nstring_key\x18\x01 \x01(\t\"\xe2\x01\n\x15MeasurementKeyMapping\x12N\n\x07\x65ntries\x18\x01 \x03(\x0b\x32=.cirq.google.api.v2.MeasurementKeyMapping.MeasurementKeyEntry\x1ay\n\x13MeasurementKeyEntry\x12/\n\x03key\x18\x01 \x01(\x0b\x32\".cirq.google.api.v2.MeasurementKey\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".cirq.google.api.v2.MeasurementKey\"\xa0\x01\n\nArgMapping\x12\x38\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\'.cirq.google.api.v2.ArgMapping.ArgEntry\x1aX\n\x08\x41rgEntry\x12$\n\x03key\x18\x01 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg\"C\n\x15\x46unctionInterpolation\x12\x14\n\x08x_values\x18\x01 \x03(\x02\x42\x02\x10\x01\x12\x14\n\x08y_values\x18\x02 \x03(\x02\x42\x02\x10\x01\"k\n\tCustomArg\x12P\n\x1b\x66unction_interpolation_data\x18\x01 \x01(\x0b\x32).cirq.google.api.v2.FunctionInterpolationH\x00\x42\x0c\n\ncustom_arg\"\xe6\x02\n\x0cInternalGate\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06module\x18\x02 \x01(\t\x12\x12\n\nnum_qubits\x18\x03 \x01(\x05\x12\x41\n\tgate_args\x18\x04 \x03(\x0b\x32..cirq.google.api.v2.InternalGate.GateArgsEntry\x12\x45\n\x0b\x63ustom_args\x18\x05 \x03(\x0b\x32\x30.cirq.google.api.v2.InternalGate.CustomArgsEntry\x1aH\n\rGateArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.cirq.google.api.v2.Arg:\x02\x38\x01\x1aP\n\x0f\x43ustomArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.cirq.google.api.v2.CustomArg:\x02\x38\x01\"\xd8\x03\n\x10\x43ouplerPulseGate\x12\x37\n\x0chold_time_ps\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x00\x88\x01\x01\x12\x37\n\x0crise_time_ps\x18\x02 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x01\x88\x01\x01\x12:\n\x0fpadding_time_ps\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x02\x88\x01\x01\x12\x37\n\x0c\x63oupling_mhz\x18\x04 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x03\x88\x01\x01\x12\x38\n\rq0_detune_mhz\x18\x05 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x04\x88\x01\x01\x12\x38\n\rq1_detune_mhz\x18\x06 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgH\x05\x88\x01\x01\x42\x0f\n\r_hold_time_psB\x0f\n\r_rise_time_psB\x12\n\x10_padding_time_psB\x0f\n\r_coupling_mhzB\x10\n\x0e_q0_detune_mhzB\x10\n\x0e_q1_detune_mhz\"\x8b\x01\n\x0f\x43liffordTableau\x12\x17\n\nnum_qubits\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rinitial_state\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\n\n\x02rs\x18\x03 \x03(\x08\x12\n\n\x02xs\x18\x04 \x03(\x08\x12\n\n\x02zs\x18\x05 \x03(\x08\x42\r\n\x0b_num_qubitsB\x10\n\x0e_initial_state\"O\n\x17SingleQubitCliffordGate\x12\x34\n\x07tableau\x18\x01 \x01(\x0b\x32#.cirq.google.api.v2.CliffordTableau\"!\n\x0cIdentityGate\x12\x11\n\tqid_shape\x18\x01 \x03(\r\":\n\x08HPowGate\x12.\n\x08\x65xponent\x18\x01 \x01(\x0b\x32\x1c.cirq.google.api.v2.FloatArgB/\n\x1d\x63om.google.cirq.google.api.v2B\x0cProgramProtoP\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,8 +30,14 @@ _OPERATION.fields_by_name['gate']._serialized_options = b'\030\001' _OPERATION.fields_by_name['args']._options = None _OPERATION.fields_by_name['args']._serialized_options = b'\030\001' + _FUNCTIONINTERPOLATION.fields_by_name['x_values']._options = None + _FUNCTIONINTERPOLATION.fields_by_name['x_values']._serialized_options = b'\020\001' + _FUNCTIONINTERPOLATION.fields_by_name['y_values']._options = None + _FUNCTIONINTERPOLATION.fields_by_name['y_values']._serialized_options = b'\020\001' _INTERNALGATE_GATEARGSENTRY._options = None _INTERNALGATE_GATEARGSENTRY._serialized_options = b'8\001' + _INTERNALGATE_CUSTOMARGSENTRY._options = None + _INTERNALGATE_CUSTOMARGSENTRY._serialized_options = b'8\001' _globals['_PROGRAM']._serialized_start=84 _globals['_PROGRAM']._serialized_end=299 _globals['_CONSTANT']._serialized_start=302 @@ -116,18 +122,24 @@ _globals['_ARGMAPPING']._serialized_end=5338 _globals['_ARGMAPPING_ARGENTRY']._serialized_start=5250 _globals['_ARGMAPPING_ARGENTRY']._serialized_end=5338 - _globals['_INTERNALGATE']._serialized_start=5341 - _globals['_INTERNALGATE']._serialized_end=5546 - _globals['_INTERNALGATE_GATEARGSENTRY']._serialized_start=5474 - _globals['_INTERNALGATE_GATEARGSENTRY']._serialized_end=5546 - _globals['_COUPLERPULSEGATE']._serialized_start=5549 - _globals['_COUPLERPULSEGATE']._serialized_end=6021 - _globals['_CLIFFORDTABLEAU']._serialized_start=6024 - _globals['_CLIFFORDTABLEAU']._serialized_end=6163 - _globals['_SINGLEQUBITCLIFFORDGATE']._serialized_start=6165 - _globals['_SINGLEQUBITCLIFFORDGATE']._serialized_end=6244 - _globals['_IDENTITYGATE']._serialized_start=6246 - _globals['_IDENTITYGATE']._serialized_end=6279 - _globals['_HPOWGATE']._serialized_start=6281 - _globals['_HPOWGATE']._serialized_end=6339 + _globals['_FUNCTIONINTERPOLATION']._serialized_start=5340 + _globals['_FUNCTIONINTERPOLATION']._serialized_end=5407 + _globals['_CUSTOMARG']._serialized_start=5409 + _globals['_CUSTOMARG']._serialized_end=5516 + _globals['_INTERNALGATE']._serialized_start=5519 + _globals['_INTERNALGATE']._serialized_end=5877 + _globals['_INTERNALGATE_GATEARGSENTRY']._serialized_start=5723 + _globals['_INTERNALGATE_GATEARGSENTRY']._serialized_end=5795 + _globals['_INTERNALGATE_CUSTOMARGSENTRY']._serialized_start=5797 + _globals['_INTERNALGATE_CUSTOMARGSENTRY']._serialized_end=5877 + _globals['_COUPLERPULSEGATE']._serialized_start=5880 + _globals['_COUPLERPULSEGATE']._serialized_end=6352 + _globals['_CLIFFORDTABLEAU']._serialized_start=6355 + _globals['_CLIFFORDTABLEAU']._serialized_end=6494 + _globals['_SINGLEQUBITCLIFFORDGATE']._serialized_start=6496 + _globals['_SINGLEQUBITCLIFFORDGATE']._serialized_end=6575 + _globals['_IDENTITYGATE']._serialized_start=6577 + _globals['_IDENTITYGATE']._serialized_end=6610 + _globals['_HPOWGATE']._serialized_start=6612 + _globals['_HPOWGATE']._serialized_end=6670 # @@protoc_insertion_point(module_scope) diff --git a/cirq-google/cirq_google/api/v2/program_pb2.pyi b/cirq-google/cirq_google/api/v2/program_pb2.pyi index 87067a23a21..5f064aee71b 100644 --- a/cirq-google/cirq_google/api/v2/program_pb2.pyi +++ b/cirq-google/cirq_google/api/v2/program_pb2.pyi @@ -1182,6 +1182,51 @@ class ArgMapping(google.protobuf.message.Message): global___ArgMapping = ArgMapping +@typing.final +class FunctionInterpolation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + X_VALUES_FIELD_NUMBER: builtins.int + Y_VALUES_FIELD_NUMBER: builtins.int + @property + def x_values(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: + """The x_values must be sorted in ascending order. + The x_values and y_values must be of the same length. + The independent variable. + """ + + @property + def y_values(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: + """The dependent variable.""" + + def __init__( + self, + *, + x_values: collections.abc.Iterable[builtins.float] | None = ..., + y_values: collections.abc.Iterable[builtins.float] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["x_values", b"x_values", "y_values", b"y_values"]) -> None: ... + +global___FunctionInterpolation = FunctionInterpolation + +@typing.final +class CustomArg(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_INTERPOLATION_DATA_FIELD_NUMBER: builtins.int + @property + def function_interpolation_data(self) -> global___FunctionInterpolation: ... + def __init__( + self, + *, + function_interpolation_data: global___FunctionInterpolation | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["custom_arg", b"custom_arg", "function_interpolation_data", b"function_interpolation_data"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["custom_arg", b"custom_arg", "function_interpolation_data", b"function_interpolation_data"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["custom_arg", b"custom_arg"]) -> typing.Literal["function_interpolation_data"] | None: ... + +global___CustomArg = CustomArg + @typing.final class InternalGate(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1204,10 +1249,29 @@ class InternalGate(google.protobuf.message.Message): def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + @typing.final + class CustomArgsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + @property + def value(self) -> global___CustomArg: ... + def __init__( + self, + *, + key: builtins.str = ..., + value: global___CustomArg | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + NAME_FIELD_NUMBER: builtins.int MODULE_FIELD_NUMBER: builtins.int NUM_QUBITS_FIELD_NUMBER: builtins.int GATE_ARGS_FIELD_NUMBER: builtins.int + CUSTOM_ARGS_FIELD_NUMBER: builtins.int name: builtins.str """Gate name.""" module: builtins.str @@ -1218,6 +1282,13 @@ class InternalGate(google.protobuf.message.Message): def gate_args(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___Arg]: """Gate args.""" + @property + def custom_args(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CustomArg]: + """Custom args are arguments that require special processing during deserialization. + The `key` is the argument in the internal class's constructor, the `value` + is a representation from which an internal object can be constructed. + """ + def __init__( self, *, @@ -1225,8 +1296,9 @@ class InternalGate(google.protobuf.message.Message): module: builtins.str = ..., num_qubits: builtins.int = ..., gate_args: collections.abc.Mapping[builtins.str, global___Arg] | None = ..., + custom_args: collections.abc.Mapping[builtins.str, global___CustomArg] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["gate_args", b"gate_args", "module", b"module", "name", b"name", "num_qubits", b"num_qubits"]) -> None: ... + def ClearField(self, field_name: typing.Literal["custom_args", b"custom_args", "gate_args", b"gate_args", "module", b"module", "name", b"name", "num_qubits", b"num_qubits"]) -> None: ... global___InternalGate = InternalGate diff --git a/cirq-google/cirq_google/ops/internal_gate.py b/cirq-google/cirq_google/ops/internal_gate.py index 258ae2dda9d..021d5090045 100644 --- a/cirq-google/cirq_google/ops/internal_gate.py +++ b/cirq-google/cirq_google/ops/internal_gate.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Mapping + +import numpy as np + from cirq import ops, value +from cirq_google.api.v2 import program_pb2 @value.value_equality @@ -25,39 +30,57 @@ class InternalGate(ops.Gate): constructor stored in `self.gate_args`. """ - def __init__(self, gate_name: str, gate_module: str, num_qubits: int = 1, **kwargs): + def __init__( + self, + gate_name: str, + gate_module: str, + num_qubits: int = 1, + custom_args: Optional[Mapping[str, program_pb2.CustomArg]] = None, + **kwargs, + ): """Instatiates an InternalGate. Arguments: gate_name: Gate class name. gate_module: The module of the gate. num_qubits: Number of qubits that the gate acts on. + custom_args: A mapping from argument name to `CustomArg`. + This is to support argument that require special processing. **kwargs: The named arguments to be passed to the gate constructor. """ self.gate_module = gate_module self.gate_name = gate_name self._num_qubits = num_qubits self.gate_args = kwargs + self.custom_args = custom_args or {} def _num_qubits_(self) -> int: return self._num_qubits def __str__(self): - gate_args = ', '.join(f'{k}={v}' for k, v in self.gate_args.items()) + gate_args = ', '.join(f'{k}={v}' for k, v in (self.gate_args | self.custom_args).items()) return f'{self.gate_module}.{self.gate_name}({gate_args})' def __repr__(self) -> str: gate_args = ', '.join(f'{k}={repr(v)}' for k, v in self.gate_args.items()) if gate_args != '': gate_args = ', ' + gate_args + + custom_args = '' + if self.custom_args: + custom_args = f", custom_args={self.custom_args}" + return ( f"cirq_google.InternalGate(gate_name='{self.gate_name}', " f"gate_module='{self.gate_module}', " f"num_qubits={self._num_qubits}" + f"{custom_args}" f"{gate_args})" ) def _json_dict_(self) -> Dict[str, Any]: + if self.custom_args: + raise ValueError('InternalGate with custom args are not json serializable') return dict( gate_name=self.gate_name, gate_module=self.gate_module, @@ -78,3 +101,51 @@ def _value_equality_values_(self): self._num_qubits, frozenset(self.gate_args.items()) if hashable else self.gate_args, ) + + +def function_points_to_proto( + x: Union[Sequence[float], np.ndarray], + y: Union[Sequence[float], np.ndarray], + msg: Optional[program_pb2.CustomArg] = None, +) -> program_pb2.CustomArg: + """Return CustomArg that expresses a function through its x and y values. + + Args: + x: Sequence of values of the free variable. + For 1D functions, this input is assumed to be given in increasing order. + y: Sequence of values of the dependent variable. + Where y[i] = func(x[i]) where `func` is the function being encoded. + msg: Optional CustomArg to serialize to. + If not provided a CustomArg is created. + + Returns: + A CustomArg encoding the function. + + Raises: + ValueError: If + - `x` is 1D and not sorted in increasing order. + - `x` and `y` don't have the same number of points. + - `y` is multidimensional. + - `x` is multidimensional. + """ + + x = np.asarray(x) + y = np.asarray(y) + + if len(x.shape) != 1: + raise ValueError('The free variable must be one dimensional') + + if len(x.shape) == 1 and not np.all(np.diff(x) > 0): + raise ValueError('The free variable must be sorted in increasing order') + + if len(y.shape) != 1: + raise ValueError('The dependent variable must be one dimensional') + + if x.shape[0] != y.shape[0]: + raise ValueError('Mismatch between number of points in x and y') + + if msg is None: + msg = program_pb2.CustomArg() + msg.function_interpolation_data.x_values[:] = x + msg.function_interpolation_data.y_values[:] = y + return msg diff --git a/cirq-google/cirq_google/ops/internal_gate_test.py b/cirq-google/cirq_google/ops/internal_gate_test.py index ce14ae294a3..0768a70aeee 100644 --- a/cirq-google/cirq_google/ops/internal_gate_test.py +++ b/cirq-google/cirq_google/ops/internal_gate_test.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import cirq import cirq_google import pytest +from cirq_google.ops import internal_gate +from cirq_google.serialization import arg_func_langs def test_internal_gate(): @@ -67,3 +71,66 @@ def test_internal_gate_with_hashable_args_is_hashable(): ) with pytest.raises(TypeError, match="unhashable"): _ = hash(unhashable) + + +def test_internal_gate_with_custom_function_repr(): + x = np.linspace(-1, 1, 10) + y = x**2 + encoded_func = internal_gate.function_points_to_proto(x=x, y=y) + + gate = internal_gate.InternalGate( + gate_name='GateWithFunction', + gate_module='test', + num_qubits=2, + custom_args={'func': encoded_func}, + ) + + assert repr(gate) == ( + "cirq_google.InternalGate(gate_name='GateWithFunction', " + f"gate_module='test', num_qubits=2, custom_args={gate.custom_args})" + ) + + assert str(gate) == (f"test.GateWithFunction(func={encoded_func})") + + with pytest.raises(ValueError): + _ = cirq.to_json(gate) + + +def test_internal_gate_with_custom_function_round_trip(): + original_func = lambda x: x**2 + x = np.linspace(-1, 1, 10) + y = original_func(x) + encoded_func = internal_gate.function_points_to_proto(x=x, y=y) + + gate = internal_gate.InternalGate( + gate_name='GateWithFunction', + gate_module='test', + num_qubits=2, + custom_args={'func': encoded_func}, + ) + + msg = arg_func_langs.internal_gate_arg_to_proto(gate) + + new_gate = arg_func_langs.internal_gate_from_proto(msg, arg_func_langs.MOST_PERMISSIVE_LANGUAGE) + + func_proto = new_gate.custom_args['func'].function_interpolation_data + + np.testing.assert_allclose(x, func_proto.x_values) + np.testing.assert_allclose(y, func_proto.y_values) + + +def test_function_points_to_proto_invalid_args_raise(): + x = np.linspace(-1, 1, 10) + y = x + 1 + + with pytest.raises(ValueError, match='The free variable must be one dimensional'): + _ = internal_gate.function_points_to_proto(np.zeros((10, 2)), y) + + with pytest.raises(ValueError, match='sorted in increasing order'): + _ = internal_gate.function_points_to_proto(x[::-1], y) + + with pytest.raises(ValueError, match='Mismatch between number of points in x and y'): + _ = internal_gate.function_points_to_proto(x, np.linspace(-1, 1, 40)) + + with pytest.raises(ValueError, match='The dependent variable must be one dimensional'): + _ = internal_gate.function_points_to_proto(x, np.zeros((10, 2))) diff --git a/cirq-google/cirq_google/serialization/arg_func_langs.py b/cirq-google/cirq_google/serialization/arg_func_langs.py index cc386553124..56cee8cab70 100644 --- a/cirq-google/cirq_google/serialization/arg_func_langs.py +++ b/cirq-google/cirq_google/serialization/arg_func_langs.py @@ -432,9 +432,13 @@ def internal_gate_arg_to_proto( msg.name = value.gate_name msg.module = value.gate_module msg.num_qubits = value.num_qubits() + for k, v in value.gate_args.items(): arg_to_proto(value=v, out=msg.gate_args[k]) + for ck, cv in value.custom_args.items(): + msg.custom_args[ck].MergeFrom(cv) + return msg @@ -461,6 +465,7 @@ def internal_gate_from_proto( gate_name=str(msg.name), gate_module=str(msg.module), num_qubits=int(msg.num_qubits), + custom_args=msg.custom_args, **gate_args, )