diff --git a/dpdata/qe/scf.py b/dpdata/qe/scf.py index f86708605..8c5299d07 100755 --- a/dpdata/qe/scf.py +++ b/dpdata/qe/scf.py @@ -12,47 +12,67 @@ kbar2evperang3 = 1e3 / 1.602176621e6 -def get_block(lines, keyword, skip=0): - ret = [] - for idx, ii in enumerate(lines): - if keyword in ii: - blk_idx = idx + 1 + skip - while len(lines[blk_idx]) == 0: - blk_idx += 1 - while len(lines[blk_idx]) != 0 and blk_idx != len(lines): - ret.append(lines[blk_idx]) - blk_idx += 1 +def get_block(lines, start_marker): + start_idx = None + for idx, line in enumerate(lines): + if start_marker in line: + start_idx = idx + 1 break - return ret + if start_idx is None: + raise RuntimeError(f"{start_marker} not found in the input lines.") + + block = [] + for line in lines[start_idx:]: + if line.strip() == "" or line.strip().startswith("&"): + break + block.append(line.strip()) + return block + + +def get_block(lines, start_marker, skip=0): + start_idx = None + for idx, line in enumerate(lines): + if start_marker in line: + start_idx = idx + 1 + skip + break + if start_idx is None: + raise RuntimeError(f"{start_marker} not found in the input lines.") + + block = [] + for line in lines[start_idx:]: + if line.strip() == "" or line.strip().startswith("&"): + break + block.append(line.strip()) + return block def get_cell(lines): - ret = [] - for idx, ii in enumerate(lines): - if "ibrav" in ii: + for idx, line in enumerate(lines): + if "ibrav" in line: + ibrav = int(line.replace(",", "").split("=")[-1]) break - blk = lines[idx : idx + 2] - ibrav = int(blk[0].replace(",", "").split("=")[-1]) + else: + raise RuntimeError("ibrav not found in the input lines.") + if ibrav == 0: - for iline in lines: - if "CELL_PARAMETERS" in iline and "angstrom" not in iline.lower(): + for line in lines: + if "CELL_PARAMETERS" in line and "angstrom" not in line.lower(): raise RuntimeError( "CELL_PARAMETERS must be written in Angstrom. Other units are not supported yet." ) blk = get_block(lines, "CELL_PARAMETERS") - for ii in blk: - ret.append([float(jj) for jj in ii.split()[0:3]]) + ret = [] + for line in blk: + ret.append([float(value) for value in line.split()[0:3]]) ret = np.array(ret) elif ibrav == 1: a = None - for iline in lines: - line = iline.replace("=", " ").replace(",", "").split() + for line in lines: + line = line.replace("=", " ").replace(",", "").split() if len(line) >= 2 and "a" == line[0]: - # print("line = ", line) a = float(line[1]) if len(line) >= 2 and "celldm(1)" == line[0]: a = float(line[1]) * bohr2ang - # print("a = ", a) if not a: raise RuntimeError("parameter 'a' or 'celldm(1)' cannot be found.") ret = np.array([[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]])