Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ClExpr #176

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Changelog
~~~~~~~~~

Unreleased
----------

* Updated pytket version requirement to 1.34.
* Now supporting ``ClExpr`` operations (the new version of tket's ``ClassicalExpBox``).

0.10.0 (October 2024)
---------------------

Expand Down
126 changes: 126 additions & 0 deletions pytket/extensions/cutensornet/structured_state/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
SetBitsOp,
CopyBitsOp,
RangePredicateOp,
ClExprOp,
ClassicalExpBox,
LogicExp,
BitWiseOp,
RegWiseOp,
)
from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar


ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int]
Expand Down Expand Up @@ -56,6 +58,40 @@ def apply_classical_command(
# Check that the value is in the range
bits_dict[res_bit] = val >= op.lower and val <= op.upper

elif isinstance(op, ClExprOp):
# Convert bit_posn to dictionary of `ClBitVar` index to its value
bitvar_val = {
var_id: int(bits_dict[args[bit_pos]])
for var_id, bit_pos in op.expr.bit_posn.items()
}
# Convert reg_posn to dictionary of `ClRegVar` index to its value
regvar_val = {
var_id: from_little_endian(
[bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]
)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
# Identify number of bits on each register
regvar_size = {
var_id: len(reg_pos_list)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
# Identify number of bits in output register
output_size = len(op.expr.output_posn)
result = evaluate_clexpr(
op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size
)

# The result is an int in little-endian encoding. We update the
# output register accordingly.
for bit_pos in op.expr.output_posn:
bits_dict[args[bit_pos]] = (result % 2) == 1
result = result >> 1
# If there has been overflow in the operations, error out.
# This can be detected if `result != 0`
if result != 0:
raise ValueError("Evaluation of the ClExpr resulted in overflow.")

elif isinstance(op, ClassicalExpBox):
the_exp = op.get_exp()
result = evaluate_logic_exp(the_exp, bits_dict)
Expand All @@ -74,6 +110,95 @@ def apply_classical_command(
raise NotImplementedError(f"Commands of type {op.type} are not supported.")


def evaluate_clexpr(
expr: ClExpr,
bitvar_val: dict[int, int],
regvar_val: dict[int, int],
regvar_size: dict[int, int],
output_size: int,
) -> int:
"""Recursive evaluation of a ClExpr."""

# Evaluate arguments to operation
args_val = []
for arg in expr.args:
if isinstance(arg, int):
value = arg
elif isinstance(arg, ClBitVar):
value = bitvar_val[arg.index]
elif isinstance(arg, ClRegVar):
value = regvar_val[arg.index]
elif isinstance(arg, ClExpr):
value = evaluate_clexpr(
arg, bitvar_val, regvar_val, regvar_size, output_size
)
else:
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.")

args_val.append(value)

# Apply the operation at the root of this ClExpr
if expr.op in [ClOp.BitAnd, ClOp.RegAnd]:
result = args_val[0] & args_val[1]
elif expr.op in [ClOp.BitOr, ClOp.RegOr]:
result = args_val[0] | args_val[1]
elif expr.op in [ClOp.BitXor, ClOp.RegXor]:
result = args_val[0] ^ args_val[1]
elif expr.op in [ClOp.BitEq, ClOp.RegEq]:
result = int(args_val[0] == args_val[1])
elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]:
result = int(args_val[0] != args_val[1])
elif expr.op == ClOp.RegGeq:
result = int(args_val[0] >= args_val[1])
elif expr.op == ClOp.RegGt:
result = int(args_val[0] > args_val[1])
elif expr.op == ClOp.RegLeq:
result = int(args_val[0] <= args_val[1])
elif expr.op == ClOp.RegLt:
result = int(args_val[0] < args_val[1])
elif expr.op == ClOp.BitNot:
result = 1 - args_val[0]
elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits)
n_bits = regvar_size[expr.args[0].index] # type: ignore
result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring
elif expr.op in [ClOp.BitZero, ClOp.RegZero]:
result = 0
elif expr.op == ClOp.BitOne:
result = 1
elif expr.op == ClOp.RegOne: # All 1s bitstring
n_bits = output_size
result = 2**n_bits - 1
elif expr.op == ClOp.RegAdd:
result = args_val[0] + args_val[1]
elif expr.op == ClOp.RegSub:
if args_val[0] < args_val[1]:
raise NotImplementedError(
"Currently not supporting ClOp.RegSub where the outcome is negative."
)
result = args_val[0] - args_val[1]
elif expr.op == ClOp.RegMul:
result = args_val[0] * args_val[1]
elif expr.op == ClOp.RegDiv: # floor(a / b)
result = args_val[0] // args_val[1]
elif expr.op == ClOp.RegPow:
result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegLsh:
result = args_val[0] << args_val[1]
elif expr.op == ClOp.RegRsh:
result = args_val[0] >> args_val[1]
# elif expr.op == ClOp.RegNeg:
# result = -args_val[0]
else:
# TODO: Not supporting RegNeg because I do not know if we have agreed how to
# specify signed ints.
raise NotImplementedError(
f"Evaluation of {expr.op} not supported in ClExpr ",
"by pytket-cutensornet.",
)

return result


def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int:
"""Recursive evaluation of a LogicExp."""

Expand Down Expand Up @@ -132,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int
def from_little_endian(bitstring: list[bool]) -> int:
"""Obtain the integer from the little-endian encoded bitstring (i.e. bitstring
[False, True] is interpreted as the integer 2)."""
# TODO: Assumes unisigned integer. What are the specs for signed integers?
return sum(1 << i for i, b in enumerate(bitstring) if b)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
license="Apache 2",
packages=find_namespace_packages(include=["pytket.*"]),
include_package_data=True,
install_requires=["pytket >= 1.33.0", "networkx >= 2.8.8"],
install_requires=["pytket >= 1.34.0", "networkx >= 2.8.8"],
classifiers=[
"Environment :: Console",
"Programming Language :: Python :: 3.10",
Expand Down
Loading
Loading