Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore #2835

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft

Ignore #2835

Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Cutoff by dates
CeesJol committed Sep 18, 2022
commit 1e5956c03743c43f8f994b18bf1fd2d703a3a9dd
112 changes: 76 additions & 36 deletions samples/strawberry/strawberry.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,40 @@
import os
CLUSTER = not 'Ceess-MacBook-Pro-2.local' in os.popen('hostname').read()

import sys
import json
import datetime
import numpy as np
import skimage.draw
from os.path import exists
import csv
# import csv
import glob

# DETECTIONS_PATH = '/Users/ceesjol/Documents/Thesis/data_original/Detections/'
DETECTIONS_PATH = '/tudelft.net/staff-umbrella/abeellabstudents/cfjol/data/Detections/'
if CLUSTER:
DETECTIONS_PATH = '/tudelft.net/staff-umbrella/abeellabstudents/cfjol/data/Detections/'
else:
DETECTIONS_PATH = '/Users/ceesjol/Documents/Thesis/data_original/Detections/'


# Root directory of the project
ROOT_DIR = os.path.abspath("../../")

# Main directory that contains project and data
MAIN_DIR = os.path.abspath("../../../")

# Date constants
CAM_ID = 3
CAM_NAME = ''
if CAM_ID % 2 == 0:
CAM_NAME = f'OCNCAM{CAM_ID}'
else:
CAM_NAME = f'RGBCAM{CAM_ID}'
if CAM_ID == 3:
# One month for test, one month for val, almost 4 months for train
CUTOFF_DATE_TRAIN = datetime.datetime(2021, 8, 13)
CUTOFF_DATE_VAL = datetime.datetime(2021, 9, 26)
CUTOFF_DAYS_MARGIN = 14

# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
sys.path.append('../../mrcnn') # ...Actually use local version
@@ -49,10 +67,10 @@ class StrawberryConfig(Config):
NUM_CLASSES = 1 + 3 # Background + strawberry + flower + note

# Number of training steps per epoch
STEPS_PER_EPOCH = 1
STEPS_PER_EPOCH = 100

# Validation steps per epoch
VALIDATION_STEPS = 1
VALIDATION_STEPS = 50

# Skip detections with < 80% confidence
DETECTION_MIN_CONFIDENCE = 0.8
@@ -65,18 +83,21 @@ class StrawberryConfig(Config):
labels = []
def get_labels():
global labels
if len(labels) == 0:
labels = []
if len(labels) > 0:
return labels

for filepath in glob.iglob(DETECTIONS_PATH + 'v1/*.json', recursive=True):
print(filepath)
my_json = json.load(open(str(filepath)))
my_json = my_json["images"]
labels += my_json
print('Loading labels for {CAM_NAME}...')

len(labels)
for filepath in glob.iglob(DETECTIONS_PATH + '**/*.json', recursive=True):
if not CAM_NAME in str(filepath):
continue
print(filepath)
my_json = json.load(open(str(filepath)))
my_json = my_json["images"]
labels += my_json

return labels
get_labels()

def get_polygon(label):
try:
@@ -108,24 +129,39 @@ def load_strawberry(self, dataset_dir, subset):

# Train or validation dataset?
assert subset in ["train", "val"]
dataset_dir = os.path.join(dataset_dir, subset)
# dataset_dir = os.path.join(dataset_dir, subset)

annotations = get_labels()

# Add images
for annotation in annotations:
polygons = get_polygon(annotation)
# Skip empty labels or annotations with very little data
if len(polygons) <= 2:
continue
width, height = 4000, 3000
image_name = annotation['file'].split("/")[-1]
image_path = os.path.join(dataset_dir, image_name)

# Skip non-existing images
file_exists = exists(image_path)
if not file_exists:
continue

# Check if image date is in the right range
image_year = image_name.split("_")[0]
image_monthday = image_name.split("_")[1]
image_date = datetime.datetime.strptime(f'{image_year}{image_monthday}', '%Y%m%d')

if subset == "train" and not (image_date < CUTOFF_DATE_TRAIN):
continue
elif subset == "val" and not (image_date < CUTOFF_DATE_VAL and image_date >= CUTOFF_DATE_TRAIN + datetime.timedelta(days=CUTOFF_DAYS_MARGIN)):
continue
# "test" is inference only
# elif subset == "test" and not (image_date >= CUTOFF_DATE_VAL + datetime.timedelta(days=CUTOFF_DAYS_MARGIN)):
# continue

polygons = get_polygon(annotation)
# Skip empty labels or annotations with very little data
if len(polygons) <= 2:
continue
width, height = 4000, 3000

labels = [label.lower() for label in annotation['label']]
self.add_image(
"strawberry",
@@ -201,6 +237,10 @@ def train(model):
learning_rate=config.LEARNING_RATE,
epochs=1,
layers='heads')
# model.train(dataset_train, dataset_val,
# learning_rate=config.LEARNING_RATE,
# epochs=1,
# layers='all')

def crop(image, roi):
"""Crop an image to a region of interest."""
@@ -297,11 +337,11 @@ def detect_and_segment(model, image_paths, output_dir='output'):
os.mkdir(path)

# Create output csv
f = open(os.path.join(output_dir, date, 'output.csv'), 'w')
writer = csv.writer(f)
writer.writerow(["Track ID", "Name", "Pixel width"])
# f = open(os.path.join(output_dir, date, 'output.csv'), 'w')
# writer = csv.writer(f)
# writer.writerow(["Track ID", "Name", "Pixel width"])

track_ids = {}
# track_ids = {}

for image_path in image_paths:
# Skip JSON, DS_Store, etc.
@@ -400,19 +440,19 @@ def detect_and_segment(model, image_paths, output_dir='output'):
skimage.io.imsave(os.path.join(output_dir, date, file_name), seg)

# Store track id if it's not in track_ids, or if it's a newer image
if not track_id in track_ids or track_ids[track_id]['name'] < image_name:
track_ids[track_id] = {
'name': file_name,
'width': seg.shape[1],
}

for key in track_ids.keys():
arr = [key]
for val in track_ids[key]:
arr.append(track_ids[key][val])
writer.writerow(arr)

f.close()
# if not track_id in track_ids or track_ids[track_id]['name'] < image_name:
# track_ids[track_id] = {
# 'name': file_name,
# 'width': seg.shape[1],
# }

# for key in track_ids.keys():
# arr = [key]
# for val in track_ids[key]:
# arr.append(track_ids[key][val])
# writer.writerow(arr)

# f.close()

print('Done')