Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for logging and dealing with scaffold markers #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 154 additions & 55 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@
from cmlibs.utils.zinc.finiteelement import evaluateFieldNodesetMean, evaluateFieldNodesetRange, \
findNodeWithName, getMaximumNodeIdentifier
from cmlibs.utils.zinc.general import ChangeManager
from cmlibs.utils.zinc.region import write_to_buffer, read_from_buffer
from cmlibs.zinc.context import Context
from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate
from cmlibs.zinc.field import Field, FieldFindMeshLocation, FieldGroup
from cmlibs.zinc.result import RESULT_OK, RESULT_WARNING_PART_DONE

from scaffoldfitter.fitterexceptions import FitterModelCoordinateField
from scaffoldfitter.fitterstep import FitterStep
from scaffoldfitter.fitterstepconfig import FitterStepConfig
from scaffoldfitter.fitterstepfit import FitterStepFit


def _next_available_identifier(node_set, candidate):
node = node_set.findNodeByIdentifier(candidate)
while node.isValid():
candidate += 1
node = node_set.findNodeByIdentifier(candidate)
return candidate


class Fitter:

def __init__(self, zincModelFileName: str, zincDataFileName: str):
Expand Down Expand Up @@ -411,6 +422,84 @@ def _defineCommonDataFields(self):
self._activeDataProjectionGroupFields.append(activeDataProjectionGroupField)
self._activeDataProjectionMeshGroups.append(activeDataProjectionGroupField.getOrCreateMeshGroup(mesh))

def _convert_marker_points(self):
"""
Convert any scaffold marker points to data marker points.
Removing any coordinate fields that are only defined on the scaffold marker points
and not defined on the mesh.
Assigns material coordinates to the scaffold marker points when converting to
data marker points if not defined.
Assumes 'coordinates' is not a material coordinate field.
"""
fm = self._rawDataRegion.getFieldmodule()
fc = fm.createFieldcache()
field_iter = fm.createFielditerator()

markerGroupName = self._markerGroupName if self._markerGroupName else "marker"
markerGroup = fm.findFieldByName(markerGroupName).castGroup()
if not markerGroup.isValid():
return

nodes = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)
marker_nodeset_group = markerGroup.getNodesetGroup(nodes)
if not marker_nodeset_group.isValid():
return

marker_location_field = fm.findFieldByName("marker_location")
if not marker_location_field.isValid():
return

marker_node = marker_nodeset_group.createNodeiterator().next()
fc.setNode(marker_node)
field = field_iter.next()
defined_coordinate_fields = []
potential_material_coordinate_fields = []
while field.isValid():
is_coordinate_field = field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and field.castFiniteElement().isValid()
if field.isDefinedAtLocation(fc) and is_coordinate_field:
defined_coordinate_fields.append(field)
if (field.getName() != "coordinates") and is_coordinate_field:
potential_material_coordinate_fields.append(field)
field = field_iter.next()

undefined_potential_material_coordinate_fields = []
for f in potential_material_coordinate_fields:
if f not in defined_coordinate_fields:
undefined_potential_material_coordinate_fields.append(f)

if self._diagnosticLevel > 0:
print(f"Converting markers has {len(undefined_potential_material_coordinate_fields)} undefined potential material coordinate fields.")
print(f"The markers have {len(defined_coordinate_fields)} coordinate field(s) defined.")

if len(undefined_potential_material_coordinate_fields) > 0:
if self._diagnosticLevel > 0:
print(f"Defining {[f.getName() for f in undefined_potential_material_coordinate_fields]} on markers.")
print(f"Un-defining {[f.getName() for f in defined_coordinate_fields]} on markers.")

pending_assignments = []
node_template = nodes.createNodetemplate()
for f in undefined_potential_material_coordinate_fields:
node_template.defineField(f)
host_coordinates = fm.createFieldEmbedded(f, marker_location_field)
field_assignment = f.createFieldassignment(host_coordinates)
field_assignment.setNodeset(marker_nodeset_group)
pending_assignments.append(field_assignment)
for f in defined_coordinate_fields:
node_template.undefineField(f)
f.setManaged(False)

node_iter = marker_nodeset_group.createNodeiterator()
node = node_iter.next()
while node.isValid():
node.merge(node_template)
node = node_iter.next()

for f in pending_assignments:
f.assign()

