Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Agglomerative clustering. #1384

Merged
merged 2 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ option(SHERPA_ONNX_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF)
option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF)
option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON)
option(SHERPA_ONNX_ENABLE_TTS "Whether to build TTS related code" ON)
option(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION "Whether to build speaker diarization related code" ON)
option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON)
option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF)
Expand Down Expand Up @@ -142,6 +143,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_VAD_ASR ${SHERPA_ONNX_ENABLE_WASM_VAD_AS
message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}")
message(STATUS "SHERPA_ONNX_ENABLE_BINARY ${SHERPA_ONNX_ENABLE_BINARY}")
message(STATUS "SHERPA_ONNX_ENABLE_TTS ${SHERPA_ONNX_ENABLE_TTS}")
message(STATUS "SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ${SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION}")
message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY}")
message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}")
message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}")
Expand Down Expand Up @@ -341,6 +343,10 @@ if(SHERPA_ONNX_ENABLE_TTS)
include(cppjieba) # For Chinese TTS. It is a header-only C++ library
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
include(hclust-cpp)
endif()

# if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"))
if(SHERPA_ONNX_ENABLE_SANITIZER)
message(WARNING "enable ubsan and asan")
Expand Down
45 changes: 45 additions & 0 deletions cmake/hclust-cpp.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
function(download_hclust_cpp)
include(FetchContent)

# The latest commit as of 2024.09.29
set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz")
set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33")

# If you don't have access to the Internet,
# please pre-download hclust-cpp
set(possible_file_locations
$ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz
${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz
${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz
/tmp/hclust-cpp-2024-09-29.tar.gz
/star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(hclust_cpp_URL "${f}")
file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL)
message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}")
break()
endif()
endforeach()

FetchContent_Declare(hclust_cpp
URL
${hclust_cpp_URL}
${hclust_cpp_URL2}
URL_HASH ${hclust_cpp_HASH}
)

FetchContent_GetProperties(hclust_cpp)
if(NOT hclust_cpp_POPULATED)
message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}")
FetchContent_Populate(hclust_cpp)
endif()

message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}")
message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}")
include_directories(${hclust_cpp_SOURCE_DIR})
endfunction()

download_hclust_cpp()
13 changes: 13 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
fast-clustering-config.cc
fast-clustering.cc
)
endif()

if(SHERPA_ONNX_ENABLE_CHECK)
list(APPEND sources log.cc)
endif()
Expand Down Expand Up @@ -523,6 +530,12 @@ if(SHERPA_ONNX_ENABLE_TESTS)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sherpa_onnx_test_srcs
fast-clustering-test.cc
)
endif()

list(APPEND sherpa_onnx_test_srcs
speaker-embedding-manager-test.cc
)
Expand Down
45 changes: 45 additions & 0 deletions sherpa-onnx/csrc/fast-clustering-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// sherpa-onnx/csrc/fast-clustering-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/fast-clustering-config.h"

#include <sstream>
#include <string>

#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {
std::string FastClusteringConfig::ToString() const {
std::ostringstream os;

os << "FastClusteringConfig(";
os << "num_clusters=" << num_clusters << ", ";
os << "threshold=" << threshold << ")";

return os.str();
}

void FastClusteringConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);

p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored");

p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering.");
}

bool FastClusteringConfig::Validate() const {
if (num_clusters < 1 && threshold < 0) {
SHERPA_ONNX_LOGE("Please provide either num_clusters or threshold");
return false;
}

return true;
}

} // namespace sherpa_onnx
28 changes: 28 additions & 0 deletions sherpa-onnx/csrc/fast-clustering-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/fast-clustering-config.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_
#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct FastClusteringConfig {
// If greater than 0, then threshold is ignored
int32_t num_clusters = -1;

// distance threshold
float threshold = 0.5;

std::string ToString() const;

void Register(ParseOptions *po);
bool Validate() const;
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_
69 changes: 69 additions & 0 deletions sherpa-onnx/csrc/fast-clustering-test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// sherpa-onnx/csrc/fast-clustering-test.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/fast-clustering.h"

#include <vector>

#include "gtest/gtest.h"

