diff --git a/src/build123d/importers.py b/src/build123d/importers.py index 36eb0a39..319c0dee 100644 --- a/src/build123d/importers.py +++ b/src/build123d/importers.py @@ -69,7 +69,7 @@ from svgpathtools import svg2paths import unicodedata -from build123d.geometry import Color +from build123d.geometry import Color, Location from build123d.topology import ( Compound, Edge, @@ -133,6 +133,7 @@ def import_step(filename: str) -> Compound: """ def get_name(label: TDF_Label) -> str: + """Extract name and format""" name = "" std_name = TDataStd_Name() if label.FindAttribute(TDataStd_Name.GetID_s(), std_name): @@ -142,6 +143,7 @@ def get_name(label: TDF_Label) -> str: return clean_name.translate(str.maketrans(" .()", "____")) def get_color(shape: TopoDS_Shape) -> Quantity_ColorRGBA: + """Get the color - take that of the largest Face if multiple""" def get_col(obj: TopoDS_Shape) -> Quantity_ColorRGBA: col = Quantity_ColorRGBA() @@ -172,37 +174,34 @@ def get_col(obj: TopoDS_Shape) -> Quantity_ColorRGBA: return shape_color - def build_assembly( - assembly: Compound, parent_tdf_label: Optional[TDF_Label] = None - ) -> list[Shape]: - tdf_labels = TDF_LabelSequence() + def build_assembly(parent_tdf_label: Optional[TDF_Label] = None) -> list[Shape]: + """Recursively extract object into an assembly""" + sub_tdf_labels = TDF_LabelSequence() if parent_tdf_label is None: - shape_tool.GetFreeShapes(tdf_labels) + shape_tool.GetFreeShapes(sub_tdf_labels) else: - shape_tool.GetComponents_s(parent_tdf_label, tdf_labels) + shape_tool.GetComponents_s(parent_tdf_label, sub_tdf_labels) sub_shapes: list[Shape] = [] - for i in range(tdf_labels.Length()): - sub_tdf_label = tdf_labels.Value(i + 1) + for i in range(sub_tdf_labels.Length()): + sub_tdf_label = sub_tdf_labels.Value(i + 1) if shape_tool.IsReference_s(sub_tdf_label): ref_tdf_label = TDF_Label() shape_tool.GetReferredShape_s(sub_tdf_label, ref_tdf_label) else: ref_tdf_label = sub_tdf_label - topo_shape = downcast(shape_tool.GetShape_s(ref_tdf_label)) - sub_shape_type = topods_lut[type(topo_shape)] - sub_shape_loc = shape_tool.GetLocation_s(sub_tdf_label) - # The location of this part is relative to its parent - if assembly.wrapped is not None: - sub_shape_loc = assembly.location.wrapped.Multiplied(sub_shape_loc) - sub_shape: Shape = sub_shape_type() - sub_shape.wrapped = downcast(topo_shape.Moved(sub_shape_loc)) - sub_shape.color = Color(get_color(topo_shape)) - sub_shape.label = get_name(ref_tdf_label) - sub_shape.parent = assembly + sub_topo_shape = downcast(shape_tool.GetShape_s(ref_tdf_label)) if shape_tool.IsAssembly_s(ref_tdf_label): - sub_shape.children = build_assembly(sub_shape, ref_tdf_label) + sub_shape = Compound() + sub_shape.children = build_assembly(ref_tdf_label) + else: + sub_shape = topods_lut[type(sub_topo_shape)](sub_topo_shape) + + sub_shape.color = Color(get_color(sub_topo_shape)) + sub_shape.label = get_name(ref_tdf_label) + sub_shape.move(Location(shape_tool.GetLocation_s(sub_tdf_label))) + sub_shapes.append(sub_shape) return sub_shapes @@ -222,7 +221,7 @@ def build_assembly( root = Compound() root.for_construction = None - build_assembly(root) + root.children = build_assembly() # Remove empty Compound wrapper if single free object if len(root.children) == 1: root = root.children[0] diff --git a/tests/test_importers.py b/tests/test_importers.py index db724cd8..5dc252a1 100644 --- a/tests/test_importers.py +++ b/tests/test_importers.py @@ -1,6 +1,7 @@ from io import StringIO import os import unittest +import urllib.request from build123d import BuildLine, Color, Line, Bezier, RadiusArc, Solid, Compound from build123d.importers import ( import_svg_as_buildline_code, @@ -103,6 +104,21 @@ def test_bad_filename(self): class ImportSTEP(unittest.TestCase): + + @classmethod + def setUpClass(cls): + """setUpClass is a class method that is executed once for the entire test + class before any of the test methods in the class are executed. It's intended + for setting up expensive resources or doing tasks that are required by all + tests in the class, such as downloading a large file, establishing a database + connection, or starting a server. + """ + url = "https://raw.githubusercontent.com/tpaviot/pythonocc-demos/master/assets/models/as1-oc-214.stp" + file_path = "/tmp/as1-oc-214.stp" + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + cls.large_step_file_path = file_path + def test_single_object(self): export_step(Solid.make_box(1, 1, 1), "test.step") box = import_step("test.step") @@ -146,6 +162,12 @@ def test_single_label_color(self): self.assertEqual(imported_assembly.children[1].label, "box") self.assertEqual(tuple(imported_assembly.children[1].color), (0, 0, 1, 1)) + def test_assembly_with_oriented_parts(self): + assembly = import_step(self.large_step_file_path) + fused = Solid.fuse(*assembly.solids()) + # If the parts where placed correctly they all touch and can be fused + self.assertEqual(len(fused.solids()), 1) + if __name__ == "__main__": unittest.main()