diff --git a/facedb/db.py b/facedb/db.py index 9878d0e..eebb12d 100644 --- a/facedb/db.py +++ b/facedb/db.py @@ -5,10 +5,22 @@ import warnings import threading - -DeepFace = None -deepface_distance = None -face_recognition = None +# Attempt to import deepface and face_recognition at the top level +try: + from deepface import DeepFace + from deepface.commons import distance as deepface_distance + DEEPFACE_AVAILABLE = True +except ImportError: + DEEPFACE_AVAILABLE = False + DeepFace = None + deepface_distance = None + +try: + import face_recognition + FACE_RECOGNITION_AVAILABLE = True +except ImportError: + FACE_RECOGNITION_AVAILABLE = False + face_recognition = None from facedb.db_tools import ( get_embeddings, @@ -32,38 +44,17 @@ ) from facedb.db_models import FaceResults, PineconeDB, ChromaDB - from pathlib import Path -import_lock = threading.Lock() - -def load_module(module: Literal["deepface", "face_recognition"]): - with import_lock: - if module == "deepface": - global DeepFace - global deepface_distance - if DeepFace is None: - try: - from deepface import DeepFace - from deepface.commons import distance as deepface_distance - except ImportError: - raise ImportError( - "Please install `deepface` to use `deepface` module." - ) - elif module == "face_recognition": - global face_recognition - if face_recognition is None: - try: - import face_recognition - except ImportError: - raise ImportError( - "Please install `face_recognition` to use `face_recognition` module." - ) - else: - raise ValueError( - "Currently only `deepface` and `face_recognition` are supported." - ) +def load_module(self, module: Literal["deepface", "face_recognition"]): + """Function to check if a module is loaded and available.""" + if module == "deepface" and not DEEPFACE_AVAILABLE: + raise ImportError("Please install `deepface` to use `deepface` module.") + elif module == "face_recognition" and not FACE_RECOGNITION_AVAILABLE: + raise ImportError("Please install `face_recognition` to use `face_recognition` module.") + elif module not in ["deepface", "face_recognition"]: + raise ValueError("Currently only `deepface` and `face_recognition` are supported.") def create_deepface_embedding_func(