Skip to content

Commit

Permalink
support recursive detection for the system of model_devi (#1424)
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Jan 17, 2022
1 parent b88c1da commit a9d08a7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
2 changes: 1 addition & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def parse_args(args: Optional[List[str]] = None):
"--system",
default=".",
type=str,
help="The system directory, not support recursive detection.",
help="The system directory. Recursively detect systems in this directory.",
)
parser_model_devi.add_argument(
"-S", "--set-prefix", default="set", type=str, help="The set prefix"
Expand Down
60 changes: 34 additions & 26 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .deep_pot import DeepPot
from ..utils.data import DeepmdData
from ..utils.batch_size import AutoBatchSize
from deepmd.common import expand_sys_str


def calc_model_devi_f(fs: np.ndarray):
Expand Down Expand Up @@ -56,11 +57,12 @@ def write_model_devi_out(devi: np.ndarray, fname: str):
header = "%10s" % "step"
for item in 'vf':
header += "%19s%19s%19s" % (f"max_devi_{item}", f"min_devi_{item}", f"avg_devi_{item}")
np.savetxt(fname,
devi,
fmt=['%12d'] + ['%19.6e' for _ in range(6)],
delimiter='',
header=header)
with open(fname, "ab") as fp:
np.savetxt(fp,
devi,
fmt=['%12d'] + ['%19.6e' for _ in range(6)],
delimiter='',
header=header)
return devi

def _check_tmaps(tmaps, ref_tmap=None):
Expand Down Expand Up @@ -185,25 +187,31 @@ def make_model_devi(
tmap = tmaps[0]
else:
raise RuntimeError("The models does not have the same type map.")

# create data-system
dp_data = DeepmdData(system, set_prefix, shuffle_test=False, type_map=tmap)
if dp_data.pbc:
nopbc = False
else:
nopbc = True

data_sets = [dp_data._load_set(set_name) for set_name in dp_data.dirs]
nframes_tot = 0
devis = []
for data in data_sets:
coord = data["coord"]
box = data["box"]
atype = data["type"][0]
devi = calc_model_devi(coord, box, atype, dp_models, nopbc=nopbc)
nframes_tot += coord.shape[0]
devis.append(devi)
devis = np.vstack(devis)
devis[:, 0] = np.arange(nframes_tot) * frequency
write_model_devi_out(devis, output)
return devis
all_sys = expand_sys_str(system)
if len(all_sys) == 0:
raise RuntimeError("Did not find valid system")
devis_coll = []
for system in all_sys:
# create data-system
dp_data = DeepmdData(system, set_prefix, shuffle_test=False, type_map=tmap)
if dp_data.pbc:
nopbc = False
else:
nopbc = True

data_sets = [dp_data._load_set(set_name) for set_name in dp_data.dirs]
nframes_tot = 0
devis = []
for data in data_sets:
coord = data["coord"]
box = data["box"]
atype = data["type"][0]
devi = calc_model_devi(coord, box, atype, dp_models, nopbc=nopbc)
nframes_tot += coord.shape[0]
devis.append(devi)
devis = np.vstack(devis)
devis[:, 0] = np.arange(nframes_tot) * frequency
write_model_devi_out(devis, output)
devis_coll.append(devis)
return devis_coll

0 comments on commit a9d08a7

Please sign in to comment.