Skip to content

Commit

Permalink
Add readNERSCGauge().
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Dec 5, 2024
1 parent 0f0bd2f commit 39ad13f
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 22 deletions.
26 changes: 18 additions & 8 deletions pyquda_utils/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,30 @@ def writeNPYPropagator(filename: str, propagator: LatticePropagator):
write(filename, propagator.latt_info.global_size, propagator.lexico())


def readOpenQCDGauge(filename: str, endian: Literal["<", ">"] = "<"):
def readOpenQCDGauge(filename: str, return_plaquette: bool = False):
from .openqcd import readGauge as read

latt_size, plaquette, gauge_raw = read(filename, endian)
gauge = LatticeGauge(LatticeInfo(latt_size), cb2(gauge_raw, [1, 2, 3, 4]))
print(gauge.plaquette()[0], plaquette)
return gauge
latt_size, plaquette, gauge_raw = read(filename)
if not return_plaquette:
return LatticeGauge(LatticeInfo(latt_size), cb2(gauge_raw, [1, 2, 3, 4]))
else:
return LatticeGauge(LatticeInfo(latt_size), cb2(gauge_raw, [1, 2, 3, 4])), plaquette


def writeOpenQCDGauge(filename: str, gauge: LatticeGauge):
def writeOpenQCDGauge(filename: str, gauge: LatticeGauge, plaquette: float):
from .openqcd import writeGauge as write

print(gauge.plaquette())
write(filename, gauge.latt_info.global_size, gauge.plaquette()[0] * Nc, gauge.lexico())
write(filename, gauge.latt_info.global_size, plaquette, gauge.lexico())


def readNERSCGauge(filename: str, return_plaquette: bool = False, link_trace: bool = True, checksum: bool = True):
from .nersc import readGauge as read

latt_size, plaquette, gauge_raw = read(filename, link_trace, checksum)
if not return_plaquette:
return LatticeGauge(LatticeInfo(latt_size), cb2(gauge_raw, [1, 2, 3, 4]))
else:
return LatticeGauge(LatticeInfo(latt_size), cb2(gauge_raw, [1, 2, 3, 4])), plaquette


def readQIOGauge(filename: str):
Expand Down
4 changes: 2 additions & 2 deletions pyquda_utils/io/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def checksum_qio(latt_size: List[int], data):
)
rank29 = (rank % 29).astype("<u4")
rank31 = (rank % 31).astype("<u4")
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)).item(), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)).item(), MPI.BXOR)
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)), MPI.BXOR)
return sum29, sum31


Expand Down
10 changes: 4 additions & 6 deletions pyquda_utils/io/milc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


def checksum_milc(latt_size: List[int], data):
import numpy
from mpi4py import MPI

gx, gy, gz, gt = getGridCoord()
Expand All @@ -29,14 +28,13 @@ def checksum_milc(latt_size: List[int], data):
)
rank29 = (rank % 29).astype("<u4")
rank31 = (rank % 31).astype("<u4")
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)).item(), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)).item(), MPI.BXOR)
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)), MPI.BXOR)
return sum29, sum31


def checksum_qio(latt_size: List[int], data):
import zlib
import numpy
from mpi4py import MPI

gx, gy, gz, gt = getGridCoord()
Expand All @@ -53,8 +51,8 @@ def checksum_qio(latt_size: List[int], data):
)
rank29 = rank % 29
rank31 = rank % 31
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)).item(), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)).item(), MPI.BXOR)
sum29 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank29 | work >> (32 - rank29)), MPI.BXOR)
sum31 = getMPIComm().allreduce(numpy.bitwise_xor.reduce(work << rank31 | work >> (32 - rank31)), MPI.BXOR)
return sum29, sum31


Expand Down
68 changes: 68 additions & 0 deletions pyquda_utils/io/nersc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from os import path
from typing import Dict

import numpy

from pyquda import getSublatticeSize, readMPIFile, writeMPIFile, getMPIComm, getMPIRank

Nd, Ns, Nc = 4, 4, 3


def checksum_nersc(data: numpy.ndarray) -> int:
from mpi4py import MPI

return getMPIComm().allreduce(numpy.sum(data.view("<u4"), dtype="<u4"), MPI.SUM)


