大数据

sqlalchemy + pydantic_settings 实现动态库连接

需求:根据用户指定环境参数输入访问不同环境的datasource元数据表,然后根据指定元数据id创建对应conn链接,进而进行crud操作。

1、项目结构

multi_db_tool/
├── .env_dev                 # 开发环境配置
├── .env_prod                # 生产环境配置
├── config.py                # 环境配置加载(Pydantic)
├── db.py                    # 数据库连接与动态数据源管理
├── main.py                  # 命令行入口(支持 --env dev/prod)
└── main2.py                 # 交互操作 用户输入 env/prod

2、文件代码

.env_dev

main_db_host=192.168.0.xxx
main_db_port=3306
main_db_user=xxx
main_db_password=xxx
main_db_database=xxx

.env_prod

main_db_host=192.168.0.xxx
main_db_port=3306
main_db_user=xxx
main_db_password=xxx
main_db_database=xxx

config.py

from pydantic_settings import BaseSettings
from urllib.parse import quote_plus
from typing import Literal


class Settings(BaseSettings):
    """
    应用配置类,从.env_{env}文件加载数据库连接信息
    """
    # 标识当前环境
    env_name: str

    # 主元数据库连接参数 存放datasource表
    main_db_host: str
    main_db_port: int = 3306
    main_db_user: str
    main_db_password: str
    main_db_database: str

    @property
    def main_db_url(self) -> str:
        """
        构建安全的 MySQL 连接 URL,自动对用户名/密码进行 URL 编码。
        支持密码中包含 @、#、! 等特殊字符。
        """
        user = quote_plus(self.main_db_user)
        pwd = quote_plus(self.main_db_password)
        return f"mysql+pymysql://{user}:{pwd}@{self.main_db_host}:{self.main_db_port}/{self.main_db_database}"

    class Config:
        # 不在此处固定 env_file,由外部动态传入
        extra = "ignore"  # 忽略 .env 中未定义的字段

    @classmethod
    def from_env(cls, env: Literal["dev", "prod", "test"]) -> "Settings":
        """
        根据环境名称加载对应的 .env 文件。
        例如:env='dev' → 加载 .env_dev
        """
        env_file = f".env_{env}"
        return cls(_env_file=env_file, env_name=env)

db.py

"""
动态连接多数据源核心逻辑:
1. 从主库(metadata DB)读取 datasource 表
2. 按 ID 动态创建目标数据库连接(带连接池)
3. 支持跨库查询与写入
"""
from sqlalchemy import create_engine, text
from sqlalchemy.pool import QueuePool
from contextlib import contextmanager
from urllib.parse import quote_plus
from typing import Dict, Any

# 全局主库引擎 由main.py 初始化
main_engine = None


def init_main_engine(settings) -> None:
    """
    初始化主库(存放 datasource 表的数据库)连接引擎。
    """
    global main_engine
    main_engine = create_engine(
        settings.main_db_url,
        poolclass=QueuePool,
        pool_size=5,
        max_overflow=10,
        pool_pre_ping=True,  # 自动检测并重连失效连接
        echo=False
    )


def get_datasource_config(ds_id: str) -> Dict[str, Any]:
    """
    从主库的 datasource 表中读取指定 ID 的连接信息。
    返回字典,字段名作为 key(如 host, port, username...)。
    """
    with main_engine.connect() as conn:
        result = conn.execute(
            text("""
                SELECT host, port, username, password, default_database 
                FROM datasource 
                WHERE id = :id
            """),
            {"id": ds_id}
        ).mappings().fetchone()

    if not result:
        raise ValueError(f"Datasource ID {ds_id} not found in metadata DB")

    # 转为普通字典
    data = dict(result)

    # 字段重命名:default_database → database(便于后续使用)
    data["database"] = data.pop("default_database")

    # 确保端口为整数
    data["port"] = int(data["port"])

    # 对用户名和密码进行 URL 编码(支持 @、# 等特殊字符)
    data["username"] = quote_plus(data["username"])
    data["password"] = quote_plus(data["password"])

    return data


def get_engine_by_id(ds_id: str):
    """
    根据 datasource ID 动态创建目标数据库的 SQLAlchemy Engine。
    每次调用都读取最新配置,避免缓存导致的配置过期问题。
    """
    config = get_datasource_config(ds_id)
    url = f"mysql+pymysql://{config['username']}:{config['password']}@{config['host']}:{config['port']}/{config['database']}"
    return create_engine(
        url,
        poolclass=QueuePool,
        pool_size=5,
        max_overflow=10,
        pool_pre_ping=True,
        echo=False,
    )


@contextmanager
def get_connection(ds_id: str):
    """
    上下文管理器:安全获取目标数据库连接。
    自动处理连接创建与关闭,并在出错时抛出带上下文的异常。
    """
    try:
        engine = get_engine_by_id(ds_id)
        conn = engine.connect()
        yield conn
    except Exception as e:
        raise RuntimeError(f"Failed to connect to datasource {ds_id}: {e}") from e
    finally:
        if 'conn' in locals() and conn:
            conn.close()

main.py

# coding: utf-8
# author: 黄波
# file  : main.py
# time  : 2025-10-13 10:34
"""
命令行工具入口:支持 --env dev / --env prod 切换环境。
用于手动测试多数据源连接功能。
"""
import argparse
from config import Settings
from db import init_main_engine, get_connection, get_datasource_config
from sqlalchemy import text


