Skip to content

Commit

Permalink
Fix stochastic rounding and clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
Viren6 committed Apr 7, 2024
1 parent 7f9ce2d commit 7f0bfb0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions server/fishtest/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,8 +1505,8 @@ def purge_run(self, run, p=0.001, res=7.0, iters=1):
self.buffer(run, True)
return message

def spsa_param_clip(self, param, increment):
return min(max(param["theta"] + increment, param["min"]), param["max"])
def spsa_param_clip(self, param, increment, r):
return min(max(param["theta"] + increment + r, param["min"]), param["max"])

# Store SPSA parameters for each worker
spsa_params = {}
Expand Down Expand Up @@ -1556,15 +1556,15 @@ def generate_spsa(self, run):
r = random.uniform(0, 1)
flip = 1 if random.getrandbits(1) else -1
# Stochastic rounding and probability for float N.p: (N, 1-p); (N+1, p)
w_value = math.floor(self.spsa_param_clip(param, c * flip) + r)
b_value = math.floor(self.spsa_param_clip(param, -c * flip) + r)
w_value = math.floor(self.spsa_param_clip(param, c * flip, r))
b_value = math.floor(self.spsa_param_clip(param, -c * flip, r))
result["w_params"].append(
{
"name": param["name"],
"value": w_value,
"R": param["a"] / (spsa["A"] + iter_local) ** spsa["alpha"] / c**2,
#Set c to the real delta after stochastic rounding is applied
"c": abs(w_value - b_value),
# Set c to the real delta after stochastic rounding is applied
"c": abs(w_value - b_value) / 2,
"flip": flip,
}
)
Expand Down Expand Up @@ -1600,7 +1600,7 @@ def update_spsa(self, worker, run, spsa_results):
R = w_params[idx]["R"]
c = w_params[idx]["c"]
flip = w_params[idx]["flip"]
param["theta"] = self.spsa_param_clip(param, R * c * result * flip)
param["theta"] = self.spsa_param_clip(param, R * c * result * flip, 0)
if grow_summary:
summary.append({"theta": param["theta"], "R": R, "c": c})

Expand Down

0 comments on commit 7f0bfb0

Please sign in to comment.