From d25095235f63902278ef49d76739ae91ca34910e Mon Sep 17 00:00:00 2001
From: Frank Anema <33519926+Conengmo@users.noreply.github.com>
Date: Mon, 27 Nov 2023 14:58:04 +0100
Subject: [PATCH] Fix streamlit-folium incompatibility (add layer to map with
new class) (#1834)
* Add Layer to map with MacroElement class
* run ruff
* add ElementAddToElement only once
* Fix marker cluster test
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
folium/elements.py | 20 ++++++++++++++-
folium/map.py | 27 +++++++-------------
tests/plugins/test_marker_cluster.py | 38 ++++++++++++++--------------
3 files changed, 47 insertions(+), 38 deletions(-)
diff --git a/folium/elements.py b/folium/elements.py
index bbb06c7d2..416ad3845 100644
--- a/folium/elements.py
+++ b/folium/elements.py
@@ -1,6 +1,7 @@
from typing import List, Tuple
-from branca.element import CssLink, Element, Figure, JavascriptLink
+from branca.element import CssLink, Element, Figure, JavascriptLink, MacroElement
+from jinja2 import Template
class JSCSSMixin(Element):
@@ -22,3 +23,20 @@ def render(self, **kwargs) -> None:
figure.header.add_child(CssLink(url), name=name)
super().render(**kwargs)
+
+
+class ElementAddToElement(MacroElement):
+ """Abstract class to add an element to another element."""
+
+ _template = Template(
+ """
+ {% macro script(this, kwargs) %}
+ {{ this.element_name }}.addTo({{ this.element_parent_name }});
+ {% endmacro %}
+ """
+ )
+
+ def __init__(self, element_name: str, element_parent_name: str):
+ super().__init__()
+ self.element_name = element_name
+ self.element_parent_name = element_parent_name
diff --git a/folium/map.py b/folium/map.py
index 8a8c054b6..e3356c275 100644
--- a/folium/map.py
+++ b/folium/map.py
@@ -9,12 +9,12 @@
from branca.element import Element, Figure, Html, MacroElement
from jinja2 import Template
+from folium.elements import ElementAddToElement
from folium.utilities import (
TypeBounds,
TypeJsonValue,
camelize,
escape_backticks,
- get_and_assert_figure_root,
parse_options,
validate_location,
)
@@ -51,24 +51,15 @@ def __init__(
self.show = show
def render(self, **kwargs):
- super().render(**kwargs)
if self.show:
- self._add_layer_to_map()
-
- def _add_layer_to_map(self, **kwargs):
- """Show the layer on the map by adding it to its parent in JS."""
- template = Template(
- """
- {%- macro script(this, kwargs) %}
- {{ this.get_name() }}.addTo({{ this._parent.get_name() }});
- {%- endmacro %}
- """
- )
- script = template.module.__dict__["script"]
- figure = get_and_assert_figure_root(self)
- figure.script.add_child(
- Element(script(self, kwargs)), name=self.get_name() + "_add"
- )
+ self.add_child(
+ ElementAddToElement(
+ element_name=self.get_name(),
+ element_parent_name=self._parent.get_name(),
+ ),
+ name=self.get_name() + "_add",
+ )
+ super().render(**kwargs)
class FeatureGroup(Layer):
diff --git a/tests/plugins/test_marker_cluster.py b/tests/plugins/test_marker_cluster.py
index 5eb586ac2..4c3d635ce 100644
--- a/tests/plugins/test_marker_cluster.py
+++ b/tests/plugins/test_marker_cluster.py
@@ -23,24 +23,7 @@ def test_marker_cluster():
m = folium.Map([45.0, 3.0], zoom_start=4)
mc = plugins.MarkerCluster(data).add_to(m)
- out = normalize(m._parent.render())
-
- # We verify that imports
- assert (
- '' # noqa
- in out
- ) # noqa
- assert (
- '' # noqa
- in out
- ) # noqa
- assert (
- '' # noqa
- in out
- ) # noqa
-
- # Verify the script part is okay.
- tmpl = Template(
+ tmpl_for_expected = Template(
"""
var {{this.get_name()}} = L.markerClusterGroup(
{{ this.options|tojson }}
@@ -60,7 +43,24 @@ def test_marker_cluster():
{{ this.get_name() }}.addTo({{ this._parent.get_name() }});
"""
)
- expected = normalize(tmpl.render(this=mc))
+ expected = normalize(tmpl_for_expected.render(this=mc))
+
+ out = normalize(m._parent.render())
+
+ # We verify that imports
+ assert (
+ '' # noqa
+ in out
+ ) # noqa
+ assert (
+ '' # noqa
+ in out
+ ) # noqa
+ assert (
+ '' # noqa
+ in out
+ ) # noqa
+
assert expected in out
bounds = m.get_bounds()