-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathproject_path.py
73 lines (62 loc) · 2.6 KB
/
project_path.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pathlib import Path
import os
import sys
from typing import List, Tuple
class ProjectPath:
_project_root_cache = {}
@classmethod
def get_start_path(cls):
if '__file__' in globals():
return Path(__file__).resolve().parent
elif 'ipykernel' in sys.modules:
# Specific handling for Jupyter notebooks
return Path(os.getcwd())
else:
return Path(os.getcwd())
@classmethod
def find_project_root(cls, start_path: Path, markers: Tuple[str, ...]) -> Path:
"""
Cache-enabled method to find the project root by looking for directory markers.
:param start_path: Path to start search from.
:param markers: Markers to look for in directories.
:return: Path object representing the project root directory.
"""
cache_key = (str(start_path), markers)
if cache_key in cls._project_root_cache:
return cls._project_root_cache[cache_key]
current_path = start_path
while True:
if any((current_path / marker).exists() for marker in markers):
cls._project_root_cache[cache_key] = current_path
return current_path
if current_path.parent == current_path:
raise FileNotFoundError(f"Project root not found using the provided markers: {markers}.")
current_path = current_path.parent
@classmethod
def invalidate_cache(cls):
cls._project_root_cache.clear()
def __init__(self, *path_parts, markers: List[str] = None):
if markers is None:
markers = ['.git', '.hg', '.svn', 'pyproject.toml', 'setup.py', '.project']
self.markers = markers
self.path_parts = path_parts
start_path = self.get_start_path()
self.project_root = self.find_project_root(start_path, tuple(self.markers))
def add_marker(self, marker: str):
if marker not in self.markers:
self.markers.append(marker)
# Invalidate cache as the markers have changed
self._project_root_cache.clear()
def __str__(self) -> str:
return str(self.project_root.joinpath(*self.path_parts))
def __fspath__(self) -> str:
return str(self)
def path(self) -> Path:
"""
Return the Path object for the constructed path.
"""
return self.project_root.joinpath(*self.path_parts)
# Example usage
project_file = ProjectPath("data", "mydata.csv", markers=['.git', 'my_project_marker.file'])
print(project_file) # Print the path as a string
print(project_file.path()) # Get the path as a Path object