namespace sherpa_onnx {

TEST(FastClustering, TestTwoClusters) {
std::vector<float> features = {
// point 0
0.1,
0.1,
// point 2
0.4,
-0.5,
// point 3
0.6,
-0.7,
// point 1
0.2,
0.3,
};

FastClusteringConfig config;
config.num_clusters = 2;

FastClustering clustering(config);
auto labels = clustering.Cluster(features.data(), 4, 2);
int32_t k = 0;
for (auto i : labels) {
std::cout << "point " << k << ": label " << i << "\n";
++k;
}
}

TEST(FastClustering, TestClusteringWithThreshold) {
std::vector<float> features = {
// point 0
0.1,
0.1,
// point 2
0.4,
-0.5,
// point 3
0.6,
-0.7,
// point 1
0.2,
0.3,
};

FastClusteringConfig config;
config.threshold = 0.5;

FastClustering clustering(config);
auto labels = clustering.Cluster(features.data(), 4, 2);
int32_t k = 0;
for (auto i : labels) {
std::cout << "point " << k << ": label " << i << "\n";
++k;
}
}

} // namespace sherpa_onnx
83 changes: 83 additions & 0 deletions sherpa-onnx/csrc/fast-clustering.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// sherpa-onnx/csrc/fast-clustering.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/fast-clustering.h"

#include <vector>

#include "Eigen/Dense"
#include "fastcluster-all-in-one.h" // NOLINT

namespace sherpa_onnx {

class FastClustering::Impl {
public:
explicit Impl(const FastClusteringConfig &config) : config_(config) {}

std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
if (num_rows <= 0) {
return {};
}

if (num_rows == 1) {
return {0};
}

Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
m(features, num_rows, num_cols);
m.rowwise().normalize();

std::vector<double> distance((num_rows * (num_rows - 1)) / 2);

int32_t k = 0;
for (int32_t i = 0; i != num_rows; ++i) {
auto v = m.row(i);
for (int32_t j = i + 1; j != num_rows; ++j) {
double cosine_similarity = v.dot(m.row(j));
double consine_dissimilarity = 1 - cosine_similarity;

if (consine_dissimilarity < 0) {
consine_dissimilarity = 0;
}

distance[k] = consine_dissimilarity;
++k;
}
}

std::vector<int32_t> merge(2 * (num_rows - 1));
std::vector<double> height(num_rows - 1);

fastclustercpp::hclust_fast(num_rows, distance.data(),
fastclustercpp::HCLUST_METHOD_SINGLE,
merge.data(), height.data());

std::vector<int32_t> labels(num_rows);
if (config_.num_clusters > 0) {
fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters,
labels.data());
} else {
fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(),
config_.threshold, labels.data());
}

return labels;
}

private:
FastClusteringConfig config_;
};

FastClustering::FastClustering(const FastClusteringConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

FastClustering::~FastClustering() = default;

std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
return impl_->Cluster(features, num_rows, num_cols);
}
} // namespace sherpa_onnx
43 changes: 43 additions & 0 deletions sherpa-onnx/csrc/fast-clustering.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// sherpa-onnx/csrc/fast-clustering.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_
#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_

#include <memory>
#include <vector>

#include "sherpa-onnx/csrc/fast-clustering-config.h"

namespace sherpa_onnx {

class FastClustering {
public:
explicit FastClustering(const FastClusteringConfig &config);
~FastClustering();

/**
* @param features Pointer to a 2-D feature matrix in row major. Each row
* is a feature frame. It is changed in-place. We will
* convert each feature frame to a normalized vector.
* That is, the L2-norm of each vector will be equal to 1.
* It uses cosine dissimilarity,
* which is 1 - (cosine similarity)
* @param num_rows Number of feature frames
* @param num-cols The feature dimension.
*
* @return Return a vector of size num_rows. ans[i] contains the label
* for the i-th feature frame, i.e., the i-th row of the feature
* matrix.
*/
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols);

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_
2 changes: 0 additions & 2 deletions sherpa-onnx/csrc/offline-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "sherpa-onnx/csrc/offline-stream.h"

#include <math.h>

#include <algorithm>
#include <cassert>
#include <cmath>
Expand Down
Loading
Loading