From d802f0eb3a0f9011423f5a6894f212b3aa824cc8 Mon Sep 17 00:00:00 2001 From: Hugh Sorby Date: Wed, 16 Oct 2024 15:22:57 +1300 Subject: [PATCH 1/2] Improve log message output. Refactor finding first coordinate field. --- src/scaffoldfitter/fitter.py | 69 ++++++++++++++++++-------- src/scaffoldfitter/fitterexceptions.py | 3 ++ src/scaffoldfitter/fitterstepfit.py | 6 +-- 3 files changed, 53 insertions(+), 25 deletions(-) create mode 100644 src/scaffoldfitter/fitterexceptions.py diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 7546195..5687378 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -14,6 +14,8 @@ from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup from cmlibs.zinc.result import RESULT_OK, RESULT_WARNING_PART_DONE + +from scaffoldfitter.fitterexceptions import FitterModelCoordinateField from scaffoldfitter.fitterstep import FitterStep from scaffoldfitter.fitterstepconfig import FitterStepConfig from scaffoldfitter.fitterstepfit import FitterStepFit @@ -480,7 +482,8 @@ def _loadData(self): sir.createStreamresourceMemoryBuffer(buffer) result = self._region.read(sir) if result != RESULT_OK: - self.printLog() + print("Node to datapoints log:") + self.print_log() raise AssertionError("Failed to load nodes as datapoints") # transfer datapoints to self._region sir = self._rawDataRegion.createStreaminformationRegion() @@ -493,7 +496,7 @@ def _loadData(self): sir.createStreamresourceMemoryBuffer(buffer) result = self._region.read(sir) if result != RESULT_OK: - self.printLog() + self.print_log() raise AssertionError("Failed to load datapoints, result " + str(result)) self._discoverDataCoordinatesField() self._discoverMarkerGroup() @@ -1028,6 +1031,27 @@ def setModelCoordinatesField(self, modelCoordinatesField: Field): def setModelCoordinatesFieldByName(self, modelCoordinatesFieldName): self.setModelCoordinatesField(self._fieldmodule.findFieldByName(modelCoordinatesFieldName)) + def _find_first_coordinate_type_field(self): + field = None + + mesh = self.getHighestDimensionMesh() + element = mesh.createElementiterator().next() + if element.isValid(): + fieldcache = self._fieldmodule.createFieldcache() + fieldcache.setElement(element) + fielditer = self._fieldmodule.createFielditerator() + field = fielditer.next() + while field.isValid(): + if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \ + (field.castFiniteElement().isValid()): + if field.isDefinedAtLocation(fieldcache): + break + field = fielditer.next() + else: + field = None + + return field + def _discoverModelCoordinatesField(self): """ Choose default modelCoordinates field. @@ -1037,24 +1061,14 @@ def _discoverModelCoordinatesField(self): field = None if self._modelCoordinatesFieldName: field = self._fieldmodule.findFieldByName(self._modelCoordinatesFieldName) - else: - mesh = self.getHighestDimensionMesh() - element = mesh.createElementiterator().next() - if element.isValid(): - fieldcache = self._fieldmodule.createFieldcache() - fieldcache.setElement(element) - fielditer = self._fieldmodule.createFielditerator() - field = fielditer.next() - while field.isValid(): - if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \ - (field.castFiniteElement().isValid()): - if field.isDefinedAtLocation(fieldcache): - break - field = fielditer.next() - else: - field = None - if field: + + if field is None or not field.isValid(): + field = self._find_first_coordinate_type_field() + + if field and field.isValid(): self.setModelCoordinatesField(field) + else: + raise FitterModelCoordinateField("No coordinate field found for model.") def getModelFitGroup(self): return self._modelFitGroup @@ -1521,11 +1535,22 @@ def evaluateNodeGroupMeanCoordinates(self, groupName, coordinatesFieldName, isDa coordinates = self._fieldmodule.findFieldByName(coordinatesFieldName) return evaluateFieldNodesetMean(coordinates, nodesetGroup) - def printLog(self): + def _log_message_type_to_text(self, message_type): + # 'MESSAGE_TYPE_ERROR', 'MESSAGE_TYPE_INFORMATION', 'MESSAGE_TYPE_INVALID', 'MESSAGE_TYPE_WARNING' + if self._logger.MESSAGE_TYPE_ERROR == message_type: + return "Error" + if self._logger.MESSAGE_TYPE_INFORMATION == message_type: + return "Information" + if self._logger.MESSAGE_TYPE_WARNING == message_type: + return "Warning" + + return "Invalid" + + def print_log(self): loggerMessageCount = self._logger.getNumberOfMessages() if loggerMessageCount > 0: for i in range(1, loggerMessageCount + 1): - print(self._logger.getMessageTypeAtIndex(i), self._logger.getMessageTextAtIndex(i)) + print(f"[Message {i}] {self._log_message_type_to_text(self._logger.getMessageTypeAtIndex(i))}: {self._logger.getMessageTextAtIndex(i)}") self._logger.removeAllMessages() def getDiagnosticLevel(self): @@ -1562,7 +1587,7 @@ def writeModel(self, modelFileName): if self._modelFitGroup: sir.setResourceGroupName(srf, self._modelFitGroup.getName()) result = self._region.write(sir) - # self.printLog() + # self.print_log() # restore original name self._modelCoordinatesField.setName(self._modelCoordinatesFieldName) diff --git a/src/scaffoldfitter/fitterexceptions.py b/src/scaffoldfitter/fitterexceptions.py new file mode 100644 index 0000000..2d18eb0 --- /dev/null +++ b/src/scaffoldfitter/fitterexceptions.py @@ -0,0 +1,3 @@ + +class FitterModelCoordinateField(Exception): + pass diff --git a/src/scaffoldfitter/fitterstepfit.py b/src/scaffoldfitter/fitterstepfit.py index a80dff2..82421c9 100644 --- a/src/scaffoldfitter/fitterstepfit.py +++ b/src/scaffoldfitter/fitterstepfit.py @@ -385,7 +385,7 @@ def run(self, modelFileNameStem=None): fieldcache, flattenGroupObjective.getNumberOfComponents()) print(" Flatten group objective", objectiveFormat.format(objective)) if self.getDiagnosticLevel() > 1: - self._fitter.printLog() + self._fitter.print_log() if self._updateReferenceState: self._fitter.updateModelReferenceCoordinates() @@ -449,7 +449,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc # convert to local fibre directions, with possible dimension reduction for 2D, 1D fibreAxes = fieldmodule.createFieldFibreAxes(fibreField, modelReferenceCoordinates) if not fibreAxes.isValid(): - self.getFitter().printLog() + self.getFitter().print_log() if dimension == 3: fibreAxesT = fieldmodule.createFieldTranspose(3, fibreAxes) elif dimension == 2: @@ -506,7 +506,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc deformationTerm = \ (deformationTerm + wtSqDeformationGradient2) if deformationTerm else wtSqDeformationGradient2 if not deformationTerm.isValid(): - self.getFitter().printLog() + self.getFitter().print_log() raise AssertionError("Scaffoldfitter: Failed to get deformation term") deformationPenaltyObjective = fieldmodule.createFieldMeshIntegral( From dc43bd5b949df777b936a3ab4b993d6c8d7fcca9 Mon Sep 17 00:00:00 2001 From: Hugh Sorby Date: Thu, 17 Oct 2024 14:45:42 +1300 Subject: [PATCH 2/2] Convert scaffold marker points to marker data points when using a scaffold as data. --- src/scaffoldfitter/fitter.py | 140 ++++++++++++++++++++++++++--------- 1 file changed, 107 insertions(+), 33 deletions(-) diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 5687378..8efd0b2 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -10,6 +10,7 @@ from cmlibs.utils.zinc.finiteelement import evaluateFieldNodesetMean, evaluateFieldNodesetRange, \ findNodeWithName, getMaximumNodeIdentifier from cmlibs.utils.zinc.general import ChangeManager +from cmlibs.utils.zinc.region import write_to_buffer, read_from_buffer from cmlibs.zinc.context import Context from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup @@ -21,6 +22,14 @@ from scaffoldfitter.fitterstepfit import FitterStepFit +def _next_available_identifier(node_set, candidate): + node = node_set.findNodeByIdentifier(candidate) + while node.isValid(): + candidate += 1 + node = node_set.findNodeByIdentifier(candidate) + return candidate + + class Fitter: def __init__(self, zincModelFileName: str, zincDataFileName: str): @@ -413,6 +422,84 @@ def _defineCommonDataFields(self): self._activeDataProjectionGroupFields.append(activeDataProjectionGroupField) self._activeDataProjectionMeshGroups.append(activeDataProjectionGroupField.getOrCreateMeshGroup(mesh)) + def _convert_marker_points(self): + """ + Convert any scaffold marker points to data marker points. + Removing any coordinate fields that are only defined on the scaffold marker points + and not defined on the mesh. + Assigns material coordinates to the scaffold marker points when converting to + data marker points if not defined. + Assumes 'coordinates' is not a material coordinate field. + """ + fm = self._rawDataRegion.getFieldmodule() + fc = fm.createFieldcache() + field_iter = fm.createFielditerator() + + markerGroupName = self._markerGroupName if self._markerGroupName else "marker" + markerGroup = fm.findFieldByName(markerGroupName).castGroup() + if not markerGroup.isValid(): + return + + nodes = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + marker_nodeset_group = markerGroup.getNodesetGroup(nodes) + if not marker_nodeset_group.isValid(): + return + + marker_location_field = fm.findFieldByName("marker_location") + if not marker_location_field.isValid(): + return + + marker_node = marker_nodeset_group.createNodeiterator().next() + fc.setNode(marker_node) + field = field_iter.next() + defined_coordinate_fields = [] + potential_material_coordinate_fields = [] + while field.isValid(): + is_coordinate_field = field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and field.castFiniteElement().isValid() + if field.isDefinedAtLocation(fc) and is_coordinate_field: + defined_coordinate_fields.append(field) + if (field.getName() != "coordinates") and is_coordinate_field: + potential_material_coordinate_fields.append(field) + field = field_iter.next() + + undefined_potential_material_coordinate_fields = [] + for f in potential_material_coordinate_fields: + if f not in defined_coordinate_fields: + undefined_potential_material_coordinate_fields.append(f) + + if self._diagnosticLevel > 0: + print(f"Converting markers has {len(undefined_potential_material_coordinate_fields)} undefined potential material coordinate fields.") + print(f"The markers have {len(defined_coordinate_fields)} coordinate field(s) defined.") + + if len(undefined_potential_material_coordinate_fields) > 0: + if self._diagnosticLevel > 0: + print(f"Defining {[f.getName() for f in undefined_potential_material_coordinate_fields]} on markers.") + print(f"Un-defining {[f.getName() for f in defined_coordinate_fields]} on markers.") + + pending_assignments = [] + node_template = nodes.createNodetemplate() + for f in undefined_potential_material_coordinate_fields: + node_template.defineField(f) + host_coordinates = fm.createFieldEmbedded(f, marker_location_field) + field_assignment = f.createFieldassignment(host_coordinates) + field_assignment.setNodeset(marker_nodeset_group) + pending_assignments.append(field_assignment) + for f in defined_coordinate_fields: + node_template.undefineField(f) + f.setManaged(False) + + node_iter = marker_nodeset_group.createNodeiterator() + node = node_iter.next() + while node.isValid(): + node.merge(node_template) + node = node_iter.next() + + for f in pending_assignments: + f.assign() + + if self._diagnosticLevel > 0: + print(f"Assigned {len(pending_assignments)} coordinate field(s).") + def _loadData(self): """ Load zinc data file into self._rawDataRegion. @@ -455,46 +542,33 @@ def _loadData(self): if nodes.getSize() > 0: datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) if datapoints.getSize() > 0: - maximumDatapointIdentifier = max(0, getMaximumNodeIdentifier(datapoints)) - maximumNodeIdentifier = max(0, getMaximumNodeIdentifier(nodes)) - # this assumes identifiers are in low ranges and can be improved if there is a problem: - identifierOffset = 100000 - while (maximumDatapointIdentifier > identifierOffset) or (maximumNodeIdentifier > identifierOffset): - assert identifierOffset < 1000000000, "Invalid node and datapoint identifier ranges" - identifierOffset *= 10 - while True: - # logic relies on datapoints being in identifier order - datapoint = datapoints.createNodeiterator().next() - identifier = datapoint.getIdentifier() - if identifier >= identifierOffset: - break - result = datapoint.setIdentifier(identifier + identifierOffset) - assert result == RESULT_OK, "Failed to offset datapoint identifier" + datapoint_iterator = datapoints.createNodeiterator() + datapoint = datapoint_iterator.next() + latest = 1 + datapoint_new_identifier_map = {} + while datapoint.isValid(): + identifier = _next_available_identifier(nodes, latest) + datapoint_new_identifier_map[identifier] = datapoint + latest = identifier + 1 + datapoint = datapoint_iterator.next() + + for new_identifier, datapoint in datapoint_new_identifier_map.items(): + datapoint.setIdentifier(new_identifier) + + self._convert_marker_points() # transfer nodes as datapoints to self._region - sir = self._rawDataRegion.createStreaminformationRegion() - srm = sir.createStreamresourceMemory() - sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_NODES) - self._rawDataRegion.write(sir) - result, buffer = srm.getBuffer() - assert result == RESULT_OK, "Failed to write nodes" + buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_NODES) + assert buffer is not None, "Failed to write nodes" buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8")) - sir = self._region.createStreaminformationRegion() - sir.createStreamresourceMemoryBuffer(buffer) - result = self._region.read(sir) + result = read_from_buffer(self._region, buffer) if result != RESULT_OK: print("Node to datapoints log:") self.print_log() raise AssertionError("Failed to load nodes as datapoints") # transfer datapoints to self._region - sir = self._rawDataRegion.createStreaminformationRegion() - srm = sir.createStreamresourceMemory() - sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_DATAPOINTS) - self._rawDataRegion.write(sir) - result, buffer = srm.getBuffer() - assert result == RESULT_OK, "Failed to write datapoints" - sir = self._region.createStreaminformationRegion() - sir.createStreamresourceMemoryBuffer(buffer) - result = self._region.read(sir) + buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_DATAPOINTS) + assert buffer is not None, "Failed to write datapoints" + result = read_from_buffer(self._region, buffer) if result != RESULT_OK: self.print_log() raise AssertionError("Failed to load datapoints, result " + str(result))