-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import re | ||
import statistics | ||
import subprocess | ||
|
||
|
||
def measure_import_times(n_runs=30, module="haystack"): | ||
""" | ||
Measure the time it takes to import a module. | ||
""" | ||
user_times = [] | ||
sys_times = [] | ||
|
||
print(f"Running {n_runs} measurements...") | ||
|
||
for i in range(n_runs): | ||
# Run the import command and capture output | ||
result = subprocess.run(["time", "python", "-c", f"import {module}"], capture_output=True, text=True) | ||
|
||
# Check both stdout and stderr | ||
time_output = result.stderr | ||
|
||
# Extract times using regex - matches patterns like "3.21user 0.17system" | ||
time_pattern = r"([\d.]+)user\s+([\d.]+)system" | ||
match = re.search(time_pattern, time_output) | ||
|
||
if match: | ||
user_time = float(match.group(1)) | ||
sys_time = float(match.group(2)) | ||
|
||
user_times.append(user_time) | ||
sys_times.append(sys_time) | ||
|
||
# print(user_times) | ||
|
||
if (i + 1) % 10 == 0: | ||
print(f"Completed {i + 1} runs...") | ||
|
||
# Calculate statistics | ||
avg_user = statistics.mean(user_times) | ||
avg_sys = statistics.mean(sys_times) | ||
avg_total = avg_user + avg_sys | ||
|
||
# Calculate standard deviations | ||
std_user = statistics.stdev(user_times) | ||
std_sys = statistics.stdev(sys_times) | ||
|
||
print("\nResults:") | ||
print(f"Average user time: {avg_user:.3f}s ± {std_user:.3f}s") | ||
print(f"Average sys time: {avg_sys:.3f}s ± {std_sys:.3f}s") | ||
print(f"Average total (user + sys): {avg_total:.3f}s") | ||
|
||
|
||
if __name__ == "__main__": | ||
measure_import_times() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import importlib.abc | ||
import importlib.util | ||
import sys | ||
import types | ||
from pathlib import Path | ||
|
||
|
||
class ImportTracker(importlib.abc.MetaPathFinder): | ||
def find_spec(self, fullname, path, target=None): | ||
""" | ||
If the module name contains "torch", print the full name and the stack trace. | ||
""" | ||
if "torch" in fullname: | ||
print(f"\nAttempting to import: {fullname}") | ||
import traceback | ||
|
||
for frame in traceback.extract_stack()[:-1]: # Exclude this frame | ||
if "haystack" in frame.filename: | ||
print(f" In Haystack file: {frame.filename}:{frame.lineno}") | ||
print(f" {frame.line}") | ||
|
||
|
||
# Install the import tracker | ||
sys.meta_path.insert(0, ImportTracker()) | ||
|
||
# Record modules before import | ||
print("Recording initial modules...") | ||
modules_before = set(sys.modules.keys()) | ||
|
||
# Import haystack | ||
print("Importing haystack...") | ||
import haystack | ||
|
||
# Find new modules after import | ||
print("Analyzing new modules...") | ||
modules_after = set(sys.modules.keys()) | ||
new_modules = modules_after - modules_before | ||
|
||
# Filter for haystack modules that imported torch | ||
haystack_importers = {} | ||
|
||
for name in new_modules: | ||
if name.startswith("haystack"): | ||
module = sys.modules[name] | ||
# Check if this module uses torch | ||
module_dict = getattr(module, "__dict__", {}) | ||
for value in module_dict.values(): | ||
if isinstance(value, types.ModuleType) and "torch" in value.__name__: | ||
if name not in haystack_importers: | ||
haystack_importers[name] = set() | ||
haystack_importers[name].add(value.__name__) | ||
|
||
if haystack_importers: | ||
print("\nFound haystack modules that imported torch:") | ||
for module_name, torch_modules in sorted(haystack_importers.items()): | ||
print(f"\n{module_name}:") | ||
for torch_module in sorted(torch_modules): | ||
print(f" - {torch_module}") | ||
else: | ||
print("\nNo haystack modules imported torch") |