Skip to content

Commit

Permalink
scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Feb 7, 2025
1 parent 14716f2 commit ce35151
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
54 changes: 54 additions & 0 deletions measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import re
import statistics
import subprocess


def measure_import_times(n_runs=30, module="haystack"):
"""
Measure the time it takes to import a module.
"""
user_times = []
sys_times = []

print(f"Running {n_runs} measurements...")

for i in range(n_runs):
# Run the import command and capture output
result = subprocess.run(["time", "python", "-c", f"import {module}"], capture_output=True, text=True)

# Check both stdout and stderr
time_output = result.stderr

# Extract times using regex - matches patterns like "3.21user 0.17system"
time_pattern = r"([\d.]+)user\s+([\d.]+)system"
match = re.search(time_pattern, time_output)

if match:
user_time = float(match.group(1))
sys_time = float(match.group(2))

user_times.append(user_time)
sys_times.append(sys_time)

# print(user_times)

if (i + 1) % 10 == 0:
print(f"Completed {i + 1} runs...")

# Calculate statistics
avg_user = statistics.mean(user_times)
avg_sys = statistics.mean(sys_times)
avg_total = avg_user + avg_sys

# Calculate standard deviations
std_user = statistics.stdev(user_times)
std_sys = statistics.stdev(sys_times)

print("\nResults:")
print(f"Average user time: {avg_user:.3f}s ± {std_user:.3f}s")
print(f"Average sys time: {avg_sys:.3f}s ± {std_sys:.3f}s")
print(f"Average total (user + sys): {avg_total:.3f}s")


if __name__ == "__main__":
measure_import_times()
60 changes: 60 additions & 0 deletions torch_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import importlib.abc
import importlib.util
import sys
import types
from pathlib import Path


class ImportTracker(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
"""
If the module name contains "torch", print the full name and the stack trace.
"""
if "torch" in fullname:
print(f"\nAttempting to import: {fullname}")
import traceback

for frame in traceback.extract_stack()[:-1]: # Exclude this frame
if "haystack" in frame.filename:
print(f" In Haystack file: {frame.filename}:{frame.lineno}")
print(f" {frame.line}")


# Install the import tracker
sys.meta_path.insert(0, ImportTracker())

# Record modules before import
print("Recording initial modules...")
modules_before = set(sys.modules.keys())

# Import haystack
print("Importing haystack...")
import haystack

# Find new modules after import
print("Analyzing new modules...")
modules_after = set(sys.modules.keys())
new_modules = modules_after - modules_before

# Filter for haystack modules that imported torch
haystack_importers = {}

for name in new_modules:
if name.startswith("haystack"):
module = sys.modules[name]
# Check if this module uses torch
module_dict = getattr(module, "__dict__", {})
for value in module_dict.values():
if isinstance(value, types.ModuleType) and "torch" in value.__name__:
if name not in haystack_importers:
haystack_importers[name] = set()
haystack_importers[name].add(value.__name__)

if haystack_importers:
print("\nFound haystack modules that imported torch:")
for module_name, torch_modules in sorted(haystack_importers.items()):
print(f"\n{module_name}:")
for torch_module in sorted(torch_modules):
print(f" - {torch_module}")
else:
print("\nNo haystack modules imported torch")

0 comments on commit ce35151

Please sign in to comment.