Skip to content

Commit

Permalink
Added GitRef support in init_run (#1292)
Browse files Browse the repository at this point in the history
Co-authored-by: Sabine <[email protected]>
  • Loading branch information
Raalsky and normandy7 authored Mar 15, 2023
1 parent ee723e2 commit 54801d2
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 63 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## [UNRELEASED] neptune 1.1.0

### Features
- Added ability to provide repository path with `GitRef` to `init_run` ([#1292](https://github.com/neptune-ai/neptune-client/pull/1292))
- Added `SupportsNamespaces` interface in `neptune.typing` for proper type annotations of Handler and Neptune objects ([#1280](https://github.com/neptune-ai/neptune-client/pull/1280))
- `Run`, `Model`, `ModelVersion` and `Project` could be created with constructor in addition to `init_*` functions ([#1246](https://github.com/neptune-ai/neptune-client/pull/1246))

Expand Down
26 changes: 13 additions & 13 deletions src/neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@
from neptune.internal.operation_processors.operation_storage import OperationStorage
from neptune.internal.utils import base64_decode
from neptune.internal.utils.generic_attribute_mapper import map_attribute_result_to_value
from neptune.internal.utils.git import GitInfo
from neptune.internal.utils.paths import path_to_str
from neptune.internal.websockets.websockets_factory import WebsocketsFactory
from neptune.management.exceptions import ObjectNotFound
from neptune.types.atoms import GitRef
from neptune.version import version as neptune_client_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -316,31 +316,31 @@ def get_metadata_container(
def create_run(
self,
project_id: UniqueId,
git_ref: Optional[GitRef] = None,
git_info: Optional[GitInfo] = None,
custom_run_id: Optional[str] = None,
notebook_id: Optional[str] = None,
checkpoint_id: Optional[str] = None,
) -> ApiExperiment:

git_info = (
git_info_serialized = (
{
"commit": {
"commitId": git_ref.commit_id,
"message": git_ref.message,
"authorName": git_ref.author_name,
"authorEmail": git_ref.author_email,
"commitDate": git_ref.commit_date,
"commitId": git_info.commit_id,
"message": git_info.message,
"authorName": git_info.author_name,
"authorEmail": git_info.author_email,
"commitDate": git_info.commit_date,
},
"repositoryDirty": git_ref.dirty,
"currentBranch": git_ref.branch,
"remotes": git_ref.remotes,
"repositoryDirty": git_info.dirty,
"currentBranch": git_info.branch,
"remotes": git_info.remotes,
}
if git_ref
if git_info
else None
)

additional_params = {
"gitInfo": git_info,
"gitInfo": git_info_serialized,
"customId": custom_run_id,
}

Expand Down
4 changes: 2 additions & 2 deletions src/neptune/internal/backends/neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
)
from neptune.internal.operation import Operation
from neptune.internal.operation_processors.operation_storage import OperationStorage
from neptune.internal.utils.git import GitInfo
from neptune.internal.websockets.websockets_factory import WebsocketsFactory
from neptune.types.atoms import GitRef


class NeptuneBackend:
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_available_workspaces(self) -> List[Workspace]:
def create_run(
self,
project_id: UniqueId,
git_ref: Optional[GitRef] = None,
git_info: Optional[GitInfo] = None,
custom_run_id: Optional[str] = None,
notebook_id: Optional[str] = None,
checkpoint_id: Optional[str] = None,
Expand Down
7 changes: 3 additions & 4 deletions src/neptune/internal/backends/neptune_backend_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
from neptune.internal.types.file_types import FileType
from neptune.internal.utils import base64_decode
from neptune.internal.utils.generic_attribute_mapper import NoValue
from neptune.internal.utils.git import GitInfo
from neptune.internal.utils.paths import path_to_str
from neptune.types import (
Boolean,
Expand Down Expand Up @@ -191,17 +192,15 @@ def _get_container(self, container_id: UniqueId, container_type: ContainerType):
def create_run(
self,
project_id: UniqueId,
git_ref: Optional[GitRef] = None,
git_info: Optional[GitInfo] = None,
custom_run_id: Optional[str] = None,
notebook_id: Optional[str] = None,
checkpoint_id: Optional[str] = None,
) -> ApiExperiment:
sys_id = SysId(f"{self.PROJECT_KEY}-{self._next_run}")
self._next_run += 1
new_run_id = UniqueId(str(uuid.uuid4()))
container = self._create_container(new_run_id, ContainerType.RUN, sys_id=sys_id)
if git_ref:
container.set(["source_code", "git"], git_ref)
self._create_container(new_run_id, ContainerType.RUN, sys_id=sys_id)
return ApiExperiment(
id=new_run_id,
type=ContainerType.RUN,
Expand Down
3 changes: 2 additions & 1 deletion src/neptune/internal/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["ANONYMOUS_API_TOKEN_CONTENT"]
__all__ = ["ANONYMOUS_API_TOKEN_CONTENT", "DO_NOT_TRACK_GIT_REPOSITORY"]

ANONYMOUS_API_TOKEN_CONTENT = (
"eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS"
"5haSIsImFwaV9rZXkiOiJiNzA2YmM4Zi03NmY5LTRjMmUtOTM5ZC00YmEwMzZmOTMyZTQifQo="
)
DO_NOT_TRACK_GIT_REPOSITORY = "DO_NOT_TRACK_GIT_REPOSITORY"
55 changes: 34 additions & 21 deletions src/neptune/internal/utils/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["get_git_info", "discover_git_repo_location"]
__all__ = ["to_git_info", "GitInfo"]

import logging
import os
import warnings
from typing import Optional

from neptune.types.atoms import GitRef
from neptune.vendor.lib_programname import get_path_executed_script
from dataclasses import dataclass
from datetime import datetime
from typing import (
List,
Optional,
Union,
)

from neptune.types.atoms.git_ref import (
GitRef,
GitRefDisabled,
)

_logger = logging.getLogger(__name__)


@dataclass
class GitInfo:
commit_id: str
message: str
author_name: str
author_email: str
commit_date: datetime
dirty: bool
branch: Optional[str]
remotes: Optional[List[str]]


def get_git_repo(repo_path):
# WARN: GitPython asserts the existence of `git` executable
# which consists in failure during the preparation of conda package
Expand All @@ -37,10 +56,16 @@ def get_git_repo(repo_path):
warnings.warn("GitPython could not be initialized")


def get_git_info(repo_path=None):
def to_git_info(git_ref: Union[GitRef, GitRefDisabled]) -> Optional[GitInfo]:
try:
repo = get_git_repo(repo_path)
if git_ref == GitRef.DISABLED:
return None

initial_repo_path = git_ref.resolve_path()
if initial_repo_path is None:
return None

repo = get_git_repo(repo_path=initial_repo_path)
commit = repo.head.commit

active_branch = ""
Expand All @@ -53,7 +78,7 @@ def get_git_info(repo_path=None):

remote_urls = [remote.url for remote in repo.remotes]

return GitRef(
return GitInfo(
commit_id=commit.hexsha,
message=commit.message,
author_name=commit.author.name,
Expand All @@ -65,15 +90,3 @@ def get_git_info(repo_path=None):
)
except: # noqa: E722
return None


def get_git_repo_path(initial_path: str) -> Optional[str]:
try:
return get_git_repo(initial_path).git_dir
except: # noqa: E722
pass


def discover_git_repo_location() -> Optional[str]:
potential_initial_path = os.path.dirname(os.path.abspath(get_path_executed_script()))
return get_git_repo_path(initial_path=potential_initial_path)
28 changes: 20 additions & 8 deletions src/neptune/metadata_containers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@
verify_collection_type,
verify_type,
)
from neptune.internal.utils.git import (
discover_git_repo_location,
get_git_info,
)
from neptune.internal.utils.git import to_git_info
from neptune.internal.utils.hashing import generate_hash
from neptune.internal.utils.limits import custom_run_id_exceeds_length
from neptune.internal.utils.ping_background_job import PingBackgroundJob
Expand All @@ -87,8 +84,12 @@
from neptune.internal.utils.traceback_job import TracebackJob
from neptune.internal.websockets.websocket_signals_background_job import WebsocketSignalsBackgroundJob
from neptune.metadata_containers import MetadataContainer
from neptune.types import (
GitRef,
StringSeries,
)
from neptune.types.atoms.git_ref import GitRefDisabled
from neptune.types.mode import Mode
from neptune.types.series.string_series import StringSeries


class Run(MetadataContainer):
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
capture_traceback: bool = True,
git_ref: Optional[Union[GitRef, GitRefDisabled]] = None,
**kwargs,
):
"""Starts a new tracked run and adds it to the top of the runs table.
Expand Down Expand Up @@ -234,6 +236,10 @@ def __init__(
Defaults to True.
The tracked metadata is stored in the '<monitoring_namespace>/traceback' namespace (see the
'monitoring_namespace' parameter).
git_ref: GitRef object containing information about the Git repository path.
If None, Neptune looks for a repository in the path of the script that is executed.
To specify a different location, set to GitRef(repository_path="path/to/repo").
To turn off Git tracking for the run, set to GitRef.DISABLED.
Returns:
Run object that is used to manage the tracked run and log metadata to it.
Expand All @@ -248,11 +254,12 @@ def __init__(
... # (creates a run in the project specified by the NEPTUNE_PROJECT environment variable)
... run = neptune.init_run()
>>> # Create a tracked run with a name and description, and no sources files uploaded
>>> # Create a run with a name and description, with no sources files or Git info tracked:
>>> run = neptune.init_run(
... name="neural-net-mnist",
... description="neural net trained on MNIST",
... source_files=[],
... git_ref=GitRef.DISABLED,
... )
>>> # Log all .py files from all subdirectories, excluding hidden files
Expand All @@ -270,6 +277,7 @@ def __init__(
... source_files=["training_with_pytorch.py", "net.py"],
... monitoring_namespace="system_metrics",
... capture_stderr=False,
... git_ref=GitRef(repository_path="/Users/Jackie/repos/cls_project"),
... )
Connecting to an existing run:
Expand Down Expand Up @@ -299,6 +307,8 @@ def __init__(
verify_type("fail_on_exception", fail_on_exception, bool)
verify_type("monitoring_namespace", monitoring_namespace, (str, type(None)))
verify_type("capture_traceback", capture_traceback, bool)
verify_type("capture_traceback", capture_traceback, bool)
verify_type("git_ref", git_ref, (GitRef, str, type(None)))
if tags is not None:
if isinstance(tags, str):
tags = [tags]
Expand All @@ -321,6 +331,7 @@ def __init__(
self._source_files: Optional[List[str]] = source_files
self._fail_on_exception: bool = fail_on_exception
self._capture_traceback: bool = capture_traceback
self._git_ref: Optional[GitRef, GitRefDisabled] = git_ref

self._monitoring_namespace: str = (
monitoring_namespace
Expand Down Expand Up @@ -367,7 +378,8 @@ def _get_or_create_api_object(self) -> ApiExperiment:
if self._mode == Mode.READ_ONLY:
raise NeedExistingRunForReadOnlyMode()

git_ref = get_git_info(discover_git_repo_location())
git_ref = self._git_ref or GitRef()
git_info = to_git_info(git_ref=git_ref)

custom_run_id = self._custom_run_id
if custom_run_id_exceeds_length(self._custom_run_id):
Expand All @@ -377,7 +389,7 @@ def _get_or_create_api_object(self) -> ApiExperiment:

return self._backend.create_run(
project_id=self._project_api_object.id,
git_ref=git_ref,
git_info=git_info,
custom_run_id=custom_run_id,
notebook_id=notebook_id,
checkpoint_id=checkpoint_id,
Expand Down
41 changes: 27 additions & 14 deletions src/neptune/types/atoms/git_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["GitRef"]
__all__ = ["GitRef", "GitRefDisabled"]

from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
List,
NewType,
Optional,
TypeVar,
Union,
)

from neptune.types.atoms.atom import Atom
from neptune.vendor.lib_programname import get_path_executed_script

if TYPE_CHECKING:
from neptune.types.value_visitor import ValueVisitor

Ret = TypeVar("Ret")
GitRefDisabled = NewType("GitRefDisabled", str)


class WithDisabledMixin:
DISABLED: GitRefDisabled = GitRefDisabled("DO_NOT_TRACK_GIT_REPOSITORY")
"""Constant that can be used to disable Git repository tracking."""


@dataclass
class GitRef(Atom):
class GitRef(Atom, WithDisabledMixin):
"""
Represents Git repository metadata.
commit_id: str
message: str
author_name: str
author_email: str
commit_date: datetime
dirty: bool
branch: Optional[str]
remotes: Optional[List[str]]
Args:
repository_path: Path to the repository. If not provided,
the path to the script that is currently executed is used.
"""

repository_path: Optional[Union[str, Path]] = get_path_executed_script()

def accept(self, visitor: "ValueVisitor[Ret]") -> Ret:
return visitor.visit_git_ref(self)

def __str__(self):
return "GitRef({})".format(str(self.commit_id))
def __str__(self) -> str:
return f"GitRef({self.repository_path})"

def resolve_path(self) -> Optional[Path]:
if self.repository_path is None:
return None
return Path(self.repository_path).resolve()
Loading

0 comments on commit 54801d2

Please sign in to comment.