Skip to content

Commit

Permalink
allow loading from pre-computed graphs, and add script to preprocess …
Browse files Browse the repository at this point in the history
…dataset
  • Loading branch information
jongyaoY committed Aug 29, 2024
1 parent 26722a1 commit f9eb5c1
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 94 deletions.
Binary file added docs/img/collisions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
186 changes: 113 additions & 73 deletions fignet/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
# SOFTWARE.


import os
import pickle
from dataclasses import fields
from pathlib import Path
from typing import List

import numpy as np
Expand All @@ -32,8 +35,6 @@
from fignet.types import EdgeType, Graph, NodeType
from fignet.utils import dataclass_to_tensor, dict_to_tensor

# from torch.utils.data import DistributedSampler


def collate_fn(batch: List[Graph]):
"""Merge batch of graphs into one graph"""
Expand Down Expand Up @@ -127,27 +128,46 @@ def __init__(
config: dict = None,
transform=None,
):
# If raw data is given, need to calculate graph connectivity on the fly
if os.path.isfile(path):
self._load_raw_data = True

self._data = list(np.load(path, allow_pickle=True).values())[0]
self._dimension = self._data[0]["pos"].shape[-1]
self._target_length = 1
self._input_sequence_length = input_sequence_length

self._data_lengths = [
x["pos"].shape[0] - input_sequence_length - self._target_length
for x in self._data
]
self._length = sum(self._data_lengths)

self._data = list(np.load(path, allow_pickle=True).values())[0]
self._dimension = self._data[0]["pos"].shape[-1]
self._target_length = 1
self._input_sequence_length = input_sequence_length

self._data_lengths = [
x["pos"].shape[0] - input_sequence_length - self._target_length
for x in self._data
]
self._length = sum(self._data_lengths)
# pre-compute cumulative lengths
# to allow fast indexing in __getitem__
self._precompute_cumlengths = [
sum(self._data_lengths[:x])
for x in range(1, len(self._data_lengths) + 1)
]
self._precompute_cumlengths = np.array(
self._precompute_cumlengths, dtype=int
)
# Directly load pre-calculated graphs and save time
elif os.path.isdir(path):
self._load_raw_data = False

self._graph_ext = "pkl"
self._file_list = [
os.path.join(path, f)
for f in os.listdir(path)
if os.path.isfile(os.path.join(path, f))
and f.endswith(self._graph_ext)
]
self._file_list.sort(key=lambda f: int(Path(f).stem.split("_")[1]))
self._length = len(self._file_list)
else:
raise FileNotFoundError(f"{path} not found")

# pre-compute cumulative lengths
# to allow fast indexing in __getitem__
self._precompute_cumlengths = [
sum(self._data_lengths[:x])
for x in range(1, len(self._data_lengths) + 1)
]
self._precompute_cumlengths = np.array(
self._precompute_cumlengths, dtype=int
)
self._transform = transform
self._mode = mode
if config is not None:
Expand All @@ -164,61 +184,81 @@ def __getitem__(self, idx):
elif self._mode == "trajectory":
return self._get_trajectory(idx)

def _get_sample(self, idx):
"""Sample one step"""
trajectory_idx = np.searchsorted(
self._precompute_cumlengths - 1, idx, side="left"
)
# Compute index of pick along time-dimension of trajectory.
start_of_selected_trajectory = (
self._precompute_cumlengths[trajectory_idx - 1]
if trajectory_idx != 0
else 0
)
time_idx = self._input_sequence_length + (
idx - start_of_selected_trajectory
)

start = time_idx - self._input_sequence_length
end = time_idx
obj_ids = dict(self._data[trajectory_idx]["obj_id"].item())
positions = self._data[trajectory_idx]["pos"][
start:end
] # (seq_len, n_obj, 3) input sequence
quats = self._data[trajectory_idx]["quat"][
start:end
] # (seq_len, n_obj, 4) input sequence
target_posisitons = self._data[trajectory_idx]["pos"][time_idx]
target_quats = self._data[trajectory_idx]["quat"][time_idx]
poses = np.concatenate([positions, quats], axis=-1)
target_poses = np.concatenate(
[target_posisitons, target_quats], axis=-1
)

scene_config = dict(self._data[trajectory_idx]["meta_data"].item())

connectivity_radius = self._config.get("connectivity_radius")
if connectivity_radius is not None:
scene_config.update({"connectivity_radius": connectivity_radius})
def _load_graph(self, graph_file):
try:
with open(graph_file, "rb") as f:
sample_dict = pickle.load(f)
graph = Graph()
graph.from_dict(sample_dict)
return graph