def readGauge(filename: str, link_trace: bool = True, checksum: bool = True):
filename = path.expanduser(path.expandvars(filename))
header: Dict[str, str] = {}
with open(filename, "rb") as f:
assert f.readline().decode() == "BEGIN_HEADER\n"
buffer = f.readline().decode()
while buffer != "END_HEADER\n":
key, val = buffer.split("=")
header[key.strip()] = val.strip()
buffer = f.readline().decode()
offset = f.tell()
latt_size = [
int(header["DIMENSION_1"]),
int(header["DIMENSION_2"]),
int(header["DIMENSION_3"]),
int(header["DIMENSION_4"]),
]
Lx, Ly, Lz, Lt = getSublatticeSize(latt_size)
assert header["FLOATING_POINT"].startswith("IEEE")
if header["FLOATING_POINT"][6:] == "BIG":
endian = ">"
elif header["FLOATING_POINT"][6:] == "LITTLE":
endian = "<"
else:
raise ValueError(f"Unsupported endian: {header['FLOATING_POINT'][6:]}")
nbytes = int(header["FLOATING_POINT"][4:6]) // 8
dtype = f"{endian}c{2 * nbytes}"
plaquette = float(header["PLAQUETTE"])

if header["DATATYPE"] == "4D_SU3_GAUGE_3x3":
gauge = readMPIFile(filename, dtype, offset, (Lt, Lz, Ly, Lx, Nd, Nc, Nc), (3, 2, 1, 0))
if link_trace:
assert numpy.isclose(
numpy.einsum("tzyxdaa->", gauge.real) / (gauge.size // Nc), float(header["LINK_TRACE"])
), f"Bad link trace for {filename}"
if checksum:
assert checksum_nersc(gauge.astype(f"<c{2 * nbytes}").reshape(-1)) == int(
header["CHECKSUM"], 16
), f"Bad checksum for {filename}"
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
return latt_size, plaquette, gauge
elif header["DATATYPE"] == "4D_SU3_GAUGE":
# gauge = readMPIFile(filename, dtype, offset, (Lt, Lz, Ly, Lx, Nd, Nc - 1, Nc), (3, 2, 1, 0))
# if checksum:
# assert (
# hex(checksum_nersc(gauge.astype(f"<c{2 * nbytes}").reshape(-1)))[2:] == header["CHECKSUM"]
# ), f"Bad checksum for {filename}"
# gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
# return latt_size, gauge
raise NotImplementedError("SU3_GAUGE is not supported")
else:
raise ValueError(f"Unsupported datatype: {header['DATATYPE']}")
12 changes: 6 additions & 6 deletions pyquda_utils/io/openqcd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from os import path
import struct
from typing import List, Literal
from typing import List

import numpy

Expand All @@ -9,14 +9,14 @@
Nd, Ns, Nc = 4, 4, 3


def readGauge(filename: str, endian: Literal["<", ">"] = "<"):
def readGauge(filename: str):
filename = path.expanduser(path.expandvars(filename))
with open(filename, "rb") as f:
latt_size = struct.unpack(f"{endian}iiii", f.read(16))[::-1]
plaquette = struct.unpack(f"{endian}d", f.read(8))[0]
latt_size = struct.unpack("<iiii", f.read(16))[::-1]
plaquette = struct.unpack("<d", f.read(8))[0] / Nc
offset = f.tell()
Lx, Ly, Lz, Lt = getSublatticeSize(latt_size)
dtype = f"{endian}c16"
dtype = "<c16"

gauge = readMPIFile(filename, dtype, offset, (Lt, Lz, Ly, Lx, Nd, Nc, Nc), (3, 2, 1, 0))
gauge = gauge.transpose(4, 0, 1, 2, 3, 5, 6).astype("<c16")
Expand All @@ -32,7 +32,7 @@ def writeGauge(filename: str, latt_size: List[int], plaquette: float, gauge: num
if getMPIRank() == 0:
with open(filename, "wb") as f:
f.write(struct.pack("<iiii", *latt_size[::-1]))
f.write(struct.pack("<d", plaquette))
f.write(struct.pack("<d", plaquette * Nc))
offset = f.tell()
offset = getMPIComm().bcast(offset)

Expand Down

0 comments on commit 39ad13f

Please sign in to comment.