import mysql.connector.pooling from flask import jsonify, request, current_app, Response from ktvmanager.lib.checker import single_account_check from ktvmanager.lib.encryption import encrypt_password, decrypt_password from ktvmanager.lib.get_urls import get_latest_urls_from_dns from typing import List, Dict, Any, Optional, Tuple db_pool = None def initialize_db_pool() -> None: """Initializes the database connection pool.""" global db_pool db_pool = mysql.connector.pooling.MySQLConnectionPool( pool_name="ktv_pool", pool_size=5, host=current_app.config["DBHOST"], user=current_app.config["DBUSER"], password=current_app.config["DBPASS"], database=current_app.config["DATABASE"], port=current_app.config["DBPORT"], ) _create_push_subscriptions_table() def _create_push_subscriptions_table() -> None: """Creates the push_subscriptions table if it doesn't exist.""" query = """ CREATE TABLE IF NOT EXISTS push_subscriptions ( id INT AUTO_INCREMENT PRIMARY KEY, user_id INT NOT NULL, subscription_json TEXT NOT NULL, last_notified TIMESTAMP NULL, FOREIGN KEY (user_id) REFERENCES users(id) ) """ _execute_query(query) def _execute_query(query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]] | Dict[str, int]: """Executes a SQL query and returns the result. Args: query: The SQL query to execute. params: The parameters to pass to the query. Returns: A list of dictionaries for SELECT queries, or a dictionary with the number of affected rows for other queries. """ conn = db_pool.get_connection() cursor = conn.cursor(dictionary=True) try: cursor.execute(query, params) if query.strip().upper().startswith("SELECT"): result = cursor.fetchall() else: conn.commit() result = {"affected_rows": cursor.rowcount} return result finally: cursor.close() conn.close() def get_user_id_from_username(username: str) -> Optional[int]: """Retrieves the user ID for a given username. Args: username: The username to look up. Returns: The user ID if found, otherwise None. """ query = "SELECT id FROM users WHERE username = %s" result = _execute_query(query, (username,)) if result: return result[0]["id"] return None def get_user_accounts(user_id: int) -> Response: """Retrieves all accounts for a given user ID. Args: user_id: The ID of the user. Returns: A Flask JSON response containing the user's accounts. """ query = "SELECT * FROM userAccounts WHERE userID = %s" accounts = _execute_query(query, (user_id,)) for account in accounts: try: account["password"] = decrypt_password(account["password"]) except Exception as e: # Log the error to the console for debugging print( f"Password decryption failed for account ID {account.get('id', 'N/A')}: {e}" ) account["password"] = "DECRYPTION_FAILED" return jsonify(accounts) def get_stream_names() -> Response: """Retrieves all stream names from the database. Returns: A Flask JSON response containing a list of stream names. """ query = "SELECT streamName FROM streams" results = _execute_query(query) stream_names = [row["streamName"] for row in results] return jsonify(stream_names) def single_check() -> Response | Tuple[Response, int]: """ Performs a check on a single account provided in the request JSON. Returns: A Flask JSON response with the result of the check, or an error message. """ data = request.get_json() stream_urls = current_app.config["STREAM_URLS"] result = single_account_check(data, stream_urls) if result: # Here you would typically update the database with the new information return jsonify(result) return jsonify({"message": "All checks failed"}), 400 def add_account(user_id: int) -> Response: """Adds a new account for a user. Args: user_id: The ID of the user. Returns: A Flask JSON response confirming the account was added. """ data = request.form res = single_account_check(data, get_latest_urls_from_dns()) encrypted_password = encrypt_password(data["password"]) query = "INSERT INTO userAccounts (username, stream, streamURL, expiaryDate, password, userID, maxConnections) VALUES (%s, %s, %s, %s, %s, %s, %s)" params = ( data["username"], data["stream"], res["url"], res["data"]["user_info"]["exp_date"], encrypted_password, user_id, res["data"]["user_info"]["max_connections"], ) result = _execute_query(query, params) return jsonify(result) def update_expiry_date(username: str, stream: str, expiry_date: str) -> None: """Updates the expiry date of an account. Args: username: The username of the account. stream: The stream of the account. expiry_date: The new expiry date. """ query = "UPDATE userAccounts SET expiaryDate = %s WHERE username = %s AND stream = %s" params = (expiry_date, username, stream) _execute_query(query, params) def update_max_connections(username: str, stream: str, max_connections: int) -> None: """Updates the max connections of an account. Args: username: The username of the account. stream: The stream of the account. max_connections: The new max connections value. """ query = "UPDATE userAccounts SET maxConnections = %s WHERE username = %s AND stream = %s" params = (max_connections, username, stream) _execute_query(query, params) def update_stream_url(new_stream: str, old_stream: str) -> None: """Updates the stream URL of an account. Args: new_stream: The stream of the account. old_stream: The new stream URL. """ query = "UPDATE userAccounts SET streamURL = %s WHERE streamURL = %s" params = (new_stream, old_stream) _execute_query(query, params) def delete_account(user_id: int) -> Response: """Deletes an account for a user. Args: user_id: The ID of the user. Returns: A Flask JSON response confirming the account was deleted. """ data = request.form query = "DELETE FROM userAccounts WHERE username = %s AND stream = %s AND userId = %s" params = (data["user"], data["stream"], user_id) result = _execute_query(query, params) return jsonify(result) def save_push_subscription(user_id: int, subscription_json: str) -> None: """Saves a push subscription to the database. Args: user_id: The ID of the user. subscription_json: The push subscription information as a JSON string. """ query = "INSERT INTO push_subscriptions (user_id, subscription_json) VALUES (%s, %s)" params = (user_id, subscription_json) _execute_query(query, params) def get_push_subscriptions(user_id: Optional[int] = None) -> List[Dict[str, Any]]: """Retrieves all push subscriptions for a given user ID, or all if no user_id is provided. Args: user_id: The ID of the user (optional). Returns: A list of push subscriptions. """ if user_id: query = "SELECT * FROM push_subscriptions WHERE user_id = %s" return _execute_query(query, (user_id,)) else: query = "SELECT * FROM push_subscriptions" return _execute_query(query)