Skip to content

Commit

Permalink
Merge pull request #45 from trollLemon/kmean_optim
Browse files Browse the repository at this point in the history
Kmean optim
  • Loading branch information
trollLemon authored Jun 12, 2024
2 parents 4fb0c94 + a87b918 commit 619df79
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 143 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@
#build directory
build
Build

#
.cache
38 changes: 1 addition & 37 deletions src/includes/k_mean.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,7 @@
#ifndef KMEAN
#define KMEAN
#include "color.h"
#include <queue>

// Cluster id and the distance from a point to its centroid

struct cluster_distance {

int cluster;
double distance;
~cluster_distance();
};

// custom comparator to sort clusters
struct ColorSort {

bool operator()(const Color *a, const Color *b);
};

// custom comparator to compare cluster_distances
struct Comp {

bool operator()(const cluster_distance *a, const cluster_distance *b);
};

// wrapper around a priority queue for a Min heap
class minHeap {

private:
std::priority_queue<cluster_distance *, std::vector<cluster_distance *>, Comp>
distances;

public:
~minHeap();
void push(cluster_distance *pair);
int pop();
void clear();
};

#include <vector>

/* *
* Returns the Euclidean Distance between two colors
Expand Down
150 changes: 44 additions & 106 deletions src/kmean/k_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,156 +2,94 @@
#include "cluster.h"
#include "color.h"
#include <algorithm>
#include <cmath>
#include <random>
#include <unordered_map>
#include <set>
#include <map>
#define MAX_ITERATIONS 12
cluster_distance::~cluster_distance() {}
void minHeap::push(cluster_distance *pair) {

pair->distance *= -1;

distances.push(pair);
}

bool ColorSort::operator()(const Color *a, const Color *b) {
return a->Red() > b->Red() && a->Green() > b->Green() &&
a->Blue() > b->Blue();
}

bool Comp::operator()(const cluster_distance *a, const cluster_distance *b) {
return a->distance < b->distance;
}

minHeap::~minHeap() { clear(); }

int minHeap::pop() {

if (distances.top() == nullptr) {
return -1;
}
int ClusterId = distances.top()->cluster;
delete distances.top();
distances.pop();

return ClusterId;
}
void minHeap::clear() {
#include <unordered_map>
#define MAX_ITERATIONS 256

while (!distances.empty()) {
delete distances.top();
distances.pop();
}
}
struct CompareColors {
bool operator()(Color *a, Color *b) const { return a->asHex() > b->asHex(); }
};

double EuclideanDistance(Color *a, Color *b) {

double deltaR = a->Red() - b->Red();
double deltaG = a->Green() - b->Green();
double deltaB = a->Blue() - b->Blue();

return ((deltaR * deltaR) + (deltaG * deltaG) + (deltaB * deltaB));
double dist2 = ((deltaR * deltaR) + (deltaG * deltaG) + (deltaB * deltaB));

return sqrt(dist2);
}

std::vector<Color *> KMeans(std::vector<Color *> &colors, int k) {

std::map<Color *, minHeap *> data; // Distances from points to each centroid
std::unordered_map<int /*Cluster ID*/, Cluster *> clusters; // clusters

for (Color *point : colors) {
data[point] = new minHeap();
}

/*
* Initialize the Clustering
*
* */
// randomly get k points
int size = colors.size();
std::unordered_map<int, Cluster *> clusters;

std::random_device rd;
std::seed_seq ss{rd(), rd(), rd(), rd(), rd(), rd(), rd(), rd()};

std::mt19937 mt{ss};
std::uniform_int_distribution<> kPoints{size / 2, size};
std::set<int> seen; // make sure we have unique numbers
std::vector<int> colorIdx;
while (seen.size() != static_cast<long unsigned int>(k)) {
int num = kPoints(mt);

if (seen.count(num) != 0)
continue;
std::mt19937 gen(ss); // Seed with current time
std::uniform_int_distribution<int> dist(0, colors.size() - 1);

seen.insert(num);

colorIdx.push_back(num);
for (int i = 0; i < k; i++) {
clusters[i] = new Cluster(colors[dist(gen)], i);
}

for (int i = 0; i < k; ++i)
clusters[i] = new Cluster(colors[colorIdx[i]], i);

// Do the first iteration
for (Color *point : colors)
for (Color *point : colors) {
point->setClusterId(-1);
}

// recalculate distances for all points
// if any points move, we will put the effected clusters in a set,

// recalculate distances for all points
// if any points move, we will put the effected clusters in a set,

std::set<Cluster *> toRecalculate;
std::set<int> toRecalculate;
for (int i = 0; i < MAX_ITERATIONS; ++i) {

int itrs = 0;
do {
toRecalculate.clear();
for (Color *point : colors) {

minHeap *heap = data[point];
int id = -1;

for (auto cluster : clusters) {
double maxDistance = INFINITY;
for (std::pair<int, Cluster *> clusterPairs : clusters) {

double distance =
EuclideanDistance(cluster.second->getCentroid(), point);
int id = cluster.second->getId();
cluster_distance *dist = new cluster_distance;
dist->cluster = id;
dist->distance = distance;
heap->push(dist);
}
Cluster *cluster = clusterPairs.second;
double distance = EuclideanDistance(cluster->getCentroid(), point);

if (distance < maxDistance) {
maxDistance = distance;
id = cluster->getId();
}
}

int id = heap->pop();
int pid = point->getClusterId();

if (id != point->getClusterId()) {
toRecalculate.insert(clusters[id]);
int pId = point->getClusterId();
if (pId != -1) {
toRecalculate.insert(clusters[pId]);
if (id != pid) {
point->setClusterId(id);
clusters[id]->addPoint(point);
toRecalculate.insert(id);
// edge case for when the points all have an of -1 for the first run
if (pid > -1) {
toRecalculate.insert(pid);
}
}
}

point->setClusterId(id);
if (toRecalculate.empty()) {
break; // if no points moved, we converged
}

for (Cluster *cluster : toRecalculate)
cluster->calcNewCentroid();

itrs++;
} while (itrs < MAX_ITERATIONS && toRecalculate.size() !=0);
for (int cluster : toRecalculate) {
clusters[cluster]->calcNewCentroid();
}
}

std::vector<std::string> palette;
std::vector<Color *> sortedColors;

for (auto cluster : clusters) {
sortedColors.push_back(cluster.second->getCentroid());
delete cluster.second;
}
std::sort(sortedColors.begin(), sortedColors.end(), ColorSort());

for (auto heap : data)
delete heap.second;
std::sort(std::begin(sortedColors), std::end(sortedColors), CompareColors());

return sortedColors;
}

0 comments on commit 619df79

Please sign in to comment.