From 062488f000582d108daf2ef644d79e91fe7dbf58 Mon Sep 17 00:00:00 2001
From: Jacob Walls <jwalls@fargeo.com>
Date: Tue, 19 Nov 2024 12:20:40 -0500
Subject: [PATCH] Move new column to nodegroup

---
 arches/app/models/graph.py                    | 16 ++--
 .../migrations/11613_node_grouping_node.py    | 77 -------------------
 .../migrations/11613_nodegroup_root_node.py   | 69 +++++++++++++++++
 .../11613_nodegroup_root_node_constraints.py  | 33 ++++++++
 arches/app/models/models.py                   | 59 +++++++++-----
 releases/8.0.0.md                             |  2 +-
 tests/models/graph_tests.py                   | 14 +++-
 tests/models/resource_test.py                 |  4 +-
 tests/models/tile_model_tests.py              |  3 +-
 tests/views/graph_manager_tests.py            |  1 -
 10 files changed, 162 insertions(+), 116 deletions(-)
 delete mode 100644 arches/app/models/migrations/11613_node_grouping_node.py
 create mode 100644 arches/app/models/migrations/11613_nodegroup_root_node.py
 create mode 100644 arches/app/models/migrations/11613_nodegroup_root_node_constraints.py

diff --git a/arches/app/models/graph.py b/arches/app/models/graph.py
index 19b08298748..0735fc3ee35 100644
--- a/arches/app/models/graph.py
+++ b/arches/app/models/graph.py
@@ -349,7 +349,6 @@ def add_node(self, node, nodegroups=None):
             node.ontologyclass = nodeobj.get("ontologyclass", "")
             node.datatype = nodeobj.get("datatype", "")
             node.nodegroup_id = nodeobj.get("nodegroup_id", "")
-            node.grouping_node_id = nodeobj.get("grouping_node_id", node.nodegroup_id)
             node.config = nodeobj.get("config", None)
             node.issearchable = nodeobj.get("issearchable", True)
             node.isrequired = nodeobj.get("isrequired", False)
@@ -373,7 +372,7 @@ def add_node(self, node, nodegroups=None):
                 node.nodegroup = self.get_or_create_nodegroup(
                     nodegroupid=node.nodegroup_id
                 )
-                node.grouping_node_id = node.nodegroup_id
+                node.nodegroup.root_node_id = node.nodegroup_id
                 if nodegroups is not None and str(node.nodegroup_id) in nodegroups:
                     node.nodegroup.cardinality = nodegroups[str(node.nodegroup_id)][
                         "cardinality"
@@ -386,7 +385,6 @@ def add_node(self, node, nodegroups=None):
                     ]["parentnodegroup_id"]
             else:
                 node.nodegroup = None
-                node.grouping_node = None
 
         node.graph = self
 
@@ -616,7 +614,6 @@ def save(self, validate=True, nodeid=None):
 
             if nodeid is not None:
                 node = self.nodes[nodeid]
-                node.grouping_node_id = node.nodegroup_id
                 branch_publication_id = node.sourcebranchpublication_id
                 self.update_es_node_mapping(node, datatype_factory, se)
                 self.create_node_alias(node)
@@ -645,7 +642,6 @@ def save(self, validate=True, nodeid=None):
 
             else:
                 for node in self.nodes.values():
-                    node.grouping_node_id = node.nodegroup_id
                     self.update_es_node_mapping(node, datatype_factory, se)
                     node.save()
 
@@ -1120,7 +1116,7 @@ def flatten_tree(tree, node_id_list=[]):
             if is_collector:
                 old_nodegroup_id = node.nodegroup_id
                 node.nodegroup = models.NodeGroup(
-                    pk=node.pk, cardinality=node.nodegroup.cardinality
+                    pk=node.pk, cardinality=node.nodegroup.cardinality, root_node=node
                 )
                 if old_nodegroup_id not in nodegroup_map:
                     nodegroup_map[old_nodegroup_id] = node.nodegroup_id