noise_std = self._config.get("noise_std")
if noise_std is not None:
scene_config.update({"noise_std": noise_std})
except FileNotFoundError as e:
print(e)
return None

scn = Scene(scene_config)
scn.synchronize_states(
obj_poses=poses,
obj_ids=obj_ids,
)
graph = scn.to_graph(
target_poses=target_poses,
obj_ids=obj_ids,
noise=True,
)
def _get_sample(self, idx):
"""Sample one step"""
if self._load_raw_data:
trajectory_idx = np.searchsorted(
self._precompute_cumlengths - 1, idx, side="left"
)
# Compute index of pick along time-dimension of trajectory.
start_of_selected_trajectory = (
self._precompute_cumlengths[trajectory_idx - 1]
if trajectory_idx != 0
else 0
)
time_idx = self._input_sequence_length + (
idx - start_of_selected_trajectory
)

start = time_idx - self._input_sequence_length
end = time_idx
obj_ids = dict(self._data[trajectory_idx]["obj_id"].item())
positions = self._data[trajectory_idx]["pos"][
start:end
] # (seq_len, n_obj, 3) input sequence
quats = self._data[trajectory_idx]["quat"][
start:end
] # (seq_len, n_obj, 4) input sequence
target_posisitons = self._data[trajectory_idx]["pos"][time_idx]
target_quats = self._data[trajectory_idx]["quat"][time_idx]
poses = np.concatenate([positions, quats], axis=-1)
target_poses = np.concatenate(
[target_posisitons, target_quats], axis=-1
)

scene_config = dict(self._data[trajectory_idx]["meta_data"].item())

connectivity_radius = self._config.get("connectivity_radius")
if connectivity_radius is not None:
scene_config.update(
{"connectivity_radius": connectivity_radius}
)

if self._transform is not None:
graph = self._transform(graph)
return graph
noise_std = self._config.get("noise_std")
if noise_std is not None:
scene_config.update({"noise_std": noise_std})

scn = Scene(scene_config)
scn.synchronize_states(
obj_poses=poses,
obj_ids=obj_ids,
)
graph = scn.to_graph(
target_poses=target_poses,
obj_ids=obj_ids,
noise=True,
)

if self._transform is not None:
graph = self._transform(graph)
return graph
else:
if os.path.exists(self._file_list[idx]):
return self._load_graph(self._file_list[idx])
else:
raise FileNotFoundError

def _get_trajectory(self, idx):
"""Sample continuous steps"""
Expand Down
63 changes: 59 additions & 4 deletions fignet/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,70 @@
# SOFTWARE.

import enum
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Dict, Union

import numpy as np
from dacite import from_dict
from torch import Tensor

TensorType = Union[np.ndarray, Tensor]


class MetaEnum(enum.EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True


class KinematicType(enum.IntEnum):
STATIC = 0
DYNAMIC = 1
SIZE = 1


class NodeType(enum.Enum):
class NodeType(enum.Enum, metaclass=MetaEnum):
MESH = "mesh"
OBJECT = "object"


class EdgeType(enum.Enum):
class EdgeType(enum.Enum, metaclass=MetaEnum):
MESH_MESH = "mesh-mesh"
MESH_OBJ = "mesh-object"
OBJ_MESH = "object-mesh"
FACE_FACE = "face-face"


def key_to_string(key):
if isinstance(key, str):
return key
elif isinstance(key, enum.Enum):
return key.value
else:
raise ValueError(f"{type(key)} not supported")


def string_to_enum(k_str, enum_list):
for e in enum_list:
if k_str in e:
return e(k_str)


def to_dict(item):
d = {}
if isinstance(item, Dict):
for k, v in item.items():
d.update({key_to_string(k): to_dict(v)})
return d
elif is_dataclass(item):
return to_dict(asdict(item))
else:
return item


@dataclass
class NodeFeature:
position: TensorType
Expand All @@ -68,4 +105,22 @@ class Graph:
edge_sets: Dict[EdgeType, Edge] = field(default_factory=lambda: {})

def to_dict(self):
return asdict(self)
return to_dict(self)

def from_dict(self, d):
for f in fields(self):
field_name = f.name
if field_name not in d:
continue
if field_name == "node_sets":
d_cls = NodeFeature
elif field_name == "edge_sets":
d_cls = Edge
for k, v in d[field_name].items():
getattr(self, field_name).update(
{
string_to_enum(k, [NodeType, EdgeType]): from_dict(
d_cls, v
)
}
)
Loading

0 comments on commit f9eb5c1

Please sign in to comment.