From d25095235f63902278ef49d76739ae91ca34910e Mon Sep 17 00:00:00 2001 From: Frank Anema <33519926+Conengmo@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:58:04 +0100 Subject: [PATCH] Fix streamlit-folium incompatibility (add layer to map with new class) (#1834) * Add Layer to map with MacroElement class * run ruff * add ElementAddToElement only once * Fix marker cluster test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- folium/elements.py | 20 ++++++++++++++- folium/map.py | 27 +++++++------------- tests/plugins/test_marker_cluster.py | 38 ++++++++++++++-------------- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/folium/elements.py b/folium/elements.py index bbb06c7d2..416ad3845 100644 --- a/folium/elements.py +++ b/folium/elements.py @@ -1,6 +1,7 @@ from typing import List, Tuple -from branca.element import CssLink, Element, Figure, JavascriptLink +from branca.element import CssLink, Element, Figure, JavascriptLink, MacroElement +from jinja2 import Template class JSCSSMixin(Element): @@ -22,3 +23,20 @@ def render(self, **kwargs) -> None: figure.header.add_child(CssLink(url), name=name) super().render(**kwargs) + + +class ElementAddToElement(MacroElement): + """Abstract class to add an element to another element.""" + + _template = Template( + """ + {% macro script(this, kwargs) %} + {{ this.element_name }}.addTo({{ this.element_parent_name }}); + {% endmacro %} + """ + ) + + def __init__(self, element_name: str, element_parent_name: str): + super().__init__() + self.element_name = element_name + self.element_parent_name = element_parent_name diff --git a/folium/map.py b/folium/map.py index 8a8c054b6..e3356c275 100644 --- a/folium/map.py +++ b/folium/map.py @@ -9,12 +9,12 @@ from branca.element import Element, Figure, Html, MacroElement from jinja2 import Template +from folium.elements import ElementAddToElement from folium.utilities import ( TypeBounds, TypeJsonValue, camelize, escape_backticks, - get_and_assert_figure_root, parse_options, validate_location, ) @@ -51,24 +51,15 @@ def __init__( self.show = show def render(self, **kwargs): - super().render(**kwargs) if self.show: - self._add_layer_to_map() - - def _add_layer_to_map(self, **kwargs): - """Show the layer on the map by adding it to its parent in JS.""" - template = Template( - """ - {%- macro script(this, kwargs) %} - {{ this.get_name() }}.addTo({{ this._parent.get_name() }}); - {%- endmacro %} - """ - ) - script = template.module.__dict__["script"] - figure = get_and_assert_figure_root(self) - figure.script.add_child( - Element(script(self, kwargs)), name=self.get_name() + "_add" - ) + self.add_child( + ElementAddToElement( + element_name=self.get_name(), + element_parent_name=self._parent.get_name(), + ), + name=self.get_name() + "_add", + ) + super().render(**kwargs) class FeatureGroup(Layer): diff --git a/tests/plugins/test_marker_cluster.py b/tests/plugins/test_marker_cluster.py index 5eb586ac2..4c3d635ce 100644 --- a/tests/plugins/test_marker_cluster.py +++ b/tests/plugins/test_marker_cluster.py @@ -23,24 +23,7 @@ def test_marker_cluster(): m = folium.Map([45.0, 3.0], zoom_start=4) mc = plugins.MarkerCluster(data).add_to(m) - out = normalize(m._parent.render()) - - # We verify that imports - assert ( - '' # noqa - in out - ) # noqa - assert ( - '' # noqa - in out - ) # noqa - assert ( - '' # noqa - in out - ) # noqa - - # Verify the script part is okay. - tmpl = Template( + tmpl_for_expected = Template( """ var {{this.get_name()}} = L.markerClusterGroup( {{ this.options|tojson }} @@ -60,7 +43,24 @@ def test_marker_cluster(): {{ this.get_name() }}.addTo({{ this._parent.get_name() }}); """ ) - expected = normalize(tmpl.render(this=mc)) + expected = normalize(tmpl_for_expected.render(this=mc)) + + out = normalize(m._parent.render()) + + # We verify that imports + assert ( + '' # noqa + in out + ) # noqa + assert ( + '' # noqa + in out + ) # noqa + assert ( + '' # noqa + in out + ) # noqa + assert expected in out bounds = m.get_bounds()