Skip to content

Commit

Permalink
minor bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugene Istrati committed Oct 31, 2024
1 parent 21a98ef commit 6e7817d
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions app/anomaly-detector/anomaly_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,30 @@
from os import path, makedirs
from concurrent.futures import ProcessPoolExecutor
from psycopg_pool import AsyncConnectionPool
from psycopg.adapt import adapt, Adapter
from psycopg.types.string import StrDumper
from boto3 import client as boto3_client
from pandas import get_dummies, to_datetime, concat, read_csv
from numpy import concatenate
from numpy import concatenate, ndarray
from sklearn.preprocessing import StandardScaler
from sentence_transformers import SentenceTransformer
from kubernetes import client as k8s_client, config as k8s_config

# Custom adapter for numpy arrays and mixed-type lists
class MixedTypeAdapter(Adapter):
def __init__(self, adapted):
self.adapted = adapted

def dump(self, obj):
if isinstance(obj, ndarray):
obj = obj.tolist()
return StrDumper().dump(str(obj))

# Register the adapter
adapt(ndarray, MixedTypeAdapter)

# Global variable for the model
model = None
MODEL = None

def get_config_map_values(config_map_name = "config-map"):
"""
Expand Down Expand Up @@ -150,15 +165,15 @@ def process_dataframe(df):
return combined_features, df[textual_features]

def initialize_model():
global model
if model is None:
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
global MODEL
if MODEL is None:
MODEL = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

def encode_batch(batch):
global model
if model is None:
global MODEL
if MODEL is None:
initialize_model()
return model.encode(batch)
return MODEL.encode(batch)

def create_embeddings(textual_features, batch_size=1000, num_workers=3):
"""
Expand Down Expand Up @@ -214,7 +229,7 @@ async def is_transaction_anomaly(conn_pool, embeddings, df):
async with conn_pool.connection() as aconn:
async with aconn.cursor() as acur:
for embedding in embeddings:
await acur.execute(query, (embedding.tolist(),))
await acur.execute(query, (embedding,))
result = await acur.fetchone()
scores.append(result[0] if result else None)

Expand Down

0 comments on commit 6e7817d

Please sign in to comment.