From 98df100e0f363e08878af9c8da4d7b97ad5e0c2e Mon Sep 17 00:00:00 2001 From: Shadi Date: Mon, 8 Apr 2024 09:37:40 +0900 Subject: [PATCH] Fixing the openmp interface calls to avoid build failures when openmp is not enabled. --- viprs/model/vi/e_step.pyx | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/viprs/model/vi/e_step.pyx b/viprs/model/vi/e_step.pyx index 6818ddf..1ec86e1 100644 --- a/viprs/model/vi/e_step.pyx +++ b/viprs/model/vi/e_step.pyx @@ -1,5 +1,4 @@ from cython.parallel import prange, parallel -cimport openmp from ...utils.math_utils cimport ( sigmoid, softmax, @@ -16,6 +15,24 @@ cimport numpy as np from cython cimport floating, integral +# A safe way to get the number of the thread currently executing the code: +# This is used to avoid compile-time errors when compiling the code with OpenMP support disabled. +# In earlier iterations, we used: +# cimport openmp +# openmp.omp_get_thread_num() +# But this tends to fail when OpenMP is not enabled. +# The code below is a safer way to get the thread number. +cdef extern from *: + """ + #ifdef _OPENMP + #include + #else + int omp_get_thread_num() { return 0; } + #endif + """ + int omp_get_thread_num() noexcept nogil + + @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) @@ -243,7 +260,7 @@ cpdef void e_step_mixture(int[::1] ld_left_bound, for j in prange(c_size, nogil=True, schedule='static', num_threads=threads): # Set the thread offset for the u_j array: - thread_offset = openmp.omp_get_thread_num() * (K + 1) + thread_offset = omp_get_thread_num() * (K + 1) # The start and end coordinates for the flattened LD matrix: ld_start = ld_indptr[j]