Skip to content

Commit

Permalink
Update prophet required version to 1.1 (#110)
Browse files Browse the repository at this point in the history
* Update prophet required version to 1.1.

Prophet 1.1 removes the dependency on pystan. Since pystan installation
is quite complicated, this change makes Merlion as a whole much simpler
to install.

* Simplify Dockerfile.

* Add Python 3.10 support to the test bench.

* Bump version 1.2.3.

* Correction for Python 3.10.

* Add prophet=1.0 backward compatibility for Python 3.6.
  • Loading branch information
aadyotb authored Jun 28, 2022
1 parent 251f4e9 commit 593ecb9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
22 changes: 2 additions & 20 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,12 @@ FROM apache/spark-py:v3.2.1
# Change to root user for installation steps
USER 0

# Uninstall existing python and replace it with miniconda.
# This is to get the right version of Python in Debian, since Prophet doesn't play nice with Python 3.9+.
# FIXME: maybe optimize the size? this image is currently 3.2GB.
RUN apt-get update && \
apt-get remove -y python3 python3-pip && \
apt-get install -y --no-install-recommends curl && \
apt-get autoremove -yqq --purge && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
# Install prophet while we're at it, since this is easier to conda install than pip install
/opt/conda/bin/conda install -y prophet && \
/opt/conda/bin/conda clean -ya
ENV PATH="/opt/conda/bin:${SPARK_HOME}/bin:${PATH}"

# Install (for spark-sql) and Merlion; get pyspark & py4j from the PYTHONPATH
# Install pyarrow (for spark-sql) and Merlion; get pyspark & py4j from the PYTHONPATH
ENV PYTHONPATH="${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-0.10.9.3-src.zip:${PYTHONPATH}"
COPY *.md ./
COPY setup.py ./
COPY merlion merlion
RUN pip install pyarrow "./[prophet]" && pip uninstall -y py4j
RUN pip install pyarrow "./" && pip uninstall -y py4j

# Copy Merlion pyspark apps
COPY spark /opt/spark/apps
Expand Down
3 changes: 1 addition & 2 deletions merlion/models/automl/autosarima.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
"""
Automatic hyperparameter selection for SARIMA.
"""
from collections import Iterator
from copy import copy, deepcopy
import logging
from typing import Any, Optional, Tuple, Union
from typing import Any, Iterator, Optional, Tuple, Union

import numpy as np

Expand Down
33 changes: 16 additions & 17 deletions merlion/models/forecast/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,15 @@
"""
Wrapper around Facebook's popular Prophet model for time series forecasting.
"""
import copy
import logging
import os
from typing import Iterable, List, Tuple, Union

try:
import prophet
except ImportError as e:
err_msg = (
"Try installing Merlion with optional dependencies using `pip install salesforce-merlion[prophet]` or "
"`pip install `salesforce-merlion[all]`"
)
raise ImportError(str(e) + ". " + err_msg)

import numpy as np
import pandas as pd
import prophet
import prophet.serialize

from merlion.models.automl.seasonality import SeasonalityModel
from merlion.models.forecast.base import ForecasterBase, ForecasterConfig
Expand Down Expand Up @@ -144,14 +138,19 @@ def require_even_sampling(self) -> bool:
return False

def __getstate__(self):
stan_backend = self.model.stan_backend
if hasattr(stan_backend, "logger"):
model_logger = self.model.stan_backend.logger
self.model.stan_backend.logger = None
state_dict = super().__getstate__()
if hasattr(stan_backend, "logger"):
self.model.stan_backend.logger = model_logger
return state_dict
try:
model = prophet.serialize.model_to_json(self.model)
except ValueError: # prophet.serialize only works for fitted models, so deepcopy as a backup
model = copy.deepcopy(self.model)
return {k: model if k == "model" else copy.deepcopy(v) for k, v in self.__dict__.items()}

def __setstate__(self, state):
if "model" in state:
model = state["model"]
if isinstance(model, str):
state = copy.copy(state)
state["model"] = prophet.serialize.model_from_json(model)
super().__setstate__(state)

@property
def yearly_seasonality(self):
Expand Down
11 changes: 4 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
]

# optional dependencies
extra_require = {
"plot": ["plotly>=4.13"],
"prophet": ["prophet", "pystan<3.0"], # pystan >= 3.0 doesn't work with prophet
"deep-learning": ["torch>=1.1.0"],
"spark": ["pyspark[sql]>=3"],
}
extra_require = {"plot": ["plotly>=4.13"], "deep-learning": ["torch>=1.1.0"], "spark": ["pyspark[sql]>=3"]}
extra_require["all"] = sum(extra_require.values(), [])


Expand All @@ -29,7 +24,7 @@ def read_file(fname):

setup(
name="salesforce-merlion",
version="1.2.2",
version="1.2.3",
author=", ".join(read_file("AUTHORS.md").split("\n")),
author_email="[email protected]",
description="Merlion: A Machine Learning Framework for Time Series Intelligence",
Expand All @@ -52,6 +47,8 @@ def read_file(fname):
"numpy>=1.19; python_version < '3.7'", # however, numpy 1.20+ requires python 3.7+
"packaging",
"pandas>=1.1.0", # >=1.1.0 for origin kwarg to df.resample()
"prophet>=1.1; python_version >= '3.7'", # 1.1 removes dependency on pystan
"prophet==1.0.1; python_version < '3.7'", # however, prophet 1.1 requires python 3.7+
"scikit-learn>=0.22", # >=0.22 for changes to isolation forest algorithm
"scipy>=1.6.0; python_version >= '3.7'", # 1.6.0 adds multivariate_t density to scipy.stats
"scipy>=1.5.0; python_version < '3.7'", # however, scipy 1.6.0 requires python 3.7+
Expand Down

0 comments on commit 593ecb9

Please sign in to comment.