From b3054b3ddae9d889c32e1de97c49a3b862128402 Mon Sep 17 00:00:00 2001 From: Karl Date: Mon, 14 Jul 2025 18:32:48 +0100 Subject: [PATCH] database connection pooling --- ktvmanager/lib/database.py | 14 +++++++++----- ktvmanager/main.py | 4 ++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ktvmanager/lib/database.py b/ktvmanager/lib/database.py index 6f216ee..a202ad6 100644 --- a/ktvmanager/lib/database.py +++ b/ktvmanager/lib/database.py @@ -1,10 +1,15 @@ -import mysql.connector +import mysql.connector.pooling from flask import jsonify, request, current_app from ktvmanager.lib.checker import single_account_check from ktvmanager.lib.encryption import encrypt_password, decrypt_password -def _create_connection(): - return mysql.connector.connect( +db_pool = None + +def initialize_db_pool(): + global db_pool + db_pool = mysql.connector.pooling.MySQLConnectionPool( + pool_name="ktv_pool", + pool_size=5, host=current_app.config["DBHOST"], user=current_app.config["DBUSER"], password=current_app.config["DBPASS"], @@ -13,7 +18,7 @@ def _create_connection(): ) def _execute_query(query, params=None): - conn = _create_connection() + conn = db_pool.get_connection() cursor = conn.cursor(dictionary=True) try: cursor.execute(query, params) @@ -33,7 +38,6 @@ def get_user_id_from_username(username): if result: return result[0]['id'] return None - def get_user_accounts(user_id): query = "SELECT * FROM userAccounts WHERE userID = %s" accounts = _execute_query(query, (user_id,)) diff --git a/ktvmanager/main.py b/ktvmanager/main.py index 0765d7d..d2b4f98 100644 --- a/ktvmanager/main.py +++ b/ktvmanager/main.py @@ -3,6 +3,7 @@ from flask import Flask, jsonify from dotenv import load_dotenv from ktvmanager.config import DevelopmentConfig, ProductionConfig from routes.api import api_blueprint +from ktvmanager.lib.database import initialize_db_pool def create_app(): app = Flask(__name__) @@ -13,6 +14,9 @@ def create_app(): else: app.config.from_object(DevelopmentConfig) + with app.app_context(): + initialize_db_pool() + # Register blueprints app.register_blueprint(api_blueprint)