if self._diagnosticLevel > 0:
print(f"Assigned {len(pending_assignments)} coordinate field(s).")

def _loadData(self):
"""
Load zinc data file into self._rawDataRegion.
Expand Down Expand Up @@ -453,47 +542,35 @@ def _loadData(self):
if nodes.getSize() > 0:
datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
if datapoints.getSize() > 0:
maximumDatapointIdentifier = max(0, getMaximumNodeIdentifier(datapoints))
maximumNodeIdentifier = max(0, getMaximumNodeIdentifier(nodes))
# this assumes identifiers are in low ranges and can be improved if there is a problem:
identifierOffset = 100000
while (maximumDatapointIdentifier > identifierOffset) or (maximumNodeIdentifier > identifierOffset):
assert identifierOffset < 1000000000, "Invalid node and datapoint identifier ranges"
identifierOffset *= 10
while True:
# logic relies on datapoints being in identifier order
datapoint = datapoints.createNodeiterator().next()
identifier = datapoint.getIdentifier()
if identifier >= identifierOffset:
break
result = datapoint.setIdentifier(identifier + identifierOffset)
assert result == RESULT_OK, "Failed to offset datapoint identifier"
datapoint_iterator = datapoints.createNodeiterator()
datapoint = datapoint_iterator.next()
latest = 1
datapoint_new_identifier_map = {}
while datapoint.isValid():
identifier = _next_available_identifier(nodes, latest)
datapoint_new_identifier_map[identifier] = datapoint
latest = identifier + 1
datapoint = datapoint_iterator.next()

for new_identifier, datapoint in datapoint_new_identifier_map.items():
datapoint.setIdentifier(new_identifier)

self._convert_marker_points()
# transfer nodes as datapoints to self._region
sir = self._rawDataRegion.createStreaminformationRegion()
srm = sir.createStreamresourceMemory()
sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_NODES)
self._rawDataRegion.write(sir)
result, buffer = srm.getBuffer()
assert result == RESULT_OK, "Failed to write nodes"
buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_NODES)
assert buffer is not None, "Failed to write nodes"
buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8"))
sir = self._region.createStreaminformationRegion()
sir.createStreamresourceMemoryBuffer(buffer)
result = self._region.read(sir)
result = read_from_buffer(self._region, buffer)
if result != RESULT_OK:
self.printLog()
print("Node to datapoints log:")
self.print_log()
raise AssertionError("Failed to load nodes as datapoints")
# transfer datapoints to self._region
sir = self._rawDataRegion.createStreaminformationRegion()
srm = sir.createStreamresourceMemory()
sir.setResourceDomainTypes(srm, Field.DOMAIN_TYPE_DATAPOINTS)
self._rawDataRegion.write(sir)
result, buffer = srm.getBuffer()
assert result == RESULT_OK, "Failed to write datapoints"
sir = self._region.createStreaminformationRegion()
sir.createStreamresourceMemoryBuffer(buffer)
result = self._region.read(sir)
buffer = write_to_buffer(self._rawDataRegion, resource_domain_type=Field.DOMAIN_TYPE_DATAPOINTS)
assert buffer is not None, "Failed to write datapoints"
result = read_from_buffer(self._region, buffer)
if result != RESULT_OK:
self.printLog()
self.print_log()
raise AssertionError("Failed to load datapoints, result " + str(result))
self._discoverDataCoordinatesField()
self._discoverMarkerGroup()
Expand Down Expand Up @@ -1028,6 +1105,27 @@ def setModelCoordinatesField(self, modelCoordinatesField: Field):
def setModelCoordinatesFieldByName(self, modelCoordinatesFieldName):
self.setModelCoordinatesField(self._fieldmodule.findFieldByName(modelCoordinatesFieldName))

def _find_first_coordinate_type_field(self):
field = None

mesh = self.getHighestDimensionMesh()
element = mesh.createElementiterator().next()
if element.isValid():
fieldcache = self._fieldmodule.createFieldcache()
fieldcache.setElement(element)
fielditer = self._fieldmodule.createFielditerator()
field = fielditer.next()
while field.isValid():
if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \
(field.castFiniteElement().isValid()):
if field.isDefinedAtLocation(fieldcache):
break
field = fielditer.next()
else:
field = None

return field

def _discoverModelCoordinatesField(self):
"""
Choose default modelCoordinates field.
Expand All @@ -1037,24 +1135,14 @@ def _discoverModelCoordinatesField(self):
field = None
if self._modelCoordinatesFieldName:
field = self._fieldmodule.findFieldByName(self._modelCoordinatesFieldName)
else:
mesh = self.getHighestDimensionMesh()
element = mesh.createElementiterator().next()
if element.isValid():
fieldcache = self._fieldmodule.createFieldcache()
fieldcache.setElement(element)
fielditer = self._fieldmodule.createFielditerator()
field = fielditer.next()
while field.isValid():
if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \
(field.castFiniteElement().isValid()):
if field.isDefinedAtLocation(fieldcache):
break
field = fielditer.next()
else:
field = None
if field:

