Skip to content

Commit

Permalink
small clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Farquhar13 committed Feb 23, 2024
1 parent 765c603 commit 05a0d90
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 2 additions & 3 deletions scripts/run_and_save_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ray
from ray.rllib.algorithms.ddpg import DDPGConfig
from ray.tune.registry import register_env
from relaqs.environments.single_qubit_env import SingleQubitEnv
from relaqs.environments.noisy_single_qubit_env import NoisySingleQubitEnv
from relaqs.save_results import SaveResults
Expand Down Expand Up @@ -41,7 +40,6 @@ def run(env_class=SingleQubitEnv, n_training_iterations=1, save=True, plot=True)

# ---------------------> Train Agent <-------------------------
results = [alg.train() for _ in range(n_training_iterations)]
result = results[-1]
# ---------------------> Save Results <-------------------------
if save is True:
env = alg.workers.local_worker().env
Expand All @@ -53,13 +51,14 @@ def run(env_class=SingleQubitEnv, n_training_iterations=1, save=True, plot=True)
# ---------------------> Plot Data <-------------------------
if plot is True:
assert save is True, "If plot=True, then save must also be set to True"
print("epiosde length", alg._episode_history[0].episode_length)
plot_data(save_dir, episode_length=alg._episode_history[0].episode_length)
print("Plots Created")
# --------------------------------------------------------------

if __name__ == "__main__":
env_class = NoisySingleQubitEnv
n_training_iterations = 50
n_training_iterations = 1
save = True
plot = True
run(env_class, n_training_iterations, save, plot)
Expand Down
5 changes: 3 additions & 2 deletions src/relaqs/environments/single_qubit_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def hamiltonian(self, delta, alpha, gamma_magnitude, gamma_phase):
return (delta + alpha) * Z + gamma_magnitude * (np.cos(gamma_phase) * X + np.sin(gamma_phase) * Y)

def reset(self, *, seed=None, options=None):
print("resetting")
self.U = self.U_initial.copy()
starting_observeration = self.get_observation()
self.state = self.get_observation()
Expand All @@ -84,6 +85,7 @@ def reset(self, *, seed=None, options=None):
self.prev_fidelity = 0
info = {}
self.episode_id += 1
print("episode id: ", self.episode_id)
return starting_observeration, info

def hamiltonian_update(self, alpha, gamma_magnitude, gamma_phase):
Expand All @@ -108,8 +110,6 @@ def is_episode_over(self, fidelity):
truncated = True # truncated when target fidelity reached
elif (self.current_Haar_num >= self.num_Haar_basis) and (self.current_step_per_Haar >= self.steps_per_Haar): # terminate when all Haar is tested
terminated = True
else:
terminated = False
return truncated, terminated

def Haar_update(self):
Expand Down Expand Up @@ -167,6 +167,7 @@ def step(self, action):
self.Haar_update()

truncated, terminated = self.is_episode_over(fidelity)
print("truncated: ", truncated, "terminated: ", terminated)

info = {}
return (self.state, reward, terminated, truncated, info)

0 comments on commit 05a0d90

Please sign in to comment.