@@ -2446,6 +2442,7 @@ def _update_source_nodegroup_hierarchy(nodegroup):
 
                 source_nodegroup.cardinality = nodegroup.cardinality
                 source_nodegroup.legacygroupid = nodegroup.legacygroupid
+                source_nodegroup.root_node_id = source_nodegroup.pk
 
                 if nodegroup.parentnodegroup_id:
                     nodegroup_parent_node = models.Node.objects.get(
@@ -2647,14 +2644,13 @@ def _update_source_nodegroup_hierarchy(nodegroup):
                             "graph_id",
                             "nodeid",
                             "nodegroup_id",
-                            "grouping_node_id",
                             "source_identifier_id",
                             "is_collector",
                         ]:
                             setattr(source_node, key, getattr(future_node, key))
 
                     source_node.nodegroup_id = future_node.nodegroup_id
-                    source_node.grouping_node_id = source_node.nodegroup_id
+                    source_node.nodegroup.root_node_id = source_node.nodegroup_id
                     if (
                         future_node_nodegroup_node
                         and future_node_nodegroup_node.source_identifier_id
@@ -2662,7 +2658,7 @@ def _update_source_nodegroup_hierarchy(nodegroup):
                         source_node.nodegroup_id = (
                             future_node_nodegroup_node.source_identifier_id
                         )
-                        source_node.grouping_node_id = source_node.nodegroup_id
+                        source_node.nodegroup.root_node_id = source_node.nodegroup_id
 
                     self.nodes[source_node.pk] = source_node
                 else:  # newly-created node
@@ -2676,7 +2672,7 @@ def _update_source_nodegroup_hierarchy(nodegroup):
                         future_node.nodegroup_id = (
                             future_node_nodegroup_node.source_identifier_id
                         )
-                        future_node.grouping_node_id = future_node.nodegroup_id
+                        future_node.nodegroup.root_node_id = future_node.nodegroup_id
 
                     del editable_future_graph.nodes[future_node.pk]
                     self.nodes[future_node.pk] = future_node
diff --git a/arches/app/models/migrations/11613_node_grouping_node.py b/arches/app/models/migrations/11613_node_grouping_node.py
deleted file mode 100644
index e9318dffe89..00000000000
--- a/arches/app/models/migrations/11613_node_grouping_node.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Generated by Django 5.1.3 on 2024-11-11 07:27
-
-import django.db.models.deletion
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
-    dependencies = [
-        ("models", "10437_node_alias_not_null"),
-    ]
-
-    def set_grouping_node(apps, schema_editor):
-        Node = apps.get_model("models", "Node")
-        all_but_top_nodes = Node.objects.exclude(istopnode=True)
-        for node in all_but_top_nodes:
-            assert node.nodegroup_id is not None, f"Missing nodegroup for {node!r}"
-            node.grouping_node_id = node.nodegroup_id
-        Node.objects.bulk_update(all_but_top_nodes, ["grouping_node_id"])
-
-        PublishedGraph = apps.get_model("models", "PublishedGraph")
-        published_graphs = PublishedGraph.objects.all()
-        for published_graph in published_graphs:
-            for node_dict in published_graph.serialized_graph["nodes"]:
-                node_dict["grouping_node_id"] = node_dict["nodegroup_id"]
-        PublishedGraph.objects.bulk_update(published_graphs, ["serialized_graph"])
-
-    def remove_grouping_node(apps, schema_editor):
-        PublishedGraph = apps.get_model("models", "PublishedGraph")
-        published_graphs = PublishedGraph.objects.all()
-        for published_graph in published_graphs:
-            for node_dict in published_graph.serialized_graph["nodes"]:
-                node_dict.pop("grouping_node_id", None)
-        PublishedGraph.objects.bulk_update(published_graphs, ["serialized_graph"])
-
-    operations = [
-        migrations.AddField(
-            model_name="node",
-            name="grouping_node",
-            field=models.ForeignKey(
-                blank=True,
-                null=True,
-                on_delete=django.db.models.deletion.CASCADE,
-                related_name="sibling_nodes",
-                related_query_name="sibling_node",
-                to="models.node",
-            ),
-        ),
-        migrations.RunPython(set_grouping_node, remove_grouping_node),
-        migrations.AddConstraint(
-            model_name="node",
-            constraint=models.CheckConstraint(
-                condition=models.Q(
-                    ("istopnode", True), ("nodegroup__isnull", False), _connector="OR"
-                ),
-                name="has_nodegroup_or_istopnode",
-            ),
-        ),
-        migrations.AddConstraint(
-            model_name="node",
-            constraint=models.CheckConstraint(
-                condition=models.Q(
-                    ("istopnode", True),
-                    ("grouping_node__isnull", False),
-                    _connector="OR",
-                ),
-                name="has_grouping_node_or_istopnode",
-            ),
-        ),
-        migrations.AddConstraint(
-            model_name="node",
-            constraint=models.CheckConstraint(
-                condition=models.Q(("grouping_node_id", models.F("nodegroup_id"))),
-                name="grouping_node_matches_nodegroup",
-            ),
-        ),
-    ]
diff --git a/arches/app/models/migrations/11613_nodegroup_root_node.py b/arches/app/models/migrations/11613_nodegroup_root_node.py
new file mode 100644
index 00000000000..7381b6e6d41
--- /dev/null
+++ b/arches/app/models/migrations/11613_nodegroup_root_node.py
@@ -0,0 +1,69 @@
+# Generated by Django 5.1.3 on 2024-11-19 09:29
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ("models", "10437_node_alias_not_null"),
+    ]
+
+    def add_root_node(apps, schema_editor):
+        NodeGroup = apps.get_model("models", "NodeGroup")
+        nodegroups_with_nodes = NodeGroup.objects.filter(node__gt=0).distinct()
+        for nodegroup in nodegroups_with_nodes:
+            nodegroup.root_node_id = nodegroup.pk
+        NodeGroup.objects.bulk_update(nodegroups_with_nodes, ["root_node_id"])
+
+        PublishedGraph = apps.get_model("models", "PublishedGraph")
+        published_graphs = PublishedGraph.objects.all()
+        for published_graph in published_graphs:
+            for node_dict in published_graph.serialized_graph["nodes"]:
+                node_dict["root_node_id"] = node_dict["nodegroup_id"]
+        PublishedGraph.objects.bulk_update(published_graphs, ["serialized_graph"])
+
+    def remove_root_node(apps, schema_editor):
+        PublishedGraph = apps.get_model("models", "PublishedGraph")
+        published_graphs = PublishedGraph.objects.all()
+        for published_graph in published_graphs:
+            for node_dict in published_graph.serialized_graph["nodegroups"]:
+                node_dict.pop("root_node_id", None)
+        PublishedGraph.objects.bulk_update(published_graphs, ["serialized_graph"])
+
+    operations = [
+        migrations.AddField(
+            model_name="nodegroup",
+            name="root_node",
+            field=models.OneToOneField(
+                blank=True,
+                db_column="rootnodeid",
+                null=True,
+                on_delete=django.db.models.deletion.SET_NULL,
+                related_name="grouping_node_nodegroup",
+                to="models.node",
+            ),
+        ),
+        migrations.AlterField(
+            model_name="nodegroup",
+            name="cardinality",
+            field=models.CharField(
+                blank=True, choices=[("1", "1"), ("n", "n")], default="1", max_length=1
+            ),
+        ),
+        migrations.AlterField(
+            model_name="nodegroup",
+            name="parentnodegroup",
+            field=models.ForeignKey(
+                blank=True,
+                db_column="parentnodegroupid",
+                null=True,
+                on_delete=django.db.models.deletion.CASCADE,
+                related_name="children",
+                related_query_name="child",
+                to="models.nodegroup",
+            ),
+        ),
+        migrations.RunPython(add_root_node, migrations.RunPython.noop),
+    ]
diff --git a/arches/app/models/migrations/11613_nodegroup_root_node_constraints.py b/arches/app/models/migrations/11613_nodegroup_root_node_constraints.py
new file mode 100644
index 00000000000..c95f0223810
--- /dev/null
+++ b/arches/app/models/migrations/11613_nodegroup_root_node_constraints.py
@@ -0,0 +1,33 @@
+# Generated by Django 5.1.3 on 2024-11-19 09:33
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ("models", "11613_nodegroup_root_node"),
+    ]
+
+    operations = [
+        migrations.AddConstraint(
+            model_name="node",
+            constraint=models.CheckConstraint(
+                condition=models.Q(
+                    ("istopnode", True), ("nodegroup__isnull", False), _connector="OR"
+                ),
+                name="has_nodegroup_or_istopnode",
+            ),
+        ),
+        migrations.AddConstraint(
+            model_name="nodegroup",
+            constraint=models.CheckConstraint(
+                condition=models.Q(
+                    ("root_node", models.F("pk")),
+                    ("root_node__isnull", True),
+                    _connector="OR",
+                ),
+                name="root_node_matches_pk_or_null",
+            ),
+        ),
+    ]
diff --git a/arches/app/models/models.py b/arches/app/models/models.py
index 22bd7c79324..ac0dafb6553 100644
--- a/arches/app/models/models.py
+++ b/arches/app/models/models.py
@@ -742,14 +742,26 @@ class Meta:
 class NodeGroup(models.Model):
     nodegroupid = models.UUIDField(primary_key=True)
     legacygroupid = models.TextField(blank=True, null=True)
