From 14fe38bccb39563d18df73bbd05bb8eaf3efeb3b Mon Sep 17 00:00:00 2001 From: pan25-1 <143333033+pan25-1@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:29:04 +0800 Subject: [PATCH] Update reader.py --- connectomics/segclr/reader.py | 83 +++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/connectomics/segclr/reader.py b/connectomics/segclr/reader.py index 7f58e0e..6880064 100644 --- a/connectomics/segclr/reader.py +++ b/connectomics/segclr/reader.py @@ -20,39 +20,44 @@ class EmbeddingReader: - """Reader to load and parse embedding data from sharded ZIP archives.""" + """Reader to load and parse embedding data from sharded ZIP archives.""" - def __init__(self, filesystem, zipdir: str, sharder): - self._filesystem = filesystem # Provides open(path) method. - self._zipdir = zipdir - self._sharder = sharder + def __init__(self, filesystem, zipdir: str, sharder, return_csv=False): + self._filesystem = filesystem # Provides open(path) method. + self._zipdir = zipdir + self._sharder = sharder + self._return_csv = return_csv # 可以选择直接返回csv - def _get_csv_data(self, seg_id: int) -> str: - shard = self._sharder(seg_id) - zip_path = os.path.join(self._zipdir, f'{shard}.zip') - with self._filesystem.open(zip_path) as f: - with zipfile.ZipFile(f) as z: - with z.open(f'{seg_id}.csv') as c: - return c.read().decode('utf-8') + def _get_csv_data(self, seg_id: int) -> str: + shard = self._sharder(seg_id) + zip_path = os.path.join(self._zipdir, f'{shard}.zip') + print("zip_path ", zip_path) + with self._filesystem.open(zip_path) as f: + with zipfile.ZipFile(f) as z: + with z.open(f'{seg_id}.csv') as c: + print("csv_path ", os.path.join(zip_path, f'{seg_id}.csv')) + return c.read().decode('utf-8') - def _parse_csv_data( - self, csv_data: str - ) -> Mapping[Tuple[float, float, float], List[float]]: - """Parses CSV rows into mapping from node XYZ coord to embedding vector.""" - embeddings_from_xyz = {} - for l in csv_data.split('\n'): - fields = l.split(',') - # node_id = int(fields[0]) # This is not currently useful for much. - xyz = tuple(float(f) for f in fields[1:4]) - embedding = [float(f) for f in fields[4:]] - assert xyz not in embeddings_from_xyz - embeddings_from_xyz[xyz] = embedding + def _parse_csv_data( + self, csv_data: str + ) -> Mapping[Tuple[float, float, float], List[float]]: + """Parses CSV rows into mapping from node XYZ coord to embedding vector.""" + embeddings_from_xyz = {} + for l in csv_data.split('\n'): + fields = l.split(',') + # node_id = int(fields[0]) # This is not currently useful for much. + xyz = tuple(float(f) for f in fields[1:4]) + embedding = [float(f) for f in fields[4:]] + assert xyz not in embeddings_from_xyz + embeddings_from_xyz[xyz] = embedding - return embeddings_from_xyz + return embeddings_from_xyz - def __getitem__(self, seg_id: int): - csv_data = self._get_csv_data(seg_id) - return self._parse_csv_data(csv_data) + def __getitem__(self, seg_id: int): + csv_data = self._get_csv_data(seg_id) + if self._return_csv: + return csv_data + return self._parse_csv_data(csv_data) # Unfortunately, this round of exports were accidentally run with bytewidth 64. @@ -77,16 +82,16 @@ def __getitem__(self, seg_id: int): ) -def get_reader(key: str, filesystem, num_shards: int = 10_000): - """Convenience helper to get reader for given dataset key.""" - if key in DATA_URL_FROM_KEY_BYTEWIDTH64: - url = DATA_URL_FROM_KEY_BYTEWIDTH64[key] - bytewidth = 64 - else: - raise ValueError(f'Key not found: {key}') +def get_reader(key: str, filesystem, num_shards: int = 10_000, return_csv: bool = False): + """Convenience helper to get reader for given dataset key.""" + if key in DATA_URL_FROM_KEY_BYTEWIDTH64: + url = DATA_URL_FROM_KEY_BYTEWIDTH64[key] + bytewidth = 64 + else: + raise ValueError(f'Key not found: {key}') - def sharder(segment_id: int) -> int: - return sharding.md5_shard( - segment_id, num_shards=num_shards, bytewidth=bytewidth) + def sharder(segment_id: int) -> int: + return sharding.md5_shard( + segment_id, num_shards=num_shards, bytewidth=bytewidth) - return EmbeddingReader(filesystem, url, sharder) + return EmbeddingReader(filesystem, url, sharder, return_csv)