if field is None or not field.isValid():
field = self._find_first_coordinate_type_field()

if field and field.isValid():
self.setModelCoordinatesField(field)
else:
raise FitterModelCoordinateField("No coordinate field found for model.")

def getModelFitGroup(self):
return self._modelFitGroup
Expand Down Expand Up @@ -1521,11 +1609,22 @@ def evaluateNodeGroupMeanCoordinates(self, groupName, coordinatesFieldName, isDa
coordinates = self._fieldmodule.findFieldByName(coordinatesFieldName)
return evaluateFieldNodesetMean(coordinates, nodesetGroup)

def printLog(self):
def _log_message_type_to_text(self, message_type):
# 'MESSAGE_TYPE_ERROR', 'MESSAGE_TYPE_INFORMATION', 'MESSAGE_TYPE_INVALID', 'MESSAGE_TYPE_WARNING'
if self._logger.MESSAGE_TYPE_ERROR == message_type:
return "Error"
if self._logger.MESSAGE_TYPE_INFORMATION == message_type:
return "Information"
if self._logger.MESSAGE_TYPE_WARNING == message_type:
return "Warning"

return "Invalid"

def print_log(self):
loggerMessageCount = self._logger.getNumberOfMessages()
if loggerMessageCount > 0:
for i in range(1, loggerMessageCount + 1):
print(self._logger.getMessageTypeAtIndex(i), self._logger.getMessageTextAtIndex(i))
print(f"[Message {i}] {self._log_message_type_to_text(self._logger.getMessageTypeAtIndex(i))}: {self._logger.getMessageTextAtIndex(i)}")
self._logger.removeAllMessages()

def getDiagnosticLevel(self):
Expand Down Expand Up @@ -1562,7 +1661,7 @@ def writeModel(self, modelFileName):
if self._modelFitGroup:
sir.setResourceGroupName(srf, self._modelFitGroup.getName())
result = self._region.write(sir)
# self.printLog()
# self.print_log()

# restore original name
self._modelCoordinatesField.setName(self._modelCoordinatesFieldName)
Expand Down
3 changes: 3 additions & 0 deletions src/scaffoldfitter/fitterexceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

class FitterModelCoordinateField(Exception):
pass
6 changes: 3 additions & 3 deletions src/scaffoldfitter/fitterstepfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def run(self, modelFileNameStem=None):
fieldcache, flattenGroupObjective.getNumberOfComponents())
print(" Flatten group objective", objectiveFormat.format(objective))
if self.getDiagnosticLevel() > 1:
self._fitter.printLog()
self._fitter.print_log()

if self._updateReferenceState:
self._fitter.updateModelReferenceCoordinates()
Expand Down Expand Up @@ -449,7 +449,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc
# convert to local fibre directions, with possible dimension reduction for 2D, 1D
fibreAxes = fieldmodule.createFieldFibreAxes(fibreField, modelReferenceCoordinates)
if not fibreAxes.isValid():
self.getFitter().printLog()
self.getFitter().print_log()
if dimension == 3:
fibreAxesT = fieldmodule.createFieldTranspose(3, fibreAxes)
elif dimension == 2:
Expand Down Expand Up @@ -506,7 +506,7 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc
deformationTerm = \
(deformationTerm + wtSqDeformationGradient2) if deformationTerm else wtSqDeformationGradient2
if not deformationTerm.isValid():
self.getFitter().printLog()
self.getFitter().print_log()
raise AssertionError("Scaffoldfitter: Failed to get deformation term")

deformationPenaltyObjective = fieldmodule.createFieldMeshIntegral(
Expand Down