-    cardinality = models.TextField(blank=True, default="1")
+    cardinality = models.CharField(
+        max_length=1, blank=True, default="1", choices={"1": "1", "n": "n"}
+    )
     parentnodegroup = models.ForeignKey(
         "self",
         db_column="parentnodegroupid",
         blank=True,
         null=True,
         on_delete=models.CASCADE,
+        related_name="children",
+        related_query_name="child",
     )  # Allows nodegroups within nodegroups
+    root_node = models.OneToOneField(
+        "Node",
+        db_column="rootnodeid",
+        blank=True,
+        null=True,
+        on_delete=models.SET_NULL,
+        related_name="grouping_node_nodegroup",
+    )
 
     def __init__(self, *args, **kwargs):
         super(NodeGroup, self).__init__(*args, **kwargs)
@@ -759,6 +771,12 @@ def __init__(self, *args, **kwargs):
     class Meta:
         managed = True
         db_table = "node_groups"
+        constraints = [
+            models.CheckConstraint(
+                condition=Q(root_node=models.F("pk")) | Q(root_node__isnull=True),
+                name="root_node_matches_pk_or_null",
+            )
+        ]
 
         default_permissions = ()
         permissions = (
@@ -768,6 +786,26 @@ class Meta:
             ("no_access_to_nodegroup", "No Access"),
         )
 
