diff --git a/gpu4pyscf/tdscf/_lr_eig.py b/gpu4pyscf/tdscf/_lr_eig.py index 13013143..7763cf6e 100644 --- a/gpu4pyscf/tdscf/_lr_eig.py +++ b/gpu4pyscf/tdscf/_lr_eig.py @@ -344,6 +344,8 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None, fresh_start = True for icyc in range(max_cycle): if fresh_start: + vlast = None + conv_last = conv = np.zeros(nroots, dtype=bool) xs = np.zeros((0, x0_size)) ax = np.zeros((0, x0_size)) row1 = 0 @@ -400,7 +402,7 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None, w, e, elast = w[:space_inc], w[:nroots], e v = v[:,:space_inc] - if not fresh_start: + if vlast is not None: elast, conv_last = _sort_elast(elast, conv, vlast, v[:,:nroots], log) vlast = v[:,:nroots] @@ -447,22 +449,30 @@ def eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=None, xt[:,:half_size] -= c.T.dot(xs[:,half_size:].conj()) xt[:,half_size:] -= c.T.dot(xs[:,:half_size].conj()) - if x0sym is None: - xt = _symmetric_orth(xt) - else: - xt_orth = [] - xt_orth_ir = [] - for ir in set(xt_ir): - idx = np.where(xt_ir == ir)[0] - xt_sub = _symmetric_orth(xt[idx]) - xt_orth.append(xt_sub) - xt_orth_ir.append([ir] * len(xt_sub)) - if xt_orth: - xt = np.vstack(xt_orth) - xs_ir = np.hstack([xs_ir, *xt_orth_ir]) + # Remove quasi linearly dependent bases, as they cause more numerical + # errors in _symmetric_orth + xt_norm = np.linalg.norm(xt, axis=1) + xt_to_keep = (dx_norm > tol_residual) & (xt_norm > max(lindep**.5, tol_residual)) + xt = xt[xt_to_keep] + if len(xt) > 0: + xt /= xt_norm[xt_to_keep, None] + if x0sym is None: + xt = _symmetric_orth(xt) else: - xt = [] - xt_orth = xt_orth_ir = xt_sub = None + xt_ir = xt_ir[xt_to_keep] + xt_orth = [] + xt_orth_ir = [] + for ir in set(xt_ir): + idx = np.where(xt_ir == ir)[0] + xt_sub = _symmetric_orth(xt[idx]) + xt_orth.append(xt_sub) + xt_orth_ir.append([ir] * len(xt_sub)) + if xt_orth: + xt = np.vstack(xt_orth) + xs_ir = np.hstack([xs_ir, *xt_orth_ir]) + else: + xt = [] + xt_orth = xt_orth_ir = xt_sub = None if len(xt) == 0: log.debug(f'Linear dependency in trial subspace. |r| for each state {dx_norm}') @@ -527,7 +537,7 @@ def real_eig(aop, x0, precond, tol_residual=1e-5, nroots=1, x0sym=None, pick=Non Eigenvectors. ''' - #assert pick is None + assert pick is None assert callable(precond) if isinstance(verbose, logger.Logger): @@ -789,13 +799,19 @@ def _qr(xs, lindep=1e-14): return xs[:nv], idx def _symmetric_orth(xt, lindep=1e-6): + xt = np.asarray(xt) + if xt.dtype == np.float64: + return _symmetric_orth_real(xt, lindep) + else: + return _symmetric_orth_cmplx(xt, lindep) + +def _symmetric_orth_real(xt, lindep=1e-6): ''' Symmetric orthogonalization for xt = {[X, Y]}, and its dual basis vectors {[Y, X]} ''' - xt = np.asarray(xt) x0_size = xt.shape[1] - s11 = xt.conj().dot(xt.T) + s11 = xt.dot(xt.T) s21 = _conj_dot(xt, xt) # Symmetric orthogonalize s, where # s = [[s11, s21.conj().T], @@ -813,15 +829,9 @@ def _symmetric_orth(xt, lindep=1e-6): n = csc.shape[0] for i in range(n): _s21 = csc[i:,i:] - if _s21.dtype == np.float64: - # s21 is symmetric for real vectors - w, u = np.linalg.eigh(_s21) - mask = 1 - abs(w) > lindep - else: - # svd(s[:n,n:]) => svd(_s21.conj().T) => u, w - w2, u = np.linalg.eigh(_s21.conj().T.dot(_s21)) - mask = 1 - w2**.5 > lindep - w = np.einsum('pi,pi->i', u.conj(), _s21.dot(u)) + # s21 is symmetric for real vectors + w, u = np.linalg.eigh(_s21) + mask = 1 - abs(w) > lindep if np.any(mask): c = c[:,i:] break @@ -836,22 +846,16 @@ def _symmetric_orth(xt, lindep=1e-6): e, c = np.linalg.eigh(c_orth.T.dot(s11).dot(c_orth)) c *= e**-.5 c_orth = c_orth.dot(c) - if s21.dtype == np.float64: - csc = c_orth.T.dot(s21).dot(c_orth) - w, u = np.linalg.eigh(csc) - c_orth = c_orth.dot(u) - else: - sc = s21.dot(c_orth) - w2, u = np.linalg.eigh(sc.conj().T.dot(sc)) - c_orth = c_orth.dot(u) - w = np.einsum('pi,pi->i', c_orth.conj(), sc.dot(u)) + csc = c_orth.T.dot(s21).dot(c_orth) + w, u = np.linalg.eigh(csc) + c_orth = c_orth.dot(u) # Symmetric diagonalize - # [1 w] => c = [a b] - # [w 1] [b a] + # [1 w.conj()] => c = [a b] + # [w 1 ] [b a] # where # a = ((1+w)**-.5 + (1-w)**-.5)/2 - # b = ((1+w)**-.5 - (1-w)**-.5)/2 + # b = (phase*(1+w)**-.5 - phase*(1-w)**-.5)/2 a1 = (1 + w)**-.5 a2 = (1 - w)**-.5 a = (a1 + a2) / 2 @@ -860,8 +864,31 @@ def _symmetric_orth(xt, lindep=1e-6): m = xt.shape[1] // 2 x_orth = (c_orth * a).T.dot(xt) # Contribution from the conjugated basis - x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:].conj()) - x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m].conj()) + x_orth[:,:m] += (c_orth * b).T.dot(xt[:,m:]) + x_orth[:,m:] += (c_orth * b).T.dot(xt[:,:m]) + return x_orth + +def _symmetric_orth_cmplx(xt, lindep=1e-6): + n, m = xt.shape + if n == 0: + raise RuntimeError('Linear dependency in trial bases') + m = m // 2 + # The conjugated basis np.hstack([xt[:,m:], xt[:,:m]]).conj() + s11 = xt.conj().dot(xt.T) + s21 = _conj_dot(xt, xt) + s = np.block([[s11, s21.conj().T], + [s21, s11.conj() ]]) + e, c = scipy.linalg.eigh(s) + if e[0] < lindep: + if n == 1: + return xt + return _symmetric_orth_cmplx(xt[:-1], lindep) + + c_orth = (c * e**-.5).dot(c[:n].conj().T) + x_orth = c_orth[:n].T.dot(xt) + # Contribution from the conjugated basis + x_orth[:,:m] += c_orth[n:].T.dot(xt[:,m:].conj()) + x_orth[:,m:] += c_orth[n:].T.dot(xt[:,:m].conj()) return x_orth def _sym_dot(V, U1, m0, m1):