# app.py from flask import (Flask, render_template, request, redirect, url_for, session, send_file, jsonify, send_from_directory, Response) from flask_caching import Cache import requests.auth import os import base64 from typing import Dict, Any, Tuple, Union import sys import redis import json from pywebpush import webpush, WebPushException import mysql.connector from lib.datetime import filter_accounts_next_30_days, filter_accounts_expired from lib.reqs import (get_urls, get_user_accounts, add_user_account, delete_user_account, get_stream_names) from config import DevelopmentConfig, ProductionConfig os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" app = Flask(__name__) if os.environ.get("FLASK_ENV") == "production": app.config.from_object(ProductionConfig) else: app.config.from_object(DevelopmentConfig) # Check for Redis availability and configure cache redis_url = app.config["REDIS_URL"] cache_config = {"CACHE_TYPE": "redis", "CACHE_REDIS_URL": redis_url} try: # Use a short timeout to prevent hanging r = redis.from_url(redis_url, socket_connect_timeout=1) r.ping() except redis.exceptions.ConnectionError as e: print( f"WARNING: Redis connection failed: {e}. Falling back to SimpleCache. " "This is not recommended for production with multiple workers.", file=sys.stderr, ) cache_config = {"CACHE_TYPE": "SimpleCache"} cache = Cache(app, config=cache_config) app.config["OCR_ENABLED"] = False app.config["SESSION_COOKIE_SECURE"] = not app.config["DEBUG"] app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' app.config['PERMANENT_SESSION_LIFETIME'] = 60 * 60 * 24 * 365 # 1 year def get_version() -> str: """Retrieves the application version from the VERSION file. Returns: The version string, or 'dev' if the file is not found. """ try: with open('VERSION', 'r') as f: return f.read().strip() except FileNotFoundError: return 'dev' @app.context_processor def inject_version() -> Dict[str, str]: """Injects the version into all templates.""" return dict(version=get_version(), config=app.config, session=session) def make_cache_key(*args, **kwargs): """Generate a cache key based on the user's session and request path.""" username = session.get('username', 'anonymous') path = request.path return f"view/{username}/{path}" @app.before_request def make_session_permanent() -> None: """Makes the user session permanent.""" session.permanent = True @app.route('/site.webmanifest') def serve_manifest() -> Response: """Serves the site manifest file.""" return send_from_directory( os.path.join(app.root_path, 'static'), 'site.webmanifest', mimetype='application/manifest+json' ) @app.route("/favicon.ico") def favicon() -> Response: """Serves the favicon.""" return send_from_directory( os.path.join(app.root_path, "static"), "favicon.ico", mimetype="image/vnd.microsoft.icon", ) @app.route("/") def index() -> Union[Response, str]: """Renders the index page or redirects to home if logged in.""" if session.get("logged_in"): return redirect(url_for("home")) return render_template("index.html") @app.route('/vapid-public-key', methods=['GET']) def proxy_vapid_public_key(): """Proxies the request for the VAPID public key to the backend.""" backend_url = f"{app.config['BASE_URL']}/vapid-public-key" try: response = requests.get(backend_url) return Response(response.content, status=response.status_code, mimetype=response.headers['Content-Type']) except requests.exceptions.RequestException as e: return jsonify({"error": str(e)}), 502 @app.route('/save-subscription', methods=['POST']) def proxy_save_subscription(): """Proxies the request to save a push subscription to the backend.""" if not session.get("logged_in"): return jsonify({'error': 'Unauthorized'}), 401 backend_url = f"{app.config['BASE_URL']}/save-subscription" credentials = base64.b64decode(session["auth_credentials"]).decode() username, password = credentials.split(":", 1) try: response = requests.post( backend_url, auth=requests.auth.HTTPBasicAuth(username, password), json=request.get_json() ) return Response(response.content, status=response.status_code, mimetype=response.headers['Content-Type']) except requests.exceptions.RequestException as e: return jsonify({"error": str(e)}), 502 def get_db_connection(): # This is a simplified version for demonstration. # In a real application, you would use a connection pool. return mysql.connector.connect( host=app.config["DBHOST"], user=app.config["DBUSER"], password=app.config["DBPASS"], database=app.config["DATABASE"], port=app.config["DBPORT"], ) def get_push_subscriptions(): conn = get_db_connection() cursor = conn.cursor(dictionary=True) cursor.execute("SELECT * FROM push_subscriptions") subscriptions = cursor.fetchall() cursor.close() conn.close() return subscriptions def send_notification(subscription_info, message_body): try: webpush( subscription_info=subscription_info, data=message_body, vapid_private_key=app.config["VAPID_PRIVATE_KEY"], vapid_claims={"sub": app.config["VAPID_CLAIM_EMAIL"]}, ) except WebPushException as ex: print(f"Web push error: {ex}") # You might want to remove the subscription if it's invalid if ex.response and ex.response.status_code == 410: print("Subscription is no longer valid, removing from DB.") # Add logic to remove the subscription from your database @app.route('/send-test-notification', methods=['POST']) def send_test_notification(): """Sends a test push notification to all users.""" if not session.get("logged_in"): return jsonify({'error': 'Unauthorized'}), 401 subscriptions = get_push_subscriptions() if not subscriptions: return jsonify({"message": "No push subscriptions found."}), 404 message_body = json.dumps({"title": "KTVManager", "body": "Ktv Test"}) for sub in subscriptions: try: send_notification(json.loads(sub['subscription_json']), message_body) except Exception as e: print(f"Error sending notification to subscription ID {sub.get('id', 'N/A')}: {e}") return jsonify({"message": f"Test notification sent to {len(subscriptions)} subscription(s)."}) @app.route("/home") @cache.cached(timeout=60, key_prefix=make_cache_key) def home() -> str: """Renders the home page with account statistics.""" if session.get("logged_in"): base_url = app.config["BASE_URL"] all_accounts = get_user_accounts(base_url, session["auth_credentials"]) return render_template( "home.html", username=session["username"], accounts=len(all_accounts), current_month_accounts=filter_accounts_next_30_days(all_accounts), expired_accounts=filter_accounts_expired(all_accounts), ) return render_template("index.html") @app.route("/login", methods=["POST"]) def login() -> Union[Response, str]: """Handles user login.""" username = request.form["username"] password = request.form["password"] credentials = f"{username}:{password}" encoded_credentials = base64.b64encode(credentials.encode()).decode() base_url = app.config["BASE_URL"] login_url = f"{base_url}/Login" try: response = requests.get( login_url, auth=requests.auth.HTTPBasicAuth(username, password) ) response.raise_for_status() if response.json().get("auth") == "Success": session["logged_in"] = True session["username"] = username session["auth_credentials"] = encoded_credentials next_url = request.args.get("next") if next_url: return redirect(next_url) return redirect(url_for("home", loggedin=True)) except requests.exceptions.RequestException: pass # Fall through to error error = "Invalid username or password. Please try again." return render_template("index.html", error=error) @app.route("/urls", methods=["GET"]) @cache.cached(timeout=300, key_prefix=make_cache_key) def urls() -> Union[Response, str]: """Renders the URLs page.""" if not session.get("logged_in"): return redirect(url_for("home")) base_url = app.config["BASE_URL"] return render_template( "urls.html", urls=get_urls(base_url, session["auth_credentials"]) ) @app.route("/accounts", methods=["GET"]) @cache.cached(timeout=60, key_prefix=make_cache_key) def user_accounts() -> Union[Response, str]: """Renders the user accounts page.""" if not session.get("logged_in"): return redirect(url_for("home")) base_url = app.config["BASE_URL"] user_accounts_data = get_user_accounts(base_url, session["auth_credentials"]) return render_template( "user_accounts.html", username=session["username"], user_accounts=user_accounts_data, auth=session["auth_credentials"], ) @app.route("/share", methods=["GET"]) def share() -> Response: """Handles shared text from PWA.""" if not session.get("logged_in"): return redirect(url_for("index", next=request.url)) shared_text = request.args.get("text") return redirect(url_for("add_account", shared_text=shared_text)) @app.route("/accounts/add", methods=["GET", "POST"]) def add_account() -> Union[Response, str]: """Handles adding a new user account.""" if not session.get("logged_in"): return redirect(url_for("index", next=request.url)) base_url = app.config["BASE_URL"] shared_text = request.args.get('shared_text') if request.method == "POST": username = request.form["username"] password = request.form["password"] stream = request.form["stream"] if add_user_account( base_url, session["auth_credentials"], username, password, stream ): cache.delete_memoized(user_accounts, key_prefix=make_cache_key) return redirect(url_for("user_accounts")) return render_template( "add_account.html", text_input_enabled=app.config.get("TEXT_INPUT_ENABLED"), shared_text=shared_text ) @app.route("/accounts/delete", methods=["POST"]) def delete_account() -> Response: """Handles deleting a user account.""" stream = request.form.get("stream") username = request.form.get("username") base_url = app.config["BASE_URL"] delete_user_account(base_url, session["auth_credentials"], stream, username) cache.delete_memoized(user_accounts, key_prefix=make_cache_key) return redirect(url_for("user_accounts")) @app.route("/validateAccount", methods=["POST"]) def validate_account() -> Tuple[Response, int]: """Forwards account validation requests to the backend.""" base_url = app.config["BASE_URL"] validate_url = f"{base_url}/validateAccount" credentials = base64.b64decode(session["auth_credentials"]).decode() username, password = credentials.split(":", 1) try: response = requests.post( validate_url, auth=requests.auth.HTTPBasicAuth(username, password), json=request.get_json() ) response.raise_for_status() response_data = response.json() if response_data.get("message") == "Account is valid and updated": cache.delete_memoized(user_accounts, key_prefix=make_cache_key) return jsonify(response_data), response.status_code except requests.exceptions.RequestException as e: return jsonify({"error": str(e)}), 500 @app.route("/get_stream_names", methods=["GET"]) def stream_names() -> Union[Response, str]: """Fetches and returns stream names as JSON.""" if not session.get("logged_in"): return redirect(url_for("home")) base_url = app.config["BASE_URL"] return jsonify(get_stream_names(base_url, session["auth_credentials"])) if __name__ == "__main__": app.run( debug=app.config["DEBUG"], host=app.config["HOST"], port=app.config["PORT"] )