-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdynamics.py
1397 lines (1241 loc) · 49.6 KB
/
dynamics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
tensorflow/dynamics.py
Tensorflow implementation of Dynamics object for training L2HMC sampler.
"""
from __future__ import absolute_import, annotations, division, print_function
from dataclasses import dataclass
from math import pi
import os
import re
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Tuple
import logging
# from copy import deepcopy
import numpy as np
import tensorflow as tf
from tensorflow.python.framework.ops import IndexedSlices
from l2hmc import configs as cfgs
from l2hmc.network.tensorflow.network import dummy_network, NetworkFactory
from l2hmc.group.u1.tensorflow.group import U1Phase
from l2hmc.group.su3.tensorflow.group import SU3
from l2hmc.lattice.u1.tensorflow.lattice import LatticeU1
from l2hmc.lattice.su3.tensorflow.lattice import LatticeSU3
Tensor = tf.Tensor
Model = tf.keras.Model
TensorLike = tf.types.experimental.TensorLike
TF_FLOAT = tf.dtypes.as_dtype(tf.keras.backend.floatx())
PI = tf.constant(pi, dtype=TF_FLOAT)
TWO = tf.constant(2., dtype=TF_FLOAT)
TWO_PI = TWO * PI
log = logging.getLogger(__name__)
DynamicsOutput = Tuple[TensorLike, dict]
def to_u1(x: Tensor) -> Tensor:
return (tf.add(x, PI) % TWO_PI) - PI
@dataclass
class State:
x: Tensor
v: Tensor
beta: Tensor
def flatten(self):
x = tf.reshape(self.x, (self.x.shape[0], -1))
v = tf.reshape(self.v, (self.v.shape[0], -1))
beta = tf.constant(tf.cast(self.beta, x.dtype))
return State(x, v, beta)
def __post_init__(self):
assert isinstance(self.x, Tensor)
assert isinstance(self.v, Tensor)
assert isinstance(self.beta, Tensor)
def to_numpy(self):
if tf.executing_eagerly():
return {
'x': self.x.numpy(), # type:ignore
'v': self.v.numpy(), # type:ignore
'beta': self.beta.numpy(), # type:ignore
}
return {
'x': self.x,
'v': self.v,
'beta': self.beta,
}
@dataclass
class MonteCarloStates:
init: State
proposed: State
out: State
@dataclass
class MonteCarloProposal:
init: State
proposed: State
def xy_repr(x: Tensor) -> Tensor:
return tf.stack([tf.math.cos(x), tf.math.sin(x)], axis=-1)
def sigmoid(x: Tensor | Any) -> Tensor:
return 1. / (1. + tf.exp(tf.negative(x)))
class Dynamics(Model):
def __init__(
self,
potential_fn: Callable,
config: cfgs.DynamicsConfig,
network_factory: Optional[NetworkFactory] = None,
):
"""Initialization."""
super(Dynamics, self).__init__()
self.config = config
self.group = config.group
self.xdim = self.config.xdim
self.xshape = self.config.xshape
self.potential_fn = potential_fn
self.nlf = self.config.nleapfrog
if self.config.group == 'U1':
self.g = U1Phase()
self.lattice = LatticeU1(self.config.nchains,
self.config.latvolume)
elif self.config.group == 'SU3':
self.g = SU3()
self.lattice = LatticeSU3(self.config.nchains,
self.config.latvolume)
else:
raise ValueError('Unexpected value for `self.config.group`')
assert isinstance(self.g, (U1Phase, SU3))
self.network_factory = network_factory
if network_factory is not None:
self._networks_built = True
self.networks = self._build_networks(network_factory)
self.xnet = self.networks['xnet']
self.vnet = self.networks['vnet']
# self.networks = {'xnet': self.xnet, 'vnet': self.vnet}
else:
self._networks_built = False
self.xnet = dummy_network
self.vnet = dummy_network
self.networks = {
'xnet': self.xnet,
'vnet': self.vnet
}
self.masks = self._build_masks()
self.xeps = []
self.veps = []
ekwargs = {
# 'dtype': TF_FLOAT,
'initial_value': self.config.eps,
'trainable': (not self.config.eps_fixed),
'constraint': tf.keras.constraints.non_neg(),
}
for lf in range(self.config.nleapfrog):
xalpha = tf.Variable(name=f'xeps_lf{lf}', **ekwargs)
valpha = tf.Variable(name=f'veps_lf{lf}', **ekwargs)
self.xeps.append(xalpha)
self.veps.append(valpha)
def get_models(self) -> dict:
if self.config.use_separate_networks:
xnet = {}
vnet = {}
for lf in range(self.config.nleapfrog):
vnet[str(lf)] = self._get_vnet(lf)
xnet[f'{lf}/first'] = self._get_xnet(lf, first=True)
if self.config.use_split_xnets:
xnet[f'{lf}/second'] = self._get_xnet(lf, first=False)
# xnet[str(lf)] = {
# '0': self._get_xnet(lf, first=True),
# '1': self._get_xnet(lf, first=False),
# }
else:
xnet[str(lf)] = self._get_xnet(lf, first=True)
else:
vnet = self._get_vnet(0)
xnet = self._get_xnet(0, first=True)
if self.config.use_split_xnets:
xnet1 = self._get_xnet(0, first=False)
xnet = {
'first': xnet,
'second': xnet1,
}
# xnet = {
# '0': self._get_xnet(0, first=True),
# '1': self._get_xnet(0, first=False),
# }
else:
xnet = self._get_xnet(0, first=True)
return {'xnet': xnet, 'vnet': vnet}
def get_weights_dict(self) -> dict:
weights = {}
if self.config.use_separate_networks:
for lf in range(self.config.nleapfrog):
vnet = self._get_vnet(lf)
weights |= vnet.get_weights_dict()
# weights.update(vnet.get_weights_dict())
xnet0 = self._get_xnet(lf, first=True)
# weights.update(xnet0.get_weights_dict())
weights |= xnet0.get_weights_dict()
if self.config.use_split_xnets:
xnet1 = self._get_xnet(lf, first=False)
# weights.update(xnet1.get_weights_dict())
weights |= xnet1.get_weights_dict()
else:
vnet = self._get_vnet(0)
weights = vnet.get_weights_dict()
xnet0 = self._get_xnet(0, first=True)
weights.update(xnet0.get_weights_dict())
if self.config.use_split_xnets:
xnet1 = self._get_xnet(0, first=False)
weights.update(xnet1.get_weights_dict())
weights = {
f'model/{k}': v for k, v in weights.items()
}
return weights
def _build_networks(
self,
network_factory: NetworkFactory
) -> dict:
"""Build networks."""
split = self.config.use_split_xnets
n = self.nlf if self.config.use_separate_networks else 1
# return networks['xnet'], networks['vnet']
# return networks
return network_factory.build_networks(
n,
split,
group=self.g,
)
def call(
self,
inputs: tuple[Tensor, Tensor],
training: bool = True
) -> tuple[Tensor, dict]:
"""Call Dynamics object.
Args:
inputs: Pair of inputs: (x, β) to use for generating new state x'.
training (bool): Indicates training or evaluation of model
"""
if self.config.merge_directions:
return self.apply_transition_fb(inputs, training=training)
return self.apply_transition(inputs, training=training)
@staticmethod
def flatten(x: Tensor | IndexedSlices | Any) -> Tensor:
return tf.reshape(x, (x.shape[0], -1))
def random_state(self, beta: float = 1.) -> State:
"""Returns a random State."""
x = self.flatten(self.g.random(list(self.xshape)))
v = self.flatten(self.g.random_momentum(list(self.xshape)))
return State(x=x, v=v, beta=tf.constant(beta, dtype=TF_FLOAT))
def test_reversibility(self) -> dict[str, Tensor]:
"""Test reversibility i.e. backward(forward(state)) = state"""
state = self.random_state(beta=1.)
state_fwd, _ = self.transition_kernel(state, forward=True)
state_, _ = self.transition_kernel(state_fwd, forward=False)
dx = tf.abs(tf.subtract(state.x, state_.x))
dv = tf.abs(tf.subtract(state.v, state_.v))
return {'dx': dx, 'dv': dv}
def apply_transition_hmc(
self,
inputs: tuple[Tensor, Tensor],
eps: Optional[float] = None,
nleapfrog: Optional[int] = None,
) -> tuple[Tensor, dict]:
data = self.generate_proposal_hmc(inputs, eps, nleapfrog=nleapfrog)
ma_, mr_ = self._get_accept_masks(data['metrics']['acc'])
ma_ = tf.constant(ma_, dtype=TF_FLOAT)
mr_ = tf.constant(mr_, dtype=TF_FLOAT)
ma = ma_[:, None]
xinit = self.flatten(data['init'].x)
vinit = self.flatten(data['init'].v)
xprop = self.flatten(data['proposed'].x)
vprop = self.flatten(data['proposed'].v)
vout = tf.where(tf.cast(ma, bool), vprop, vinit)
xout = tf.where(tf.cast(ma, bool), xprop, xinit)
state_out = State(x=xout, v=vout, beta=data['init'].beta)
mc_states = MonteCarloStates(init=data['init'],
proposed=data['proposed'],
out=state_out)
data['metrics'].update({
'acc_mask': ma_,
'mc_states': mc_states,
})
return xout, data['metrics']
def apply_transition_fb(
self,
inputs: tuple[Tensor, Tensor],
training: bool = True,
) -> tuple[Tensor, dict]:
"""Apply transition using single forward/backward update."""
data = self.generate_proposal_fb(inputs, training=training)
ma_, _ = self._get_accept_masks(data['metrics']['acc'])
ma = ma_[:, None]
vprop = self.flatten(data['proposed'].v)
xprop = self.flatten(data['proposed'].x)
v_out = tf.where(
tf.cast(ma, bool),
vprop,
tf.cast(self.flatten(data['init'].v), vprop.dtype)
)
x_out = tf.where(
tf.cast(ma, bool),
xprop,
tf.cast(self.flatten(data['init'].x), xprop.dtype)
)
sld = data['metrics']['sumlogdet']
sumlogdet = tf.cast(ma_, sld.dtype) * sld
state_out = State(x=x_out, v=v_out, beta=data['init'].beta)
mc_states = MonteCarloStates(
init=data['init'],
proposed=data['proposed'],
out=state_out
)
data['metrics'].update({
'acc_mask': ma_,
'sumlogdet': sumlogdet,
'mc_states': mc_states,
})
return x_out, data['metrics']
def apply_transition(
self,
inputs: tuple[Tensor, Tensor],
training: bool = True
) -> tuple[Tensor, dict]:
"""Apply transition using masks to combine forward/backward updates."""
x, beta = inputs
fwd = self.generate_proposal(inputs, forward=True, training=training)
bwd = self.generate_proposal(inputs, forward=False, training=training)
# assert isinstance(x, Tensor)
mf_, mb_ = self._get_direction_masks(batch_size=x.shape[0])
mf = mf_[:, None]
mb = mb_[:, None]
x_init = tf.where(tf.cast(mf, bool), fwd['init'].x, bwd['init'].x)
v_init = tf.where(tf.cast(mf, bool), fwd['init'].v, bwd['init'].v)
x_prop = tf.where(
tf.cast(mf, bool),
fwd['proposed'].x,
bwd['proposed'].x
)
v_prop = tf.where(
tf.cast(mf, bool),
fwd['proposed'].v,
bwd['proposed'].v
)
mfwd = fwd['metrics']
mbwd = bwd['metrics']
logdet_prop = tf.where(
tf.cast(mf_, bool),
mfwd['sumlogdet'],
mbwd['sumlogdet']
)
acc = mf_ * mfwd['acc'] + mb_ * mbwd['acc']
ma_, _ = self._get_accept_masks(acc)
ma = ma_[:, None]
v_out = tf.where(
tf.cast(ma, bool),
v_prop,
v_init
)
x_out = tf.where(
tf.cast(ma, bool),
x_prop,
x_init,
)
sumlogdet = tf.where(
tf.cast(ma_, bool),
logdet_prop,
tf.zeros_like(logdet_prop)
)
init = State(x=x, v=v_init, beta=beta)
prop = State(x=x_prop, v=v_prop, beta=beta)
out = State(x=x_out, v=v_out, beta=beta)
mc_states = MonteCarloStates(init=init, proposed=prop, out=out)
metrics = {}
for (key, vf), (_, vb) in zip(mfwd.items(), mbwd.items()):
try:
vfb = ma_ * (mf_ * vf + mb_ * vb) # + mr_ * v0
except ValueError:
vfb = ma * (mf * vf + mb * vb) # + mr * v0
metrics[key] = vfb
metrics.update({
'acc': acc,
'acc_mask': ma_,
'sumlogdet': sumlogdet,
'mc_states': mc_states,
})
return x_out, metrics
def generate_proposal_hmc(
self,
inputs: tuple[Tensor, Tensor],
eps: Optional[float] = None,
nleapfrog: Optional[int] = None,
) -> dict:
x, beta = inputs
assert isinstance(x, Tensor)
xshape = [x.shape[0], *self.xshape[1:]]
v = self.g.random_momentum(xshape)
init = State(x, v, beta)
proposed, metrics = self.transition_kernel_hmc(init,
eps=eps,
nleapfrog=nleapfrog)
return {'init': init, 'proposed': proposed, 'metrics': metrics}
def generate_proposal_fb(
self,
inputs: tuple[Tensor, Tensor],
training: bool = True,
) -> dict:
"""Generate proposal using single forward/backward update.
Inputs:
inputs: Tuple of (x, beta)
training: Currently training model?
Returns dict of 'init', and 'proposed' states, along with 'metrics'.
"""
x, beta = inputs
assert isinstance(x, Tensor)
xshape = [x.shape[0], *self.xshape[1:]]
v = self.flatten(self.g.random_momentum(xshape))
init = State(x, v, beta)
proposed, metrics = self.transition_kernel_fb(init, training=training)
return {'init': init, 'proposed': proposed, 'metrics': metrics}
def generate_proposal(
self,
inputs: tuple[Tensor, Tensor],
forward: bool,
training: bool = True,
) -> dict:
"""Generate proposal using direction specified by 'forward'.
Returns dict of 'init', and 'proposed' states, along with 'metrics'.
"""
x, beta = inputs
assert isinstance(x, Tensor)
xshape = [x.shape[0], *self.xshape[1:]]
v = self.flatten(self.g.random_momentum(xshape))
state_init = State(x=x, v=v, beta=beta)
state_prop, metrics = self.transition_kernel(state_init,
forward=forward,
training=training)
return {'init': state_init, 'proposed': state_prop, 'metrics': metrics}
def get_metrics(
self,
state: State,
logdet: Tensor,
step: Optional[int] = None,
extras: Optional[dict[str, Tensor]] = None,
) -> dict:
"""Returns dict of various metrics about input State."""
energy = self.hamiltonian(state)
logprob = tf.subtract(energy, tf.cast(logdet, energy.dtype))
metrics = {
'energy': energy,
'logprob': logprob,
'logdet': logdet,
}
if extras is not None:
metrics.update(extras)
if step is not None:
metrics.update({
'xeps': self.xeps[step],
'veps': self.veps[step],
})
return metrics
def update_history(
self,
metrics: dict,
history: dict,
) -> dict:
"""Update history with items from metrics."""
for key, val in metrics.items():
try:
history[key].append(val)
except KeyError:
history[key] = [val]
return history
def leapfrog_hmc(
self,
state: State,
eps: float | tf.Tensor,
) -> State:
"""Perform standard HMC leapfrog update."""
x = tf.reshape(state.x, state.v.shape)
force = self.grad_potential(x, state.beta) # f = dU / dx
eps = tf.constant(eps, dtype=force.dtype)
# halfeps = tf.cast(tf.scalar_mul(0.5, eps), dtype=force.dtype)
# halfeps = tf.scalar_mul(0.5, eps)
# halfeps = tf.constant(eps / 2.0, dtype=force.dtype)
halfeps = 0.5 * eps # type:ignore
v = state.v - halfeps * force
x = self.g.update_gauge(x, eps * v)
force = self.grad_potential(x, state.beta) # calc force, again
v -= halfeps * force
return State(x=x, v=v, beta=state.beta) # output: (x', v')
def transition_kernel_hmc(
self,
state: State,
eps: Optional[float] = None,
nleapfrog: Optional[int] = None,
) -> tuple[State, dict]:
"""Run the generic HMC transition kernel."""
state_ = State(x=state.x, v=state.v, beta=state.beta)
assert isinstance(state.x, Tensor)
sumlogdet = tf.zeros((state.x.shape[0],), dtype=state.x.dtype)
history = {}
if self.config.verbose:
history = self.update_history(
self.get_metrics(state_, sumlogdet),
history={}
)
eps = self.config.eps_hmc if eps is None else eps
nlf = (
2 * self.config.nleapfrog if self.config.merge_directions
else self.config.nleapfrog
)
assert nlf <= 2 * self.config.nleapfrog
nleapfrog = nlf if nleapfrog is None else nleapfrog
for _ in range(nleapfrog):
state_ = self.leapfrog_hmc(state_, eps=eps)
if self.config.verbose:
history = self.update_history(
self.get_metrics(state_, sumlogdet),
history=history,
)
acc = self.compute_accept_prob(state, state_, sumlogdet)
history.update({'acc': acc, 'sumlogdet': sumlogdet})
if self.config.verbose:
for key, val in history.items():
if isinstance(val, list) and isinstance(val[0], Tensor):
history[key] = tf.stack(val)
return state_, history
def transition_kernel_fb(
self,
state: State,
training: bool = True,
) -> tuple[State, dict]:
"""Run the transition kernel using single forward/backward update.
Returns:
tuple of output state, and history of metrics tracked during traj.
"""
state_ = State(state.x, state.v, state.beta)
assert isinstance(state.x, Tensor)
sumlogdet = tf.zeros(
(state.x.shape[0],),
# dtype=tf.math.real(state.x).dtype
)
sldf = tf.zeros_like(sumlogdet)
sldb = tf.zeros_like(sumlogdet)
history = {}
if self.config.verbose:
extras = {
'sldf': sldf,
'sldb': sldb,
# 'sldfb': sldf + sldb,
'sld': sumlogdet,
}
history = self.update_history(
self.get_metrics(state_, sumlogdet, step=0, extras=extras),
history=history,
)
# Forward
for step in range(self.config.nleapfrog):
state_, logdet = self._forward_lf(step, state_, training)
logdet = tf.cast(logdet, sumlogdet.dtype)
sumlogdet = sumlogdet + logdet
if self.config.verbose:
sldf = sldf + logdet
extras = {
'sldf': sldf,
'sldb': sldb,
# 'sldfb': sldf + sldb,
'sld': sumlogdet,
}
metrics = self.get_metrics(
state_,
sumlogdet,
step=step,
extras=extras
)
history = self.update_history(metrics=metrics, history=history)
# Flip momentum
state_ = State(state_.x, tf.negative(state_.v), state_.beta) # noqa
# Backward
for step in range(self.config.nleapfrog):
state_, logdet = self._backward_lf(step, state_, training)
logdet = tf.cast(logdet, sumlogdet.dtype)
sumlogdet = sumlogdet + logdet
if self.config.verbose:
# sldb += logdet
sldb = sldb + logdet
extras = {
'sldf': tf.zeros_like(sldb),
'sldb': sldb,
# 'sldfb': sldf + sldb,
'sld': sumlogdet,
}
# Reverse step count to correctly order metrics
metrics = self.get_metrics(
state_,
sumlogdet,
step=(self.config.nleapfrog - step - 1),
extras=extras
)
history = self.update_history(metrics=metrics, history=history)
acc = self.compute_accept_prob(state, state_, sumlogdet)
history.update({'acc': acc, 'sumlogdet': sumlogdet})
if self.config.verbose:
for key, val in history.items():
if isinstance(val, list) and isinstance(val[0], Tensor):
history[key] = tf.stack(val)
return state_, history
def transition_kernel(
self,
state: State,
forward: bool,
training: bool = True,
) -> tuple[State, dict]:
"""Implements the directional transition kernel.
Returns:
tuple of output state, and history of metrics tracked during traj.
"""
lf_fn = self._forward_lf if forward else self._backward_lf
# Copy initial state into proposed state
state_ = State(x=state.x, v=state.v, beta=state.beta)
assert isinstance(state.x, Tensor)
sumlogdet = tf.zeros((state.x.shape[0],))
history = {}
if self.config.verbose:
metrics = self.get_metrics(state_, sumlogdet)
history = self.update_history(metrics, history=history)
for step in range(self.config.nleapfrog):
state_, logdet = lf_fn(step, state_, training)
sumlogdet = sumlogdet + logdet
if self.config.verbose:
metrics = self.get_metrics(state_, sumlogdet, step=step)
history = self.update_history(metrics, history=history)
acc = self.compute_accept_prob(state, state_, sumlogdet)
history.update({'acc': acc, 'sumlogdet': sumlogdet})
if self.config.verbose:
for key, val in history.items():
if isinstance(val, list) and isinstance(val[0], Tensor):
history[key] = tf.stack(val) # type: ignore
return state_, history
def compute_accept_prob(
self,
state_init: State,
state_prop: State,
sumlogdet: Tensor,
) -> Tensor:
"""Compute the acceptance probability."""
h_init = self.hamiltonian(state_init)
h_prop = self.hamiltonian(state_prop)
dh = tf.add(
tf.subtract(h_init, h_prop),
tf.cast(sumlogdet, h_init.dtype)
)
# dh = h_init - h_prop + sumlogdet
prob = tf.exp(tf.minimum(dh, tf.zeros_like(dh)))
return tf.where(tf.math.is_finite(prob), prob, tf.zeros_like(prob))
@staticmethod
def _get_accept_masks(px: Tensor) -> tuple:
"""Convert acceptance probability to binary mask of accept/rejects."""
acc = tf.cast(
px > tf.random.uniform(tf.shape(px), dtype=TF_FLOAT),
dtype=TF_FLOAT,
)
rej = tf.ones_like(acc) - acc
return (acc, rej)
def _get_direction_masks(self, batch_size) -> tuple:
"""Get masks for combining forward/backward updates."""
fwd = tf.cast(
tf.random.uniform((batch_size,), dtype=TF_FLOAT) > 0.5,
dtype=TF_FLOAT,
)
bwd = tf.ones_like(fwd) - fwd
return fwd, bwd
def _get_mask(self, i: int) -> tuple[Tensor, Tensor]:
"""Returns mask used for sequentially updating x."""
m = self.masks[i]
mb = tf.ones_like(m) - m
return (m, mb)
def _build_masks(self):
"""Construct different binary masks for different lf steps."""
masks = []
for _ in range(self.config.nleapfrog):
# Need to use numpy.random here bc tf would generate different
# random values across different calls
_idx = np.arange(self.xdim)
idx = np.random.permutation(_idx)[:self.xdim // 2]
mask = np.zeros((self.xdim,))
mask[idx] = 1.
mask = tf.constant(mask, dtype=TF_FLOAT)
masks.append(mask[None, :])
return masks
def _get_vnet(self, step: int) -> Callable:
"""Returns momentum network to be used for updating v."""
if not self._networks_built:
return self.vnet
vnet = self.vnet
# assert isinstance(vnet, (dict, tf.keras.Model))
if self.config.use_separate_networks and isinstance(vnet, dict):
return vnet[str(step)]
# assert isinstance(vnet, (CallableNetwork))
return self.vnet
def _get_xnets(
self,
step: int,
) -> list:
xnets = [
self._get_xnet(step, first=True)
]
if self.config.use_separate_networks:
xnets.append(
self._get_xnet(step, first=False)
)
return xnets
def _get_all_xnets(self) -> list[Model]:
xnets = []
for step in range(self.config.nleapfrog):
nets = self._get_xnets(step)
for net in nets:
xnets.append(net)
return xnets
def _get_all_vnets(self) -> list:
return [
self._get_vnet(step)
for step in range(self.config.nleapfrog)
]
# for step in range(self.config.nleapfrog):
# nets = self._get_vnet(step)
@staticmethod
def rename_weight(
name: str,
sep: Optional[str] = None,
) -> str:
new_name = (
name.rstrip(':0').replace('kernel', 'weight')
)
new_name = re.sub(r'\_\d', '', new_name)
if sep is not None:
new_name.replace('.', '/')
new_name.replace('/', sep)
return new_name
def get_all_weights(self) -> dict:
xnets = self._get_all_xnets()
vnets = self._get_all_vnets()
weights = {}
for xnet in xnets:
weights.update({
f'{self.rename_weight(w.name)}': w
for w in xnet.weights
})
for vnet in vnets:
weights.update({
# self.format_weight_name(w.name): w
f'{self.rename_weight(w.name)}': w
for w in vnet.weights
})
return cfgs.flatten_dict(weights)
def _get_xnet(
self,
step: int,
first: bool
) -> Callable:
"""Returns position network to be used for updating x."""
if not self._networks_built:
return self.xnet
xnet = self.xnet
# assert isinstance(xnet, (tf.keras.Model, dict))
if self.config.use_separate_networks and isinstance(xnet, dict):
xnet = xnet[str(step)]
if self.config.use_split_xnets:
if first:
return xnet['first']
return xnet['second']
return xnet
return xnet
def _stack_as_xy(self, x: Tensor) -> Tensor:
"""Returns -pi < x <= pi stacked as [cos(x), sin(x)]"""
return tf.stack([tf.math.cos(x), tf.math.sin(x)], axis=-1)
def _call_vnet(
self,
step: int,
inputs: tuple[Tensor, Tensor], # (x, ∂S/∂x)
training: bool
) -> tuple[Tensor, Tensor, Tensor]:
"""Calls the momentum network used to update v.
Args:
inputs: (x, force) tuple
Returns:
s, t, q: Scaling, Translation, and Transformation functions
"""
x, force = inputs
if self.config.group == 'SU3':
x = self.group_to_vec(x)
force = self.group_to_vec(force)
vnet = self._get_vnet(step)
assert callable(vnet)
s, t, q = vnet((x, force), training)
# return (
# tf.cast(s, TF_FLOAT),
# tf.cast(t, TF_FLOAT),
# tf.cast(q, TF_FLOAT)
# )
return (s, t, q)
def _call_xnet(
self,
step: int,
inputs: tuple[Tensor, Tensor], # (m * x, v)
first: bool,
training: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
"""Call the position network used to update x.
Args:
inputs: (m * x, v) tuple, where (m * x) is a masking operation.
Returns:
s, t, q: Scaling, Translation, and Transformation functions
"""
x, v = inputs
assert isinstance(x, Tensor) and isinstance(v, Tensor)
xnet = self._get_xnet(step, first)
if self.config.group == 'U1':
x = self.g.group_to_vec(x)
elif self.config.group == 'SU3':
x = self.unflatten(x)
x = tf.stack([tf.math.real(x), tf.math.imag(x)], 1)
v = tf.stack([tf.math.real(v), tf.math.imag(v)], 1)
# s, t, q = xnet((x, v), training)
# return (s, t, q)
# return xnet((x, v), training=training)
s, t, q = xnet((x, v), training=training)
# return (
# tf.cast(s, TF_FLOAT),
# tf.cast(t, TF_FLOAT),
# tf.cast(q, TF_FLOAT)
# )
return (s, t, q)
def _forward_lf(
self,
step: int,
state: State,
training: bool = True,
) -> tuple[State, Tensor]:
"""Complete update (leapfrog step) in the forward direction."""
m, mb = self._get_mask(step)
# m = tf.cast(m, state.x.dtype)
# mb = tf.cast(mb, state.x.dtype)
# sumlogdet = tf.zeros((state.x.shape[0],), dtype=state.x.dtype)
# assert isinstance(state.x, Tensor)
# assert isinstance(m, Tensor) and isinstance(mb, Tensor)
# sumlogdet = sumlogdet + tf.cast(logdet, sumlogdet.dtype)
state, logdet = self._update_v_fwd(step, state, training=training)
sumlogdet = tf.zeros((state.x.shape[0],), dtype=logdet.dtype)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_x_fwd(step, state, m,
first=True, training=training)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_x_fwd(step, state, mb,
first=False, training=training)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_v_fwd(step, state, training=training)
sumlogdet = sumlogdet + logdet
return state, sumlogdet
def _backward_lf(
self,
step: int,
state: State,
training: bool = True,
) -> tuple[State, Tensor]:
"""Complete update (leapfrog step) in the backward direction."""
# Note: Reverse the step count, i.e. count from end of trajectory.
step_r = self.config.nleapfrog - step - 1
m, mb = self._get_mask(step_r)
# sumlogdet = tf.zeros((state.x.shape[0],), dtype=state.x.dtype)
# m = tf.cast(m, state.x.dtype)
# mb = tf.cast(mb, state.x.dtype)
# assert isinstance(m, Tensor) and isinstance(mb, Tensor)
# sumlogdet = sumlogdet + tf.cast(logdet, sumlogdet.dtype)
state, logdet = self._update_v_bwd(step_r, state, training=training)
sumlogdet = tf.zeros((state.x.shape[0],), dtype=logdet.dtype)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_x_bwd(step_r, state, mb,
first=False, training=training)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_x_bwd(step_r, state, m,
first=True, training=training)
sumlogdet = sumlogdet + logdet
state, logdet = self._update_v_bwd(step_r, state, training=training)
sumlogdet = sumlogdet + logdet
return state, sumlogdet
def unflatten(self, x: Tensor) -> Tensor:
return tf.reshape(x, (x.shape[0], *self.xshape[1:]))
def group_to_vec(self, x: Tensor) -> Tensor:
"""For x in SU(3), returns an 8-component real-valued vector"""
return self.g.group_to_vec(self.unflatten(x))
def vec_to_group(self, x: Tensor) -> Tensor:
if x.shape[1:] != self.xshape[1:]:
x = self.unflatten(x)
if self.config.group == 'SU3':
return self.g.vec_to_group(x)
xrT, xiT = tf.transpose(x)
return tf.complex(tf.transpose(xrT), tf.transpose(xiT))