Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update reader.py #92

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions connectomics/segclr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,44 @@


class EmbeddingReader:
"""Reader to load and parse embedding data from sharded ZIP archives."""
"""Reader to load and parse embedding data from sharded ZIP archives."""

def __init__(self, filesystem, zipdir: str, sharder):
self._filesystem = filesystem # Provides open(path) method.
self._zipdir = zipdir
self._sharder = sharder
def __init__(self, filesystem, zipdir: str, sharder, return_csv=False):
self._filesystem = filesystem # Provides open(path) method.
self._zipdir = zipdir
self._sharder = sharder
self._return_csv = return_csv # 可以选择直接返回csv

def _get_csv_data(self, seg_id: int) -> str:
shard = self._sharder(seg_id)
zip_path = os.path.join(self._zipdir, f'{shard}.zip')
with self._filesystem.open(zip_path) as f:
with zipfile.ZipFile(f) as z:
with z.open(f'{seg_id}.csv') as c:
return c.read().decode('utf-8')
def _get_csv_data(self, seg_id: int) -> str:
shard = self._sharder(seg_id)
zip_path = os.path.join(self._zipdir, f'{shard}.zip')
print("zip_path ", zip_path)
with self._filesystem.open(zip_path) as f:
with zipfile.ZipFile(f) as z:
with z.open(f'{seg_id}.csv') as c:
print("csv_path ", os.path.join(zip_path, f'{seg_id}.csv'))
return c.read().decode('utf-8')

def _parse_csv_data(
self, csv_data: str
) -> Mapping[Tuple[float, float, float], List[float]]:
"""Parses CSV rows into mapping from node XYZ coord to embedding vector."""
embeddings_from_xyz = {}
for l in csv_data.split('\n'):
fields = l.split(',')
# node_id = int(fields[0]) # This is not currently useful for much.
xyz = tuple(float(f) for f in fields[1:4])
embedding = [float(f) for f in fields[4:]]
assert xyz not in embeddings_from_xyz
embeddings_from_xyz[xyz] = embedding
def _parse_csv_data(
self, csv_data: str
) -> Mapping[Tuple[float, float, float], List[float]]:
"""Parses CSV rows into mapping from node XYZ coord to embedding vector."""
embeddings_from_xyz = {}
for l in csv_data.split('\n'):
fields = l.split(',')
# node_id = int(fields[0]) # This is not currently useful for much.
xyz = tuple(float(f) for f in fields[1:4])
embedding = [float(f) for f in fields[4:]]
assert xyz not in embeddings_from_xyz
embeddings_from_xyz[xyz] = embedding

return embeddings_from_xyz
return embeddings_from_xyz

def __getitem__(self, seg_id: int):
csv_data = self._get_csv_data(seg_id)
return self._parse_csv_data(csv_data)
def __getitem__(self, seg_id: int):
csv_data = self._get_csv_data(seg_id)
if self._return_csv:
return csv_data
return self._parse_csv_data(csv_data)


# Unfortunately, this round of exports were accidentally run with bytewidth 64.
Expand All @@ -77,16 +82,16 @@ def __getitem__(self, seg_id: int):
)


def get_reader(key: str, filesystem, num_shards: int = 10_000):
"""Convenience helper to get reader for given dataset key."""
if key in DATA_URL_FROM_KEY_BYTEWIDTH64:
url = DATA_URL_FROM_KEY_BYTEWIDTH64[key]
bytewidth = 64
else:
raise ValueError(f'Key not found: {key}')
def get_reader(key: str, filesystem, num_shards: int = 10_000, return_csv: bool = False):
"""Convenience helper to get reader for given dataset key."""
if key in DATA_URL_FROM_KEY_BYTEWIDTH64:
url = DATA_URL_FROM_KEY_BYTEWIDTH64[key]
bytewidth = 64
else:
raise ValueError(f'Key not found: {key}')

def sharder(segment_id: int) -> int:
return sharding.md5_shard(
segment_id, num_shards=num_shards, bytewidth=bytewidth)
def sharder(segment_id: int) -> int:
return sharding.md5_shard(
segment_id, num_shards=num_shards, bytewidth=bytewidth)

return EmbeddingReader(filesystem, url, sharder)
return EmbeddingReader(filesystem, url, sharder, return_csv)