diff --git a/swmmio/core.py b/swmmio/core.py index b8c2815..05ab85f 100644 --- a/swmmio/core.py +++ b/swmmio/core.py @@ -96,9 +96,26 @@ class Model(object): J1 13.392 NaN 0.0 NaN 0.0 """ def __init__(self, in_file_path, crs=None, include_rpt=True): + """ + Initialize a swmmio.Model object by pointing it to a local INP file + or a URL to a remote INP file. + + Parameters + ---------- + in_file_path : str + Path to local INP file or URL to remote INP file + crs : str, optional + String representation of a coordinate reference system, by default None + include_rpt : bool, optional + whether to include data from an RPT (if an RPT exists), by default True + """ self.crs = None inp_path = None + + # if the input is a URL, download it to a temp location + in_file_path = functions.check_if_url_and_download(in_file_path) + if os.path.isdir(in_file_path): # a directory was passed in inps_in_dir = glob.glob1(in_file_path, "*.inp") diff --git a/swmmio/tests/test_functions.py b/swmmio/tests/test_functions.py index 81db0b1..e956d81 100644 --- a/swmmio/tests/test_functions.py +++ b/swmmio/tests/test_functions.py @@ -1,9 +1,12 @@ import pytest +import unittest +from unittest.mock import patch, mock_open, MagicMock + import swmmio from swmmio.tests.data import (MODEL_FULL_FEATURES__NET_PATH, OWA_RPT_EXAMPLE, RPT_FULL_FEATURES, MODEL_EX_1_PARALLEL_LOOP, MODEL_EX_1) -from swmmio.utils.functions import format_inp_section_header, find_network_trace +from swmmio.utils.functions import format_inp_section_header, find_network_trace, check_if_url_and_download from swmmio.utils import error from swmmio.utils.text import get_rpt_metadata @@ -97,3 +100,48 @@ def test_network_trace_bad_include_node(): path_selection = find_network_trace(m, start_node, end_node, include_nodes=["1000"]) + +class TestCheckIfUrlAndDownload(unittest.TestCase): + + @patch('requests.get') + @patch('tempfile.gettempdir') + @patch('builtins.open', new_callable=mock_open) + def test_download_file(self, mock_open, mock_gettempdir, mock_requests_get): + # Mock the response from requests.get + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'Test content' + mock_requests_get.return_value = mock_response + + # Mock the temporary directory + mock_gettempdir.return_value = '/tmp' + + url = 'https://example.com/path/to/file.txt' + expected_path = '/tmp/file.txt' + + result = check_if_url_and_download(url) + + # Check if the file was written correctly + mock_open.assert_called_once_with(expected_path, 'wb') + mock_open().write.assert_called_once_with(b'Test content') + + self.assertEqual(result, expected_path) + + @patch('requests.get') + def test_download_file_failed(self, mock_requests_get): + # Mock the response from requests.get + mock_response = MagicMock() + mock_response.status_code = 404 + mock_requests_get.return_value = mock_response + + url = 'https://example.com/path/to/file.txt' + + with self.assertRaises(Exception) as context: + check_if_url_and_download(url) + + self.assertIn('Failed to download file: 404', str(context.exception)) + + def test_not_a_url(self): + non_url_string = '/Users/bingo/models/not_a_url.inp' + result = check_if_url_and_download(non_url_string) + self.assertEqual(result, non_url_string) \ No newline at end of file diff --git a/swmmio/utils/functions.py b/swmmio/utils/functions.py index f679fd8..17b1c81 100644 --- a/swmmio/utils/functions.py +++ b/swmmio/utils/functions.py @@ -1,6 +1,11 @@ import warnings import pandas as pd import networkx as nx +import os +from urllib.parse import urlparse +import tempfile + +import requests from swmmio.utils import error @@ -296,4 +301,39 @@ def summarize_model(model): if len(model.nodes.dataframe) != 0: model_summary['invert_range'] = model.nodes().InvertElev.max() - model.nodes().InvertElev.min() - return model_summary \ No newline at end of file + return model_summary + + +def check_if_url_and_download(url): + """ + Check if a given string is a URL and download the + file to a temporary directory if it is. + + Parameters + ---------- + url : str + string that may be a URL + + Returns + ------- + str + path to the downloaded file in the temporary directory or + the original string if it is not a URL + """ + + if url.startswith(('http://', 'https://')): + r = requests.get(url) + + # get the filename from the url + parsed_url = urlparse(url) + filename = parsed_url.path.split('/')[-1] + + temp_path = os.path.join(tempfile.gettempdir(), filename) + with open(temp_path, 'wb') as f: + if r.status_code == 200: + f.write(r.content) + else: + raise Exception(f"Failed to download file: {r.status_code}") + return temp_path + else: + return url \ No newline at end of file