+    @classmethod
+    def check(cls, **kwargs):
+        errors = super().check(**kwargs)
+        errors.extend(cls._check_root_node())
+        return errors
+
+    @classmethod
+    def _check_root_node(cls):
+        return [
+            checks.Error(
+                "Missing root node for nodegroup.",
+                hint="Set the node with the corresponding primary key as the root node of the group.",
+                obj=nodegroup,
+                id="arches.E005",
+            )
+            for nodegroup in cls.objects.filter(
+                node__gt=0, root_node__isnull=True
+            ).distinct()
+        ]
+
 
 class Node(models.Model):
     """
@@ -798,14 +836,6 @@ def __init__(self, *args, **kwargs):
     graph = models.ForeignKey(
         GraphModel, db_column="graphid", blank=True, null=True, on_delete=models.CASCADE
     )
-    grouping_node = models.ForeignKey(
-        "self",
-        blank=True,
-        null=True,
-        on_delete=models.CASCADE,
-        related_name="sibling_nodes",
-        related_query_name="sibling_node",
-    )
     config = I18n_JSONField(blank=True, null=True, db_column="config")
     issearchable = models.BooleanField(default=True)
     isrequired = models.BooleanField(default=False)
@@ -915,7 +945,6 @@ def __init__(self, *args, **kwargs):
     def clean(self):
         if not self.alias:
             Graph.objects.get(pk=self.graph_id).create_node_alias(self)
-        self.grouping_node_id = self.nodegroup_id
         if self.pk == self.source_identifier_id:
             self.source_identifier_id = None
 
@@ -923,8 +952,6 @@ def save(self, **kwargs):
         if not self.alias:
             add_to_update_fields(kwargs, "alias")
             add_to_update_fields(kwargs, "hascustomalias")
-        if self.grouping_node_id != self.nodegroup_id:
-            add_to_update_fields(kwargs, "grouping_node_id")
         if self.pk == self.source_identifier_id:
             add_to_update_fields(kwargs, "source_identifier_id")
 
@@ -946,14 +973,6 @@ class Meta:
                 condition=Q(istopnode=True) | Q(nodegroup__isnull=False),
                 name="has_nodegroup_or_istopnode",
             ),
-            models.CheckConstraint(
-                condition=Q(istopnode=True) | Q(grouping_node__isnull=False),
-                name="has_grouping_node_or_istopnode",
-            ),
-            models.CheckConstraint(
-                condition=Q(grouping_node_id=models.F("nodegroup_id")),
-                name="grouping_node_matches_nodegroup",
-            ),
         ]
 
 
diff --git a/releases/8.0.0.md b/releases/8.0.0.md
index 9af2f9be947..587e3f046bc 100644
--- a/releases/8.0.0.md
+++ b/releases/8.0.0.md
@@ -20,7 +20,7 @@ Arches 8.0.0 Release Notes
 - Add session-based REST APIs for login, logout [#11261](https://github.com/archesproject/arches/issues/11261)
 - Add system check advising next action when enabling additional languages without updating graphs [#10079](https://github.com/archesproject/arches/issues/10079)
 - Improve handling of longer model names [#11317](https://github.com/archesproject/arches/issues/11317)
-- New column `Node.grouping_node`: self-referring foreign key to the collector node [#11613](https://github.com/archesproject/arches/issues/11613)
+- New column `NodeGroup.root_node`: one-to-one field to the collector node [#11613](https://github.com/archesproject/arches/issues/11613)
 - Support more expressive plugin URLs [#11320](https://github.com/archesproject/arches/issues/11320)
 - Make node aliases not nullable [#10437](https://github.com/archesproject/arches/issues/10437)
 - Concepts API no longer responds with empty body for error conditions [#11519](https://github.com/archesproject/arches/issues/11519)
diff --git a/tests/models/graph_tests.py b/tests/models/graph_tests.py
index 4aa2f8a3311..ae06d45d045 100644
--- a/tests/models/graph_tests.py
+++ b/tests/models/graph_tests.py
@@ -119,7 +119,6 @@ def setUpTestData(cls):
                 "istopnode": True,
                 "name": "Node",
                 "nodegroup_id": "20000000-0000-0000-0000-100000000001",
-                "grouping_node_id": "20000000-0000-0000-0000-100000000001",
                 "nodeid": "20000000-0000-0000-0000-100000000001",
                 "ontologyclass": "http://www.cidoc-crm.org/cidoc-crm/E1_CRM_Entity",
             },
@@ -133,7 +132,6 @@ def setUpTestData(cls):
                 "istopnode": False,
                 "name": "Node Type",
                 "nodegroup_id": "20000000-0000-0000-0000-100000000001",
-                "grouping_node_id": "20000000-0000-0000-0000-100000000001",
                 "nodeid": "20000000-0000-0000-0000-100000000002",
                 "ontologyclass": "http://www.cidoc-crm.org/cidoc-crm/E55_Type",
             },
@@ -142,6 +140,10 @@ def setUpTestData(cls):
         for node in nodes:
             models.Node.objects.create(**node).save()
 
+        models.NodeGroup.objects.filter(
+            pk="20000000-0000-0000-0000-100000000001"
+        ).update(root_node_id="20000000-0000-0000-0000-100000000001")
+
         edges_dict = {
             "description": None,
             "domainnode_id": "20000000-0000-0000-0000-100000000001",
@@ -1281,7 +1283,6 @@ def test_update_empty_graph_from_editable_future_graph(self):
                 if key not in [
                     "graph_id",
                     "nodegroup_id",
-                    "grouping_node_id",
                     "nodeid",
                     "source_identifier_id",
                 ]:
@@ -1356,7 +1357,12 @@ def test_update_empty_graph_from_editable_future_graph(self):
 
             # ensures all relevant values are equal between graphs
             for key, value in editable_future_graph_serialized_nodegroup.items():
-                if key not in ["parentnodegroup_id", "nodegroupid", "legacygroupid"]:
+                if key not in [
+                    "parentnodegroup_id",
+                    "nodegroupid",
+                    "root_node_id",
+                    "legacygroupid",
+                ]:
                     if type(value) == "dict":
                         self.assertDictEqual(
                             value, updated_source_graph_serialized_nodegroup[key]
diff --git a/tests/models/resource_test.py b/tests/models/resource_test.py
index 0ff76e7cb81..eca453e8839 100644
--- a/tests/models/resource_test.py
+++ b/tests/models/resource_test.py
@@ -492,7 +492,6 @@ def test_self_referring_resource_instance_descriptor(self):
             pk=nodegroup.pk,
             graph=graph,
             nodegroup=nodegroup,
-            grouping_node_id=nodegroup.pk,
             name="String Node",
             datatype="string",
             istopnode=False,
@@ -500,11 +499,12 @@ def test_self_referring_resource_instance_descriptor(self):
         resource_instance_node = models.Node.objects.create(
             graph=graph,
             nodegroup=nodegroup,
-            grouping_node_id=nodegroup.pk,
             name="Resource Node",
             datatype="resource-instance",
             istopnode=False,
         )
+        nodegroup.root_node = string_node
+        nodegroup.save()
 
         # Configure the primary descriptor to use the string node
         models.FunctionXGraph.objects.create(
diff --git a/tests/models/tile_model_tests.py b/tests/models/tile_model_tests.py
index 5e518cb7df4..339650e74a8 100644
--- a/tests/models/tile_model_tests.py
+++ b/tests/models/tile_model_tests.py
@@ -737,11 +737,12 @@ def test_check_for_missing_nodes(self):
             name="Required file list",
             datatype="file-list",
             nodegroup=node_group,
-            grouping_node_id=node_group.pk,
             isrequired=True,
             istopnode=False,
         )
         required_file_list_node.save()
+        node_group.root_node = required_file_list_node
+        node_group.save()
 
         json = {
             "resourceinstance_id": "40000000-0000-0000-0000-000000000000",
diff --git a/tests/views/graph_manager_tests.py b/tests/views/graph_manager_tests.py
index 569b0c976f4..00129b869aa 100644
--- a/tests/views/graph_manager_tests.py
+++ b/tests/views/graph_manager_tests.py
@@ -138,7 +138,6 @@ def setUpTestData(cls):
                     "istopnode": False,
                     "name": "Node Type",
                     "nodegroup_id": "20000000-0000-0000-0000-100000000001",
-                    "grouping_node_id": "20000000-0000-0000-0000-100000000001",
                     "nodeid": "20000000-0000-0000-0000-100000000002",
                     "ontologyclass": "http://www.cidoc-crm.org/cidoc-crm/E55_Type",
                 },