From 869b7b54cc5daa5412bc7f4ff87cb9337fcac11e Mon Sep 17 00:00:00 2001 From: Abhinav Shrivastava <880094+abhi2610@users.noreply.github.com> Date: Thu, 22 Oct 2020 18:17:44 -0400 Subject: [PATCH] Add image_classification CLASSIFICATION feedback. Remove DETECTION feedback. Minor lint changes. --- sail_on/api/file_provider.py | 171 +++++++++++------- .../OND.1.1.1234_metadata.json | 2 +- tests/test_api.py | 72 +++++++- 3 files changed, 166 insertions(+), 79 deletions(-) diff --git a/sail_on/api/file_provider.py b/sail_on/api/file_provider.py index 67150fb..ea6a68e 100644 --- a/sail_on/api/file_provider.py +++ b/sail_on/api/file_provider.py @@ -22,8 +22,6 @@ from io import BytesIO - - def get_session_info(folder: str, session_id: str) -> Dict[str, Any]: """Retrieve session info.""" path = os.path.join(folder, f"{str(session_id)}.json") @@ -78,6 +76,8 @@ def log_session( json.dump(structure, session_file, indent=2) # region Feedback related functions + + def read_feedback_file( csv_reader: reader, feedback_ids: List[str], @@ -128,6 +128,7 @@ def read_feedback_file( for x in [[n.strip(" \"'") for n in y] for y in lines][round_pos:round_pos + int(metadata["feedback_max_ids"])] } + def get_classification_feedback( gt_files: List[str], result_files: List[str], @@ -136,20 +137,26 @@ def get_classification_feedback( round_id: int, ) -> Dict[str, Any]: """Calculates and returns the proper feedback for classification type feedback""" - with open(gt_files[0], "r") as f: - gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + # Read detection files with open(result_files[0], "r") as rf: - result_reader = csv.reader(rf, delimiter=",") - results = read_feedback_file(result_reader, feedback_ids, metadata, False) + detection_result_reader = csv.reader(rf, delimiter=",") + detection_results = read_feedback_file( + detection_result_reader, feedback_ids, metadata, False) + + # if we assume monotonically increasing detection results, we can only check first. + assert all([v[0] > metadata["threshold"] for (k, v) in detection_results.items( + )]), "Novelty Detection score needs to be \">= threshold\" to request feedback. Discuss with TA1s to disable this." + # Read classification files + with open(gt_files[1], "r") as f: + gt_reader = csv.reader(f, delimiter=",") + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) + # Don't need to read this. + # with open(result_files[1], "r") as rf: + # result_reader = csv.reader(rf, delimiter=",") + # results = read_feedback_file(result_reader, feedback_ids, metadata, False) - return { - x: 0 - if ground_truth[x][1:].index(max(ground_truth[x][1:])) != - results[x][1:].index(max(results[x][1:])) - else 1 - for x in ground_truth.keys() - } + return {x: np.argmax(ground_truth[x]) for x in ground_truth.keys()} def get_detection_feedback( @@ -164,16 +171,20 @@ def get_detection_feedback( with open(gt_files[0], "r") as f: gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) with open(result_files[0], "r") as rf: result_reader = csv.reader(rf, delimiter=",") results = read_feedback_file(result_reader, feedback_ids, metadata, False) + # this is incorrect; but since detection feedback is not allowed; leaving it as is. + raise NameError('DetectionFeedback is not supported.') return { x: 0 if abs(ground_truth[x][0] - results[x][0]) > threshold else 1 for x in ground_truth.keys() } + def get_characterization_feedback( gt_files: List[str], result_files: List[str], @@ -186,12 +197,13 @@ def get_characterization_feedback( with open(gt_files[0], "r") as f: gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) with open(result_files[0], "r") as rf: result_reader = csv.reader(rf, delimiter=",") results = read_feedback_file(result_reader, feedback_ids, metadata, False) - # If ground truth is not novel, returns 1 is prediction is correct, + # If ground truth is not novel, returns 1 is prediction is correct, # otherwise returns 1 if prediction is not a known class return { x: 0 @@ -203,6 +215,7 @@ def get_characterization_feedback( for x in ground_truth.keys() } + def get_levenshtein_feedback( gt_files: List[str], result_files: List[str], @@ -213,16 +226,19 @@ def get_levenshtein_feedback( """Calculates and returns the proper feedback for levenshtein type feedback""" with open(gt_files[0], "r") as f: gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) with open(result_files[0], "r") as rf: result_reader = csv.reader(rf, delimiter=",") results = read_feedback_file(result_reader, feedback_ids, metadata, False) return { - x: [nltk.edit_distance(ground_truth[x][i], results[x][i]) for i,_ in enumerate(ground_truth[x])] + x: [nltk.edit_distance(ground_truth[x][i], results[x][i]) + for i, _ in enumerate(ground_truth[x])] for x in ground_truth.keys() } + def get_cluster_feedback( gt_files: List[str], result_files: List[str], @@ -233,7 +249,8 @@ def get_cluster_feedback( """Calculates and returns the proper feedback for levenshtein type feedback""" with open(gt_files[0], "r") as f: gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) with open(result_files[0], "r") as rf: result_reader = csv.reader(rf, delimiter=",") results = read_feedback_file(result_reader, feedback_ids, metadata, False) @@ -248,8 +265,9 @@ def get_cluster_feedback( gt_list.append(ground_truth[key]) r_list.append(results[key]) except: - raise ServerError("MissingIds", "Some requested Ids are missing from either ground truth or results file for the current round") - + raise ServerError( + "MissingIds", "Some requested Ids are missing from either ground truth or results file for the current round") + gt_np = np.argmax(np.array(gt_list), axis=1) r_np = np.argmax(np.array(gt_list), axis=1) @@ -259,10 +277,12 @@ def get_cluster_feedback( for i in np.unique(r_np): places = np.where(r_np == i)[0] - return_dict[str(i)] = (max(np.unique(gt_np[places],return_counts=True)[1])/places.shape[0]) - + return_dict[str(i)] = ( + max(np.unique(gt_np[places], return_counts=True)[1])/places.shape[0]) + return return_dict + def psuedo_label_feedback( gt_files: List[str], feedback_ids: List[str], @@ -275,7 +295,8 @@ def psuedo_label_feedback( "Grabs psuedo label feedback for requested ids" with open(gt_files[0], "r") as f: gt_reader = csv.reader(f, delimiter=",") - ground_truth = read_feedback_file(gt_reader, feedback_ids, metadata, True, round_id) + ground_truth = read_feedback_file( + gt_reader, feedback_ids, metadata, True, round_id) structure = get_session_info(folder, session_id) @@ -339,7 +360,7 @@ def get_test_metadata(self, session_id: str, test_id: str, api_call: bool = True if session_id is not None: structure = get_session_info(self.results_folder, session_id) - hints = structure.get('activity',{}).get('created', {}).get('hints',[]) + hints = structure.get('activity', {}).get('created', {}).get('hints', []) approved_metadata.extend([data for data in ["red_light"] if data in hints]) @@ -365,7 +386,8 @@ def _strip_id(filename): elif not os.path.exists(os.path.join(self.folder, protocol, domain)): msg = f"domain {domain} for {protocol} not configured" else: - test_ids = [_strip_id(f) for f in glob.glob(os.path.join(self.folder, protocol, domain,'*.csv'))] + test_ids = [_strip_id(f) for f in glob.glob( + os.path.join(self.folder, protocol, domain, '*.csv'))] return {"test_ids": test_ids, "generator_seed": "1234"} raise ProtocolError( "BadDomain", @@ -443,7 +465,8 @@ def dataset_request(self, session_id: str, test_id: str, round_id: int) -> FileR f"Round id {str(round_id)} is out of scope for test id {test_id}. Check the metadata round_size.", # noqa: E501, traceback.format_stack(), ) - temp_file_path.write(''.join(lines[round_pos:round_pos + int(metadata["round_size"])]).encode('utf-8')) + temp_file_path.write( + ''.join(lines[round_pos:round_pos + int(metadata["round_size"])]).encode('utf-8')) temp_file_path.seek(0) else: temp_file_path = open(file_locations[0], 'rb') @@ -469,22 +492,22 @@ def dataset_request(self, session_id: str, test_id: str, round_id: int) -> FileR # } # } feedback_request_mapping = { - "image_classification" : { + "image_classification": { ProtocolConstants.CLASSIFICATION: { - "function" : get_classification_feedback, - "files" : [ProtocolConstants.DETECTION, ProtocolConstants.CLASSIFICATION], + "function": get_classification_feedback, + "files": [ProtocolConstants.DETECTION, ProtocolConstants.CLASSIFICATION], "detection_req": True }, ProtocolConstants.PSUEDO_CLASSIFICATION: { - "function" : psuedo_label_feedback, - "files" : [ProtocolConstants.CLASSIFICATION], + "function": psuedo_label_feedback, + "files": [ProtocolConstants.CLASSIFICATION], "detection_req": True } }, - "transcripts" : { + "transcripts": { ProtocolConstants.CLASSIFICATION: { - "function" : get_cluster_feedback, - "files" : [ProtocolConstants.CLASSIFICATION], + "function": get_cluster_feedback, + "files": [ProtocolConstants.CLASSIFICATION], "detection_req": True }, ProtocolConstants.TRANSCRIPTION: { @@ -496,14 +519,14 @@ def dataset_request(self, session_id: str, test_id: str, round_id: int) -> FileR "files": [ProtocolConstants.CHARACTERIZATION] }, ProtocolConstants.PSUEDO_CLASSIFICATION: { - "function" : psuedo_label_feedback, - "files" : [ProtocolConstants.CLASSIFICATION] + "function": psuedo_label_feedback, + "files": [ProtocolConstants.CLASSIFICATION] } }, - "activity" : { + "activity": { ProtocolConstants.CLASSIFICATION: { - "function" : get_cluster_feedback, - "files" : [ProtocolConstants.CLASSIFICATION], + "function": get_cluster_feedback, + "files": [ProtocolConstants.CLASSIFICATION], "detection_req": True }, ProtocolConstants.TEMPORAL: { @@ -515,8 +538,8 @@ def dataset_request(self, session_id: str, test_id: str, round_id: int) -> FileR "files": [ProtocolConstants.SPATIAL] }, ProtocolConstants.PSUEDO_CLASSIFICATION: { - "function" : psuedo_label_feedback, - "files" : [ProtocolConstants.CLASSIFICATION] + "function": psuedo_label_feedback, + "files": [ProtocolConstants.CLASSIFICATION] } } } @@ -533,13 +556,14 @@ def get_feedback( metadata = self.get_test_metadata(session_id, test_id, False) structure = get_session_info(self.results_folder, session_id) - # Gets the amount of ids already requested for this type of feedback this round and + # Gets the amount of ids already requested for this type of feedback this round and # determines whether the limit has alreayd been reached try: - feedback_count = structure["activity"]["get_feedback"]["tests"][test_id]["rounds"][round_id].get(feedback_type, 0) + feedback_count = structure["activity"]["get_feedback"]["tests"][test_id]["rounds"][round_id].get( + feedback_type, 0) if feedback_count >= metadata["feedback_max_ids"]: raise ProtocolError( - "FeedbackBudgetExceeded", + "FeedbackBudgetExceeded", f"Feedback of type {feedback_type} has already been requested on the maximum number of ids" ) except KeyError: @@ -559,7 +583,7 @@ def get_feedback( file_types = self.feedback_request_mapping[domain][feedback_type]["files"] except: raise ProtocolError( - "InvalidFeedbackType", + "InvalidFeedbackType", f"Invalid feedback type requested for the test id {test_id} with domain {domain}", traceback.format_stack(), ) @@ -580,12 +604,13 @@ def get_feedback( results_files = [] for t in file_types: results_files.extend(glob.glob( - os.path.join(self.results_folder,"**",f"{str(session_id)}.{str(test_id)}_{t}.csv"), + os.path.join(self.results_folder, "**", + f"{str(session_id)}.{str(test_id)}_{t}.csv"), recursive=True, )) else: raise ProtocolError( - "BadDomain", + "BadDomain", f"The set domain does not match a proper domain type. Please check the metadata file for test {test_id}", traceback.format_stack(), ) @@ -593,24 +618,28 @@ def get_feedback( # Check to make sure the round id being requested is both the latest and the highest round submitted try: if structure["activity"]["post_results"]["tests"][test_id]["last round"] != str(round_id): - raise RoundError("NotLastRound", "Attempted to get feedback on an older round. Feedback can only be retrieved on the most recent round submission.") - - rounds_subbed = [int(r) for r in structure["activity"]["post_results"]["tests"][test_id]["rounds"].keys()] + raise RoundError( + "NotLastRound", "Attempted to get feedback on an older round. Feedback can only be retrieved on the most recent round submission.") + + rounds_subbed = [int(r) for r in structure["activity"] + ["post_results"]["tests"][test_id]["rounds"].keys()] if int(round_id) != max(rounds_subbed): - raise RoundError("NotMaxRound", "Attempted to get feedback on a round that wasn't the max round submitted (most likely a resubmitted round).") + raise RoundError( + "NotMaxRound", "Attempted to get feedback on a round that wasn't the max round submitted (most likely a resubmitted round).") except RoundError as e: raise e except Exception as e: - raise RoundError("SessionLogError", "Error checking session log for round history. Ensure results have been posted before requesting feedback") + raise RoundError( + "SessionLogError", "Error checking session log for round history. Ensure results have been posted before requesting feedback") # If detection is required, ensure detection has been posted for the requested round if self.feedback_request_mapping[metadata["domain"]][feedback_type].get("detection_req", False): if "detection file path" not in structure["activity"]["post_results"]["tests"][test_id]["rounds"][round_id]: raise ProtocolError( - "DetectionPostRequired", + "DetectionPostRequired", "A detection file is required to be posted before feedback can be requested on a round. Please submit Detection results before requesting feedback" ) - + if len(results_files) < 1: raise ServerError( "result_file_not_found", @@ -618,7 +647,6 @@ def get_feedback( traceback.format_stack(), ) - # Get feedback from specified test try: if "psuedo" in feedback_type: @@ -660,16 +688,17 @@ def get_feedback( round_id=round_id, content={feedback_type: feedback_count}, ) - feedback_csv = BytesIO() for key in feedback.keys(): if type(feedback[key]) is not list: feedback_csv.write(f"{key},{feedback[key]}\n".encode('utf-8')) else: - feedback_csv.write(f"{key},{','.join(str(x) for x in feedback[key])}\n".encode('utf-8')) + feedback_csv.write( + f"{key},{','.join(str(x) for x in feedback[key])}\n".encode('utf-8')) feedback_csv.seek(0) - + print(feedback_csv.read()) + feedback_csv.seek(0) return feedback_csv def post_results( @@ -765,11 +794,14 @@ def session_status(self, after: str = None, session_id: str = None, include_test if include_tests: if test_ids: for test_id in test_ids: - results.append(f'{session_name}, {test_id}, {creation_time},{terminate_time}') + results.append( + f'{session_name}, {test_id}, {creation_time},{terminate_time}') else: - results.append(f'{session_name}, NA, {creation_time}, {terminate_time}') + results.append( + f'{session_name}, NA, {creation_time}, {terminate_time}') else: - results.append(f'{session_name},{creation_time},{terminate_time}') + results.append( + f'{session_name},{creation_time},{terminate_time}') results = sorted(results, key=lambda x: (x.split(',')[1], x.split(',')[0])) return '\n'.join(results) @@ -780,15 +812,18 @@ def get_session_zip(self, session_id) -> str: :return: zip file path """ zip_file_name = os.path.join(self.results_folder, f'{session_id}.zip') - with zipfile.ZipFile(zip_file_name, 'w', compression= zipfile.ZIP_BZIP2) as zip: - zip.write(os.path.join(self.results_folder, f'{session_id}.json'), arcname=f'{session_id}.json') + with zipfile.ZipFile(zip_file_name, 'w', compression=zipfile.ZIP_BZIP2) as zip: + zip.write(os.path.join(self.results_folder, + f'{session_id}.json'), arcname=f'{session_id}.json') for protocol in os.listdir(self.results_folder): if os.path.isdir(os.path.join(self.results_folder, protocol)): for test_file in glob.glob( - os.path.join(self.results_folder, protocol, "**", f"{session_id}.*.csv"), + os.path.join(self.results_folder, protocol, + "**", f"{session_id}.*.csv"), recursive=True, ): - zip.write(test_file, arcname=test_file[len(self.results_folder) + 1:]) + zip.write(test_file, arcname=test_file[len( + self.results_folder) + 1:]) - return zip_file_name \ No newline at end of file + return zip_file_name diff --git a/tests/data/OND/image_classification/OND.1.1.1234_metadata.json b/tests/data/OND/image_classification/OND.1.1.1234_metadata.json index b67270b..2449ae6 100644 --- a/tests/data/OND/image_classification/OND.1.1.1234_metadata.json +++ b/tests/data/OND/image_classification/OND.1.1.1234_metadata.json @@ -1,5 +1,5 @@ { - "domain": "transcripts", + "domain": "image_classification", "protocol": "OND", "known_classes": 3, "max_novel_classes": 2, diff --git a/tests/test_api.py b/tests/test_api.py index 5100a0f..7c35bf9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -17,6 +17,8 @@ from json import JSONDecodeError # Helpers + + def _check_response(response: Response) -> None: """ Raise the appropriate ApiError based on response error code. @@ -40,7 +42,7 @@ def _check_response(response: Response) -> None: def get(path: str, **params: Dict[str, Any]) -> Response: - return requests.get(f"http://localhost:12345{path}", **params) + return requests.get(f"http://localhost:12345{path}", **params) def post(path: str, **params: Dict[str, Any]) -> Response: @@ -51,7 +53,9 @@ def delete(path: str, **params: Dict[str, Any]) -> Response: return requests.delete(f"http://localhost:12345{path}", **params) -SERVER_RESULTS_DIR = os.path.join(os.path.dirname(__file__), f"server_results_unit_tests") +SERVER_RESULTS_DIR = os.path.join( + os.path.dirname(__file__), f"server_results_unit_tests") + class TestApi(unittest.TestCase): """Test the API.""" @@ -68,14 +72,16 @@ def setUpClass(cls): shutil.rmtree(SERVER_RESULTS_DIR) os.mkdir(SERVER_RESULTS_DIR) server.set_provider( - FileProvider(os.path.join(os.path.dirname(__file__), "data"), SERVER_RESULTS_DIR) + FileProvider(os.path.join(os.path.dirname( + __file__), "data"), SERVER_RESULTS_DIR) ) api_thread = threading.Thread(target=server.init, args=("localhost", 12345)) api_thread.daemon = True api_thread.start() directory = os.path.join(os.path.dirname(__file__), "session_state_files") for filename in os.listdir(directory): - shutil.copy(os.path.join(directory, filename), os.path.join(SERVER_RESULTS_DIR)) + shutil.copy(os.path.join(directory, filename), + os.path.join(SERVER_RESULTS_DIR)) # Test Ids Request Tests def test_test_ids_request_success(self): @@ -129,14 +135,16 @@ def test_session_request_success(self): session_id = response.json()["session_id"] self.assertEqual(session_id, str(uuid.UUID(session_id))) - self.assertTrue(os.path.exists(os.path.join(SERVER_RESULTS_DIR, f"{session_id}.json"))) + self.assertTrue(os.path.exists(os.path.join( + SERVER_RESULTS_DIR, f"{session_id}.json"))) # Dataset Request Tests def test_dataset_request_success_with_round_id(self): """Test dataset request with rounds.""" response = get( "/session/dataset", - params={"session_id": "data_request", "test_id": "OND.1.1.1234", "round_id": 0}, + params={"session_id": "data_request", + "test_id": "OND.1.1.1234", "round_id": 0}, ) _check_response(response) @@ -176,7 +184,8 @@ def test_get_feedback_failure_invalid_type(self): def test_get_feedback_success_multiple_types(self): """Test get_feedback with multiple types.""" - feedback_types = [ProtocolConstants.CLASSIFICATION, ProtocolConstants.CHARACTERIZATION] + feedback_types = [ProtocolConstants.CLASSIFICATION, + ProtocolConstants.CHARACTERIZATION] response = get( "/session/feedback", params={ @@ -191,7 +200,8 @@ def test_get_feedback_success_multiple_types(self): multipart_data = decoder.MultipartDecoder.from_response(response) result_dicts = [] for i in range(len(feedback_types)): - header = multipart_data.parts[i].headers[b"Content-Disposition"].decode("utf-8") + header = multipart_data.parts[i].headers[b"Content-Disposition"].decode( + "utf-8") header_dict = { x[0].strip(): x[1].strip(" \"'") for x in [part.split("=") for part in header.split(";") if "=" in part] @@ -202,11 +212,52 @@ def test_get_feedback_success_multiple_types(self): actual = [] for i, part in enumerate(multipart_data.parts): actual = part.content.decode("utf-8") + print('-----') + print(actual) + print('-----') self.assertEqual(expected[i], actual) for i, head in enumerate(result_dicts): self.assertEqual(feedback_types[i], head["name"]) - self.assertEqual(f"get_feedback.OND.1.1.1234.1_{feedback_types[i]}.csv", head["filename"]) + self.assertEqual( + f"get_feedback.OND.1.1.1234.1_{feedback_types[i]}.csv", head["filename"]) + + def test_get_feedback_success_classification(self): + """Test get_feedback with classification.""" + feedback_types = [ProtocolConstants.CLASSIFICATION] + response = get( + "/session/feedback", + params={ + "feedback_type": feedback_types[0], + "session_id": "get_feedback", + "test_id": "OND.1.1.1234", + "round_id": 1, + }, + ) + + _check_response(response) + expected = "n01484850_4515.JPEG,0\nn01484850_45289.JPEG,2\n" + actual = response.content.decode("utf-8") + self.assertEqual(expected, actual) + + def test_get_feedback_success_single_classification(self): + """Test get_feedback with classification.""" + feedback_types = [ProtocolConstants.CLASSIFICATION] + response = get( + "/session/feedback", + params={ + "feedback_type": feedback_types[0], + "feedback_ids": ["n01484850_45289.JPEG"], + "session_id": "get_feedback", + "test_id": "OND.1.1.1234", + "round_id": 1, + }, + ) + + _check_response(response) + expected = "n01484850_45289.JPEG,2\n" + actual = actual = response.content.decode("utf-8") + self.assertEqual(expected, actual) def test_get_feedback_failure_no_round_id(self): """Test get_feedback fails with no round id.""" @@ -402,7 +453,8 @@ def test_evaluate_success_with_round_id(self): """Test evaluate with rounds.""" response = get( "/session/evaluations", - params={"session_id": "evaluation", "test_id": "OND.1.1.1234", "round_id": 0}, + params={"session_id": "evaluation", + "test_id": "OND.1.1.1234", "round_id": 0}, ) _check_response(response)