Skip to content

Commit

Permalink
add reset + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
splch committed Aug 21, 2023
1 parent 0795c3e commit 6668f95
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
8 changes: 7 additions & 1 deletion qubit_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class QubitSimulator:
"""
A class that represents a quantum simulator.
A class that represents a qubit simulator.
"""

def __init__(self, num_qubits: int):
Expand Down Expand Up @@ -199,6 +199,12 @@ def run(
results = self.measure(shots, basis)
return dict(Counter(results))

def reset(self):
"""
Resets the simulator to its initial state.
"""
self.__init__(self.num_qubits)

def __str__(self) -> str:
"""
Returns a string representation of the circuit.
Expand Down
55 changes: 54 additions & 1 deletion tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,43 @@ def test_initial_state():
assert np.allclose(simulator.state_vector, [1, 0, 0, 0, 0, 0, 0, 0])


def test_initialization_complex_states():
simulator = QubitSimulator(2)
simulator.state_vector = [0.5, 0.5, 0.5, 0.5]
simulator.x(0)
assert np.allclose(simulator.state_vector, [0.5, 0.5, 0.5, 0.5])


def test_large_number_of_qubits():
num_qubits = 20
simulator = QubitSimulator(num_qubits)
assert len(simulator.state_vector) == 2**num_qubits


def test_circuit_reset():
simulator = QubitSimulator(1)
simulator.x(0)
simulator.reset()
assert np.allclose(simulator.state_vector, [1, 0])


def test_getsize():
simulator = QubitSimulator(3)
# Apply some gates to make the instance more complex
simulator.h(0)
simulator.u(1, np.pi / 4, np.pi / 4, np.pi / 2)
simulator.cx(1, 2)
simulator.u(0, np.pi / 4, np.pi / 4, np.pi / 2)
assert simulator.__getsize__() == 412


def test_getsize_relative():
simulator = QubitSimulator(2)
initial_size = simulator.__getsize__()
simulator.h(0)
simulator.cx(0, 1)
assert simulator.__getsize__() > initial_size


def test_zero_qubits():
simulator = QubitSimulator(0)
assert len(simulator.state_vector) == 1
Expand Down Expand Up @@ -62,6 +90,16 @@ def test_u_gate():
assert np.allclose(simulator.state_vector, expected_result)


@pytest.mark.parametrize(
"theta,phi,lambda_", [(0, 0, 0), (2 * np.pi, 2 * np.pi, 2 * np.pi)]
)
def test_u_gate_edge_cases(theta, phi, lambda_):
simulator = QubitSimulator(1)
simulator.u(0, theta, phi, lambda_)
# State vector should be |0⟩
assert np.allclose(simulator.state_vector, [1, 0])


def test_cx_gate():
simulator = QubitSimulator(2)
simulator.state_vector = [0, 0, 0, 1] # Set the initial state to |11⟩
Expand Down Expand Up @@ -127,6 +165,13 @@ def test_negative_shots(shots):
simulator.run(shots=shots) # Negative shots are invalid


def test_error_messages():
with pytest.raises(ValueError, match="Number of qubits must be non-negative."):
QubitSimulator(-1)
with pytest.raises(ValueError, match="Number of shots must be non-negative."):
QubitSimulator(1).measure(-1)


def test_measure_without_gates():
simulator = QubitSimulator(2)
results = simulator.run(shots=100)
Expand All @@ -144,6 +189,14 @@ def test_measure_custom_basis():
assert set(result) == {"0"}


def test_measure_custom_basis_valid():
simulator = QubitSimulator(1)
Z_basis = np.array([[1, 0], [0, -1]])
simulator.x(0)
result = simulator.measure(basis=Z_basis)
assert result == ["1"]


def test_invalid_basis_transformation():
simulator = QubitSimulator(1)
# Define an invalid basis transformation (not unitary)
Expand Down

0 comments on commit 6668f95

Please sign in to comment.