Skip to content

Commit

Permalink
Fix streamlit-folium incompatibility (add layer to map with new class) (
Browse files Browse the repository at this point in the history
#1834)

* Add Layer to map with MacroElement class

* run ruff

* add ElementAddToElement only once

* Fix marker cluster test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Conengmo and pre-commit-ci[bot] authored Nov 27, 2023
1 parent 6040f42 commit d250952
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 38 deletions.
20 changes: 19 additions & 1 deletion folium/elements.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Tuple

from branca.element import CssLink, Element, Figure, JavascriptLink
from branca.element import CssLink, Element, Figure, JavascriptLink, MacroElement
from jinja2 import Template


class JSCSSMixin(Element):
Expand All @@ -22,3 +23,20 @@ def render(self, **kwargs) -> None:
figure.header.add_child(CssLink(url), name=name)

super().render(**kwargs)


class ElementAddToElement(MacroElement):
"""Abstract class to add an element to another element."""

_template = Template(
"""
{% macro script(this, kwargs) %}
{{ this.element_name }}.addTo({{ this.element_parent_name }});
{% endmacro %}
"""
)

def __init__(self, element_name: str, element_parent_name: str):
super().__init__()
self.element_name = element_name
self.element_parent_name = element_parent_name
27 changes: 9 additions & 18 deletions folium/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from branca.element import Element, Figure, Html, MacroElement
from jinja2 import Template

from folium.elements import ElementAddToElement
from folium.utilities import (
TypeBounds,
TypeJsonValue,
camelize,
escape_backticks,
get_and_assert_figure_root,
parse_options,
validate_location,
)
Expand Down Expand Up @@ -51,24 +51,15 @@ def __init__(
self.show = show

def render(self, **kwargs):
super().render(**kwargs)
if self.show:
self._add_layer_to_map()

def _add_layer_to_map(self, **kwargs):
"""Show the layer on the map by adding it to its parent in JS."""
template = Template(
"""
{%- macro script(this, kwargs) %}
{{ this.get_name() }}.addTo({{ this._parent.get_name() }});
{%- endmacro %}
"""
)
script = template.module.__dict__["script"]
figure = get_and_assert_figure_root(self)
figure.script.add_child(
Element(script(self, kwargs)), name=self.get_name() + "_add"
)
self.add_child(
ElementAddToElement(
element_name=self.get_name(),
element_parent_name=self._parent.get_name(),
),
name=self.get_name() + "_add",
)
super().render(**kwargs)


class FeatureGroup(Layer):
Expand Down
38 changes: 19 additions & 19 deletions tests/plugins/test_marker_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,7 @@ def test_marker_cluster():
m = folium.Map([45.0, 3.0], zoom_start=4)
mc = plugins.MarkerCluster(data).add_to(m)

out = normalize(m._parent.render())

# We verify that imports
assert (
'<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/leaflet.markercluster.js"></script>' # noqa
in out
) # noqa
assert (
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.css"/>' # noqa
in out
) # noqa
assert (
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.Default.css"/>' # noqa
in out
) # noqa

# Verify the script part is okay.
tmpl = Template(
tmpl_for_expected = Template(
"""
var {{this.get_name()}} = L.markerClusterGroup(
{{ this.options|tojson }}
Expand All @@ -60,7 +43,24 @@ def test_marker_cluster():
{{ this.get_name() }}.addTo({{ this._parent.get_name() }});
"""
)
expected = normalize(tmpl.render(this=mc))
expected = normalize(tmpl_for_expected.render(this=mc))

out = normalize(m._parent.render())

# We verify that imports
assert (
'<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/leaflet.markercluster.js"></script>' # noqa
in out
) # noqa
assert (
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.css"/>' # noqa
in out
) # noqa
assert (
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.markercluster/1.1.0/MarkerCluster.Default.css"/>' # noqa
in out
) # noqa

assert expected in out

bounds = m.get_bounds()
Expand Down

0 comments on commit d250952

Please sign in to comment.