Skip to content

Commit

Permalink
check test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 6, 2025
1 parent 78b2a10 commit 6a5d169
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def count_non_zero_elements(self, tensor, threshold=1e-8):
return torch.sum(torch.abs(tensor) > threshold).item()

def test_make_stat_input(self):
#3 frames would be count
lst = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand All @@ -78,7 +79,7 @@ def test_make_stat_input(self):
)
bias, _ = compute_output_stats(lst, ntypes=57)
energy = bias.get("energy")
self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.")
print(energy)
non_zero_count = self.count_non_zero_elements(energy)
self.assertEqual(
non_zero_count,
Expand All @@ -87,6 +88,9 @@ def test_make_stat_input(self):
)

def test_make_stat_input_nocomplete(self):
#missing element:13,31,37
#only one frame would be count

lst = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand All @@ -96,14 +100,50 @@ def test_make_stat_input_nocomplete(self):
)
bias, _ = compute_output_stats(lst, ntypes=57)
energy = bias.get("energy")
self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.")
print(energy)
non_zero_count = self.count_non_zero_elements(energy)
self.assertLess(
non_zero_count,
self.real_ntypes,
f"Expected fewer than {self.real_ntypes} non-zero elements, but got {non_zero_count}.",
)

def test_bias(self):
lst_ori = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
nbatches=1,
min_frames_per_element_forstat=1,
enable_element_completion=False,
)
lst_all = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
nbatches=1,
min_frames_per_element_forstat=1,
enable_element_completion=True,
)
bias_ori, _ = compute_output_stats(lst_ori, ntypes=57)
bias_all, _ = compute_output_stats(lst_all, ntypes=57)
energy_ori = np.array(bias_ori.get("energy").cpu()).flatten()
energy_all = np.array(bias_all.get("energy").cpu()).flatten()

for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)):
if e_all == 0:
self.assertEqual(
e_ori,
0,
f"Index {i}: energy_all=0, but energy_ori={e_ori}"
)
else:
if e_ori != 0:
diff = abs(e_ori - e_all)
rel_diff = diff / abs(e_ori)
self.assertTrue(
rel_diff < 0.4,
f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, "
f"relative difference {rel_diff:.2%} is too large"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit 6a5d169

Please sign in to comment.