Skip to content

Commit

Permalink
fix: make sure to migrate previous versions (#683)
Browse files Browse the repository at this point in the history
* fix: make sure to migrate previous versions

* format
  • Loading branch information
tobymao authored Apr 11, 2023
1 parent 3bbf353 commit cd79b03
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
3 changes: 3 additions & 0 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class SnapshotDataVersion(PydanticModel, frozen=True):
version: str
change_category: t.Optional[SnapshotChangeCategory]

def snapshot_id(self, name: str) -> SnapshotId:
return SnapshotId(name=name, identifier=self.fingerprint.to_identifier())

@property
def data_version(self) -> SnapshotDataVersion:
return self
Expand Down
83 changes: 65 additions & 18 deletions sqlmesh/core/state_sync/engine_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@
import json
import logging
import typing as t
from copy import deepcopy

from sqlglot import __version__ as SQLGLOT_VERSION
from sqlglot import exp

from sqlmesh.core.audit import Audit
from sqlmesh.core.dialect import select_from_values
from sqlmesh.core.engine_adapter import EngineAdapter, TransactionType
from sqlmesh.core.environment import Environment
from sqlmesh.core.model import Model
from sqlmesh.core.snapshot import (
Snapshot,
SnapshotDataVersion,
SnapshotFingerprint,
SnapshotId,
SnapshotIdLike,
SnapshotNameVersionLike,
fingerprint_from_model,
)
from sqlmesh.core.snapshot.definition import _parents_from_model
from sqlmesh.core.state_sync.base import SCHEMA_VERSION, StateSync, Versions
from sqlmesh.core.state_sync.common import CommonStateSyncMixin, transactional
from sqlmesh.utils.date import now_timestamp
Expand Down Expand Up @@ -404,10 +411,10 @@ def _migrate_rows(self) -> None:
for snapshot in all_snapshots.values():
seen = set()
queue = {snapshot.snapshot_id}
env: t.Dict[str, t.Dict] = {
"models": {},
"audits": {},
}
model = snapshot.model
models: t.Dict[str, Model] = {}
audits: t.Dict[str, Audit] = {}
env: t.Dict[str, t.Dict] = {"models": models, "audits": audits}

while queue:
snapshot_id = queue.pop()
Expand All @@ -426,37 +433,77 @@ def _migrate_rows(self) -> None:
cached_env = cache.get(snapshot_id)

if cached_env:
env["models"].update(cached_env["models"])
env["audits"].update(cached_env["audits"])
models.update(cached_env["models"])
audits.update(cached_env["audits"])
else:
env["models"][s.name] = s.model
models[s.name] = s.model

for audit in s.audits:
env["audits"][audit.name] = audit
audits[audit.name] = audit

cache[snapshot_id] = env

new_snapshot = Snapshot.from_model(
snapshot.model,
new_snapshot = deepcopy(snapshot)

fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}

new_snapshot.fingerprint = fingerprint_from_model(
model,
physical_schema=snapshot.physical_schema,
models=env["models"],
ttl=snapshot.ttl,
version=snapshot.version,
audits=env["audits"],
models=models,
audits=audits,
)

if new_snapshot == snapshot or new_snapshot in all_snapshots:
logger.debug(f"{snapshot.snapshot_id} is unchaged")
new_snapshot.parents = tuple(
SnapshotId(
name=name,
identifier=fingerprint_from_model(
models[name],
physical_schema=snapshot.physical_schema,
models=models,
audits=audits,
cache=fingerprint_cache,
).to_identifier(),
)
for name in _parents_from_model(model, models)
)

if new_snapshot == snapshot:
logger.debug(f"{new_snapshot.snapshot_id} is unchanged.")
continue
if new_snapshot.snapshot_id in all_snapshots:
logger.debug(f"{new_snapshot.snapshot_id} exists.")
continue

new_snapshot.merge_intervals(snapshot)
snapshot_mapping[snapshot.snapshot_id] = new_snapshot
logger.debug(f"{snapshot.snapshot_id} mapped to {new_snapshot.snapshot_id}")
logger.debug(f"{snapshot.snapshot_id} mapped to {new_snapshot.snapshot_id}.")

if not snapshot_mapping:
logger.debug("No changes to snapshots detected.")
return

def map_data_versions(
name: str, versions: t.Sequence[SnapshotDataVersion]
) -> t.Tuple[SnapshotDataVersion, ...]:
version_ids = ((version.snapshot_id(name), version) for version in versions)

return tuple(
snapshot_mapping[version_id].data_version
if version_id in snapshot_mapping
else version
for version_id, version in version_ids
)

for from_snapshot_id, to_snapshot in snapshot_mapping.items():
from_snapshot = all_snapshots[from_snapshot_id]
to_snapshot.previous_versions = map_data_versions(
from_snapshot.name, from_snapshot.previous_versions
)
to_snapshot.indirect_versions = {
name: map_data_versions(name, versions)
for name, versions in from_snapshot.indirect_versions.items()
}

self.delete_snapshots(snapshot_mapping)
self._push_snapshots(snapshot_mapping.values())

Expand Down

0 comments on commit cd79b03

Please sign in to comment.