def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(
        description="Multi-datasource database tool with environment support."
    )
    parser.add_argument(
        "--env",
        choices=["dev", "prod", "test"],
        default="dev",
        help="Specify environment: dev, prod, or test (default: dev)"
    )

    parser.add_argument(
        "--ds-id",
        type=str,
        default="1001",
        help="Datasource ID to test connection (default: 1001)"
    )

    args = parser.parse_args()

    print(f"🚀 Starting in '{args.env}' mode...")

    # 1. 加载对应环境的配置
    try:
        settings = Settings.from_env(args.env)
    except Exception as e:
        print(f"❌ Failed to load config from .env_{args.env}: {e}")
        return

    print(f"   Main DB: {settings.main_db_host}:{settings.main_db_port}/{settings.main_db_database}")

    # 2. 初始化主库连接
    try:
        init_main_engine(settings)
    except Exception as e:
        print(f"❌ Failed to connect to main metadata DB: {e}")
        return

    # 3. 测试读取datasource配置
    try:
        config = get_datasource_config(args.ds_id)
        print(f"✅ Loaded datasource {args.ds_id}: {config['host']}:{config['port']}/{config['database']}")
    except Exception as e:
        print(f"❌ Failed to load datasource {args.ds_id}: {e}")
        return

    # 4. (可选)测试连接目标库
    try:
        with get_connection(args.ds_id) as conn:
            version = conn.execute(text("SELECT VERSION()")).fetchone()[0]
            print(f"✅ Successfully connected to target DB. MySQL version: {version}")
    except Exception as e:
        print(f"❌ Failed to connect to target DB (ID={args.ds_id}): {e}")
        return

    print("🎉 All checks passed!")


if __name__ == '__main__':
    main()

main2.py

# coding: utf-8
# author: 黄波
# file  : main2.py
# time  : 2025-10-13 10:46
"""
交互式多数据源数据库客户端。
启动后提示用户选择环境(dev / prod / test),
自动加载配置、连接主库、读取 datasource 并测试目标库连接。
"""
from sqlalchemy import text
from config import Settings
from db import init_main_engine, get_connection, get_datasource_config


def get_user_env_choice() -> str:
    """
    交互式获取用户选择的环境。
    循环提示直到输入合法值。
    """
    valid_envs = ["dev", "prod", "test"]
    while True:
        print("\n请选择运行环境:")
        print("  [1] dev   - 开发环境")
        print("  [2] prod  - 生产环境")
        print("  [3] test  - 测试环境")
        choice = input("\n请输入选项编号或环境名称 (默认: dev): ").strip().lower()

        # 默认值
        if not choice:
            return "dev"

        # 支持输入编号或名称
        if choice in valid_envs:
            return choice
        elif choice == "1":
            return "dev"
        elif choice == "2":
            return "prod"
        elif choice == "3":
            return "test"
        else:
            print("❌ 无效输入,请输入 'dev'、'prod'、'test' 或对应编号 1/2/3。")


def get_datasource_id() -> str:
    """
    交互式获取用户输入的 datasource ID。
    """
    while True:
        try:
            ds_input = input("请输入要测试的 datasource ID (默认: 1001): ").strip()
            if not ds_input:
                return "1001"
            return ds_input
        except ValueError:
            print("❌ 请输入有效的整数 ID。")


def main():
    print("🔧 多数据源数据库连接测试工具")
    print("=" * 50)

    # 1. 获取用户选择的环境
    env = get_user_env_choice()
    print(f"\n✅ 已选择环境: {env}")

    # 2. 加载配置
    try:
        settings = Settings.from_env(env)
        print(f"   主库地址: {settings.main_db_host}:{settings.main_db_port}/{settings.main_db_database}")
    except Exception as e:
        print(f"❌ 加载 .env_{env} 配置失败: {e}")
        return

    # 3. 初始化主库连接
    try:
        init_main_engine(settings)
        print("✅ 主库(metadata DB)连接成功!")
    except Exception as e:
        print(f"❌ 无法连接主库: {e}")
        return

    # 4. 获取 datasource ID
    ds_id = get_datasource_id()
    print(f"\n🔍 正在加载 datasource ID: {ds_id}")

    # 5. 读取 datasource 配置
    try:
        config = get_datasource_config(ds_id)
        print(f"✅ 成功读取配置: {config['host']}:{config['port']}/{config['database']}")
    except Exception as e:
        print(f"❌ 无法从主库读取 datasource {ds_id}: {e}")
        return

    # 6. 测试目标库连接
    print(f"\n📡 正在测试连接目标数据库 (ID={ds_id})...")
    try:
        with get_connection(ds_id) as conn:
            version = conn.execute(text("SELECT VERSION()")).fetchone()[0]
            print(f"✅ 目标库连接成功!MySQL 版本: {version}")
    except Exception as e:
        print(f"❌ 无法连接目标库 (ID={ds_id}): {e}")
        return

    print("\n🎉 所有步骤完成!连接测试通过。")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n👋 用户中断,程序退出。")
    except Exception as e:
        print(f"\n💥 未预期错误: {e}")