文章
django 多数据源管理实现mysql、pg、doris表结构同步
django多数据源管理,实现数据源之间表结构同步,支持:mysql、postgresql、doris
models.py
import threading
import hashlib
from django.conf import settings
from django.db import models, connections
from datasource.utils.validators import validate_table_name
from datasource.services.schema_sync import SchemaSyncService
import pymysql
try:
import psycopg2
except ImportError:
psycopg2 = None
# 全局锁
_registry_lock = threading.Lock()
class DataSource(models.Model):
"""
数据源配置模型
对应数据库中的datasource表
每个实例代表一个可连接的目标数据库
"""
id = models.CharField(max_length=255, primary_key=True)
source_name = models.CharField(max_length=255)
host = models.CharField(max_length=255)
username = models.CharField(max_length=255)
password = models.CharField(max_length=255)
port = models.IntegerField()
default_database = models.CharField(max_length=255)
schema_name = models.CharField(max_length=255, null=True, blank=True)
db_alias = models.CharField(max_length=255, unique=True)
db_type = models.CharField(max_length=255) # mysql doris postgresql
delete_flag = models.CharField(max_length=2, default="0")
create_time = models.DateTimeField(auto_now_add=True)
update_time = models.DateTimeField(auto_now=True)
class Meta:
db_table = "datasource"
unique_together = [("db_alias", "db_type")]
def _get_connection_alias(self) -> str:
"""
为当前datasource实例生成唯一的django数据库别名
使用配置指纹避免重复注册
:return:
"""
# 生成配置指纹
config_tuple = (
self.host,
self.port,
self.username,
self.password,
self.default_database
)
fingerprint = hashlib.md5(repr(config_tuple).encode()).hexdigest()
alias = f"dynamic_{fingerprint[:8]}"
# 现成安全注册到 django DATABASES
with _registry_lock:
if alias not in settings.DATABASES:
self._register_to_django(alias)
return alias
def _register_to_django(self, alias: str):
"""
将当前实例的配置注册到django的DATABASES设置中
:param alias:
:return:
"""
db_type = self.db_type.lower()
# 1. 以 default 配置为基础(包含所有必需字段)
new_config = settings.DATABASES["default"].copy()
# 2. 覆盖连接相关参数
new_config.update({
"NAME": self.default_database,
"USER": self.username,
"PASSWORD": self.password,
"HOST": self.host,
"PORT": self.port
})
# 3. 根据数据库类型调整 ENGINE 和 OPTIONS
if db_type in ("mysql", "doris"):
new_config['ENGINE'] = 'django.db.backends.mysql'
new_config.setdefault('OPTIONS', {}).update({
'charset': 'utf8mb4',
'init_command': "SET sql_mode='STRICT_TRANS_TABLES'",
})
elif db_type == "postgresql":
new_config['ENGINE'] = 'django.db.backends.postgresql'
new_config['OPTIONS'] = {}
else:
raise ValueError(f"不支持的数据库类型:{db_type}, 请使用 mysql/doris/postgresql")
# 5. 注册到 DATABASES
settings.DATABASES[alias] = new_config
def remove_from_cache(self):
"""
从 Django 配置和连接池中移除该动态数据源
"""
alias = self._get_connection_alias()
if not alias.startswith("dynamic_"):
return
with _registry_lock:
settings.DATABASES.pop(alias, None)
connections.databases.pop(alias, None)
# 清理实际连接(避免僵尸连接)
if hasattr(connections._connections, 'local'):
connections._connections.local.pop(alias, None)
def get_effective_schema(self):
if self.db_type.lower() == "postgresql":
return self.schema_name or "public"
return None # MySQL/Doris 不用
def _execute_with_retry(self, sql: str, params: tuple = None, fetch: bool = True, is_dml: bool = False):
"""
安全执行 SQL,自动重试一次连接失效问题
:param sql: SQL 语句
:param params: 参数(用于 execute(sql, params))
:param fetch: 是否获取结果
:param is_dml: 是否为 DML(用于返回 rowcount)
:return: 查询结果 或 rowcount(DML) 或 None
"""
alias = self._get_connection_alias()
for attempt in range(2):
try:
with connections[alias].cursor() as cursor:
cursor.execute(sql, params or ())
if fetch:
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
elif is_dml:
return cursor.rowcount
else:
return None
except Exception as e:
err_str = str(e).lower()
is_connection_error = (
isinstance(e, (pymysql.OperationalError, pymysql.InterfaceError)) or
(psycopg2 and isinstance(e, (psycopg2.OperationalError, psycopg2.InterfaceError))) or
"has gone away" in err_str or
"connection already closed" in err_str or
"server closed" in err_str or
"broken pipe" in err_str
)
if is_connection_error and attempt == 0:
print(f"DataSource {self.id} 连接失效,清理缓存并重试: {e}")
self.remove_from_cache()
alias = self._get_connection_alias()
continue
raise
return None
def execute_query(self, sql: str):
"""
执行 SELECT 查询,返回字典列表
:param sql: 原生 SQL 查询语句
:return: List[Dict[column, value]]
"""
if not sql or not sql.strip().upper().startswith("SELECT"):
raise ValueError("execute_query 仅支持以 SELECT 开头的查询语句")
return self._execute_with_retry(sql, fetch=True)
def execute_update(self, sql: str) -> int:
"""
执行 INSERT/UPDATE/DELETE 语句,返回影响行数
:param sql: 原生 SQL 更新语句
:return: 受影响的行数
"""
if not sql:
raise ValueError("SQL语句不能为空")
upper_sql = sql.strip().upper()
if not (upper_sql.startswith('INSERT') or
upper_sql.startswith('UPDATE') or
upper_sql.startswith('DELETE')):
raise ValueError("execute_update 仅支持 INSERT/UPDATE/DELETE 语句")
return self._execute_with_retry(sql, fetch=False, is_dml=True)
def execute_ddl(self, sql: str) -> None:
"""
执行 DDL 语句(如 CREATE/ALTER/DROP TABLE)
不返回行数,因为 DDL 通常不产生 rowcount(或值无意义)
:param sql: 原生 DDL SQL 语句
"""
if not sql:
raise ValueError("SQL语句不能为空")
# 按分号分割,去除空语句
statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()]
if not statements:
return
for stmt in statements:
upper_stmt = stmt.strip().upper()
if not (upper_stmt.startswith('CREATE') or
upper_stmt.startswith('ALTER') or
upper_stmt.startswith('DROP') or
upper_stmt.startswith('TRUNCATE') or
upper_stmt.startswith('COMMENT')):
raise ValueError(f"execute_ddl 仅支持 DDL 语句,发现非法语句: {stmt[:60]}...")
self._execute_with_retry(stmt, fetch=False)
def get_all_tables(self):
"""
获取当前数据源中所有用户表(不含视图)的表名列表
支持 mysql/doris 和 postgresql
:return: List[str] 表名列表
"""
db_type = self.db_type.lower()
if db_type in ("mysql", "doris"):
rows = self._execute_with_retry("SHOW TABLES", fetch=True)
return [row["Tables_in_" + self.default_database] for row in rows]
elif db_type == "postgresql":
schema = self.schema_name or "public"
rows = self._execute_with_retry("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
""", (schema,), fetch=True)
return [row["table_name"] for row in rows]
else:
raise NotImplementedError(f"暂不支持 {self.db_type} 的表列表查询")
def get_table_columns(self, table_name: str):
"""
获取指定表的字段元数据,**关键增强:返回 is_primary_key 字段**
支持:
- MySQL / Doris:通过 information_schema.COLUMNS.COLUMN_KEY 判断主键
- PostgreSQL:通过联合查询 key_column_usage 获取主键信息
返回字段说明:
- name: 字段名
- type: 原始类型(如 'varchar(255)', 'int')
- nullable: bool
- default: 默认值(字符串形式,None 表示无默认值)
- comment: 字段注释
- extra: 额外信息(如 auto_increment)
- is_primary_key: bool(关键!用于 Doris 同步时字段重排)
:param table_name: 表名(已校验安全)
:return: List[Dict]
"""
table_name = validate_table_name(table_name)
if not table_name or not isinstance(table_name, str):
raise ValueError("表名不能为空且必须为字符串")
clean_name = table_name.replace("_", "").replace("-", "").replace(".", "")
if not clean_name.isalnum():
raise ValueError("表名包含非法字符")
db_type = self.db_type.lower()
if db_type in ("mysql", "doris"):
# 查询字段信息 + 主键标识(COLUMN_KEY = 'PRI' 表示主键)
rows = self._execute_with_retry("""
SELECT
COLUMN_NAME,
COLUMN_TYPE,
IS_NULLABLE,
COLUMN_DEFAULT,
COLUMN_COMMENT,
EXTRA,
COLUMN_KEY -- 用于判断是否为主键
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION
""", (self.default_database, table_name), fetch=True)
return [{
"name": row["COLUMN_NAME"],
"type": row["COLUMN_TYPE"],
"nullable": row["IS_NULLABLE"] == "YES",
"default": row["COLUMN_DEFAULT"],
"comment": row["COLUMN_COMMENT"] or "",
"extra": row["EXTRA"] or "",
"is_primary_key": row["COLUMN_KEY"] == "PRI" # 关键新增字段
} for row in rows]
elif db_type == "postgresql":
schema = self.schema_name or "public"
# PostgreSQL 需要联合查询主键信息
rows = self._execute_with_retry("""
SELECT
c.column_name,
c.data_type,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
c.is_nullable,
c.column_default,
pg_catalog.col_description(
(SELECT oid FROM pg_class WHERE relname = %s AND relnamespace =
(SELECT oid FROM pg_namespace WHERE nspname = %s)
), c.ordinal_position
) AS column_comment,
-- 判断是否为主键
CASE
WHEN pk_cols.column_name IS NOT NULL THEN TRUE
ELSE FALSE
END AS is_primary_key
FROM information_schema.columns c
LEFT JOIN (
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
) pk_cols ON c.column_name = pk_cols.column_name
WHERE c.table_schema = %s AND c.table_name = %s
ORDER BY c.ordinal_position
""", (
table_name, schema, # for col_description
schema, table_name, # for pk subquery
schema, table_name # for main where
), fetch=True)
columns = []
for row in rows:
col_name = row["column_name"]
data_type = row["data_type"]
char_len = row["character_maximum_length"]
num_prec = row["numeric_precision"]
num_scale = row["numeric_scale"]
is_nullable = row["is_nullable"]
col_default = row["column_default"]
comment = row["column_comment"] or ""
is_pk = row["is_primary_key"]
# 类型标准化(与 MySQL/Doris 风格对齐)
if data_type == 'character varying':
display_type = f"varchar({char_len})" if char_len else "varchar"
elif data_type == 'character':
display_type = f"char({char_len})" if char_len else "char"
elif data_type == 'numeric':
if num_prec is not None and num_scale is not None:
display_type = f"numeric({num_prec},{num_scale})"
else:
display_type = "numeric"
elif data_type == 'integer':
display_type = "int"
elif data_type == 'bigint':
display_type = "bigint"
elif data_type == 'timestamp without time zone':
display_type = "timestamp"
elif data_type == 'double precision':
display_type = "double"
else:
display_type = data_type
columns.append({
"name": col_name,
"type": display_type,
"nullable": is_nullable == "YES",
"default": col_default,
"comment": comment,
"extra": "",
"is_primary_key": is_pk # 关键新增字段
})
return columns
else:
raise NotImplementedError(f"暂不支持 {self.db_type} 的字段元数据查询")
def get_create_table_sql(self, table_name: str) -> str:
"""
获取完整建表脚本(含注释、主键等)
- MySQL/Doris: 使用原生 SHOW CREATE TABLE(最准确)
- PostgreSQL: 通过字段元数据重建,返回多语句脚本(含 COMMENT ON)
"""
table_name = validate_table_name(table_name)
if not table_name or not isinstance(table_name, str):
raise ValueError("表名不能为空且必须为字符串")
clean_name = table_name.replace("_", "").replace("-", "").replace(".", "")
if not clean_name.isalnum():
raise ValueError("表名包含非法字符,仅允许字母、数字、下划线、中划线和点")
db_type = self.db_type.lower()
if db_type in ("mysql", "doris"):
rows = self._execute_with_retry(f"SHOW CREATE TABLE `{table_name}`", fetch=True)
if not rows:
raise ValueError(f"表 {table_name} 不存在")
return rows[0]["Create Table"]
elif db_type == "postgresql":
# 不再手动拼接,而是复用 SchemaSyncService 的通用生成器
columns = self.get_table_columns(table_name)
table_comment = self.get_table_comment(table_name)
schema = self.get_effective_schema()
return SchemaSyncService.build_create_table_sql(
columns=columns,
target_db_type="postgresql",
target_table=table_name,
target_schema=schema,
table_comment=table_comment,
include_comments=True # 明确要求包含 COMMENT ON 语句
)
else:
raise NotImplementedError(f"暂不支持 {self.db_type} 的建表语句查询")
def get_table_comment(self, table_name: str) -> str | None:
table_name = validate_table_name(table_name)
db_type = self.db_type.lower()
if db_type in ("mysql", "doris"):
rows = self._execute_with_retry("""
SELECT TABLE_COMMENT
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
""", (self.default_database, table_name), fetch=True)
return rows[0]["TABLE_COMMENT"] if rows else None
elif db_type == "postgresql":
schema = self.schema_name or "public"
rows = self._execute_with_retry("""
SELECT obj_description(c.oid, 'pg_class') AS table_comment
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = %s AND n.nspname = %s;
""", (table_name, schema), fetch=True)
return rows[0]["table_comment"] if rows else None
return Noneservices/schema_sync.py
import re
from datasource.utils.type_mapper import map_column_type
from datasource.utils.validators import validate_table_name
from typing import Dict, Any, List, Optional
"""
跨数据源表结构同步服务
协调两个 DataSource 实例,将源表结构转换并创建到目标数据源中
"""
class SchemaSyncService:
"""
表结构同步服务类
职责:
- 同类型数据库同步(MySQL→MySQL, PG→PG, Doris→Doris):
直接复用源库的完整建表脚本(通过 get_create_table_sql),替换表名/schema 后执行
- 异构数据库同步(如 MySQL→PG):
解析源表字段元数据,映射为目标数据库类型,重建 DDL
"""
@staticmethod
def sync_table(
source_ds,
target_ds,
source_table: str,
target_table: str = None
) -> Dict[str, Any]:
"""
同步表结构:从源数据源的表 → 目标数据源的表
:param source_ds: 源数据源实例
:param target_ds: 目标数据源实例
:param source_table: 源表名
:param target_table: 目标表名(默认与源表同名)
:return: 包含 success / sql / error / message 的结果字典
"""
# 先处理默认值
if not source_table or not source_table.strip():
raise ValueError("源表名不能为空")
source_table = source_table.strip()
target_table = (target_table or source_table).strip()
# 防止自同步
if source_ds.id == target_ds.id and source_table == target_table:
raise ValueError("禁止将表同步到自身")
# 校验并标准化表名
source_table = validate_table_name(source_table)
target_table = validate_table_name(target_table)
# 目标表不能已存在
if target_table in target_ds.get_all_tables():
raise ValueError(f"目标表 {target_table} 已存在")
source_type = source_ds.db_type.lower()
target_type = target_ds.db_type.lower()
# 【同类型同步】直接复用完整建表脚本
if source_type == target_type:
try:
# 获取源库完整建表脚本(PG 包含 COMMENT ON 语句)
create_sql = source_ds.get_create_table_sql(source_table)
if not create_sql:
raise ValueError(f"无法获取源表 {source_table} 的建表语句")
# 替换表名和 schema(仅 PostgreSQL 需要处理 schema)
final_sql = create_sql
if source_type == "postgresql":
src_schema = source_ds.get_effective_schema()
tgt_schema = target_ds.get_effective_schema()
# 安全替换所有 "schema"."table" 引用(包括 CREATE 和 COMMENT ON 中)
pattern = re.compile(rf'"{re.escape(src_schema)}"\s*\.\s*"{re.escape(source_table)}"')
final_sql = pattern.sub(f'"{tgt_schema}"."{target_table}"', create_sql)
else:
# MySQL/Doris: 替换反引号中的表名(覆盖 CREATE TABLE 和可能的 COMMENT)
final_sql = create_sql.replace(f"`{source_table}`", f"`{target_table}`")
# 执行完整脚本(execute_ddl 已支持多语句)
target_ds.execute_ddl(final_sql)
return {
"success": True,
"sql": final_sql,
"message": f"表 {target_table} 已通过原生SQL同步创建"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"sql": create_sql if 'create_sql' in locals() else None
}
# 【异构同步】走字段解析 + 重建路径(如 MySQL → PostgreSQL)
try:
# 步骤1:从源库获取字段元数据(含 is_primary_key、注释等)
columns = source_ds.get_table_columns(source_table)
if not columns:
raise ValueError(f"源表 '{source_table}' 不存在或无任何字段")
table_comment = source_ds.get_table_comment(source_table)
# 步骤2:生成目标数据库的建表 SQL
create_sql = SchemaSyncService.build_create_table_sql(
columns=columns,
source_db_type=source_type,
target_db_type=target_type,
target_table=target_table,
target_schema=target_ds.get_effective_schema(),
table_comment=table_comment,
include_comments=False # 异构路径由 sync_table 自己追加注释
)
# 步骤3:在目标库执行 DDL
target_ds.execute_ddl(create_sql)
# PostgreSQL 字段注释需额外执行(因异构路径未包含 COMMENT ON)
if target_type == "postgresql":
comment_statements = []
target_schema = target_ds.get_effective_schema()
# 表注释
if table_comment:
safe_table_comment = table_comment.replace("'", "''")
comment_statements.append(
f'COMMENT ON TABLE "{target_schema}"."{target_table}" IS \'{safe_table_comment}\';'
)
# 字段注释
for col in columns:
comment = col.get("comment", "") or ""
if comment:
safe_comment = comment.replace("'", "''")
stmt = f'COMMENT ON COLUMN "{target_schema}"."{target_table}"."{col["name"]}" IS \'{safe_comment}\';'
comment_statements.append(stmt)
if comment_statements:
comment_script = "\n".join(comment_statements)
target_ds.execute_ddl(comment_script)
# 将注释语句追加到返回的 SQL 中,便于日志/调试
create_sql += "\n\n-- PostgreSQL 字段注释:\n" + comment_script
return {
"success": True,
"sql": create_sql,
"message": f"表 {target_table} 已成功创建于目标数据源"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"sql": create_sql if 'create_sql' in locals() else None
}
@staticmethod
def build_create_table_sql(
columns: List[Dict],
source_db_type: str, # 源数据库类型(如 'mysql')
target_db_type: str, # 目标数据库类型(如 'postgresql')
target_table: str,
target_schema: Optional[str] = None,
table_comment: Optional[str] = None,
include_comments: bool = False, # 新增参数:是否包含 COMMENT ON 语句
) -> str:
"""
【通用建表语句生成器】
根据字段元数据生成目标数据库的 CREATE TABLE 语句
:param columns: 字段元数据列表,每个元素包含 name/type/nullable/default/comment/is_primary_key
:param source_db_type: 源数据库类型(用于类型映射和默认值转换)
:param target_db_type: 目标数据库类型
:param target_table: 目标表名
:param target_schema: 目标 schema(仅 PostgreSQL 使用)
:param table_comment: 表注释
:param include_comments: 是否在返回值中包含 COMMENT ON 语句(仅 PostgreSQL 有效)
:return: 完整的 CREATE TABLE SQL 字符串(可能包含多条语句)
"""
target_db_type = target_db_type.lower()
target_table = target_table.strip()
# 强制主键字段为 NOT NULL(安全兜底)
for col in columns:
if col.get("is_primary_key"):
col["nullable"] = False
# 提取主键字段名(保持源表顺序)
pk_columns = [col["name"] for col in columns if col.get("is_primary_key", False)]
# ========== MySQL / PostgreSQL ==========
if target_db_type in ("mysql", "postgresql"):
# 构建每列定义(正确传递 source_db_type 和 target_db_type)
column_defs = []
for col in columns:
col_def = SchemaSyncService._build_column_def(
col=col,
source_db_type=source_db_type,
target_type=target_db_type,
is_key=False # 主键约束在表级定义,非列级
)
column_defs.append(col_def)
# 拼接字段定义
if target_db_type == "mysql":
# MySQL: 反引号包裹表名和字段名
sql = f"CREATE TABLE `{target_table}` (\n " + ",\n ".join(column_defs)
# 添加主键约束
if pk_columns:
pk_clause = ", ".join(f"`{name}`" for name in pk_columns)
sql += f",\n PRIMARY KEY ({pk_clause})"
sql += "\n)"
# 表注释(内联)
if table_comment:
escaped_comment = table_comment.replace("'", "\\'")
sql += f" COMMENT='{escaped_comment}'"
sql += ";"
return sql
else: # postgresql
schema = target_schema or "public"
# 双引号包裹 schema 和表名
sql = f'CREATE TABLE "{schema}"."{target_table}" (\n ' + ",\n ".join(column_defs)
if pk_columns:
pk_clause = ", ".join(f'"{name}"' for name in pk_columns)
sql += f',\n PRIMARY KEY ({pk_clause})'
sql += "\n);"
# 是否附加 COMMENT ON 语句?
if include_comments:
statements = [sql]
# 表注释
if table_comment:
escaped = table_comment.replace("'", "''")
statements.append(f'COMMENT ON TABLE "{schema}"."{target_table}" IS \'{escaped}\';')
# 字段注释
for col in columns:
comment = col.get("comment", "") or ""
if comment:
escaped = comment.replace("'", "''")
statements.append(
f'COMMENT ON COLUMN "{schema}"."{target_table}"."{col["name"]}" IS \'{escaped}\';'
)
return "\n".join(statements)
else:
return sql
# ========== Doris 特殊处理 ==========
if not columns:
raise ValueError("Doris 表结构不能为空")
has_explicit_pk = any(col.get("is_primary_key", False) for col in columns)
if has_explicit_pk:
key_columns = [col for col in columns if col.get("is_primary_key", False)]
value_columns = [col for col in columns if not col.get("is_primary_key", False)]
else:
# 无主键:取第一个字段作为 Key(DUPLICATE KEY 模型)
key_columns = [columns[0]]
value_columns = columns[1:]
# 重排字段顺序:Key 在前,Value 在后
reordered_columns = key_columns + value_columns
key_col_names = [col["name"] for col in key_columns]
# 生成列定义(Key 列强制 NOT NULL)
column_defs = []
for col in reordered_columns:
is_key = col["name"] in key_col_names
col_def = SchemaSyncService._build_column_def(
col=col,
source_db_type=source_db_type,
target_type="doris",
is_key=is_key
)
column_defs.append(col_def)
# 拼接 CREATE TABLE 主体
sql = f"CREATE TABLE `{target_table}` (\n " + ",\n ".join(column_defs) + "\n)"
# 添加 Doris 模型定义
key_clause = ", ".join(f"`{name}`" for name in key_col_names)
first_key = f"`{key_col_names[0]}`"
if has_explicit_pk:
sql += f"\nUNIQUE KEY({key_clause})"
else:
sql += f"\nDUPLICATE KEY({first_key})"
# 表注释紧跟 KEY 子句
if table_comment:
escaped_comment = table_comment.replace("'", "\\'")
sql += f" COMMENT \"{escaped_comment}\""
# 再跟 DISTRIBUTED BY 和 PROPERTIES
sql += (
f"\nDISTRIBUTED BY HASH({first_key}) BUCKETS 32\n"
f"PROPERTIES(\"replication_num\" = \"1\");"
)
return sql
@staticmethod
def _map_default_value(
default_value: Optional[str],
source_db_type: str,
target_db_type: str,
column_type: Optional[str] = None
) -> Optional[str]:
"""
智能映射字段默认值,确保目标数据库兼容性
:param default_value: 原始默认值字符串(来自 information_schema)
:param source_db_type: 源数据库类型
:param target_db_type: 目标数据库类型
:param column_type: 目标字段类型(用于判断是否为字符串类型,仅 Doris 需要)
:return: 映射后的默认值字符串(如 "'0'", "CURRENT_TIMESTAMP"),或 None(表示无默认值)
"""
if default_value is None:
return None
src = source_db_type.lower()
tgt = target_db_type.lower()
dv = str(default_value).strip()
dv_lower = dv.lower()
# 统一处理 NULL
if dv_lower == "null":
return "NULL"
# 时间函数映射
if dv_lower in ("now()", "current_timestamp"):
if tgt in ("mysql", "postgresql", "doris"):
return "CURRENT_TIMESTAMP"
# Doris 限制:只支持常量
if tgt == "doris":
# 判断是否为字符串类型
if column_type:
lower_col_type = column_type.lower()
is_string_type = any(t in lower_col_type for t in ("char", "text", "varchar", "string"))
else:
is_string_type = True # 安全兜底
if is_string_type:
if dv.startswith("'") and dv.endswith("'"):
return dv
else:
return f"'{dv}'"
else:
# 数值类型:只允许纯数字
if dv.replace('.', '', 1).replace('-', '', 1).isdigit():
return dv
else:
return None
# 默认:原样返回(适用于同类型或安全常量)
return dv
@staticmethod
def _build_column_def(col: Dict, source_db_type: str, target_type: str, is_key: bool = False) -> str:
"""
构建单个字段的列定义语句
:param col: 字段元数据字典(含 name/type/nullable/default/comment)
:param source_db_type: 源数据库类型
:param target_type: 目标数据库类型
:param is_key: 是否为 Doris 的 Key 列(影响 NOT NULL 和默认值处理)
:return: 单列定义字符串,如 "`id` INT NOT NULL COMMENT '用户ID'"
"""
col_name = col["name"]
src_type_str = col["type"]
nullable = col["nullable"]
default = col.get("default")
comment = col.get("comment", "") or ""
# 1. 类型映射
mapped_type = map_column_type(src_type_str, source_db_type, target_type)
# 2. 字段名引号处理
if target_type in ("mysql", "doris"):
col_def = f"`{col_name}` {mapped_type}"
else: # postgresql
col_def = f'"{col_name}" {mapped_type}'
# 3. 默认值处理
mapped_default = SchemaSyncService._map_default_value(
default_value=default,
source_db_type=source_db_type,
target_db_type=target_type,
column_type=mapped_type
)
if mapped_default is not None:
if mapped_default == "NULL":
if target_type in ("mysql", "doris") and not is_key:
col_def += " DEFAULT NULL"
else:
col_def += f" DEFAULT {mapped_default}"
# 4. 非空约束
if is_key:
# Doris/MySQL/PG: 主键必须 NOT NULL
col_def += " NOT NULL"
elif not nullable:
# Doris 特殊规则:有默认值的非主键字段不能显式写 NOT NULL
if target_type == "doris" and mapped_default is not None:
pass
else:
col_def += " NOT NULL"
# 5. 字段注释(仅 MySQL/Doris 支持内联 COMMENT)
if target_type in ("mysql", "doris") and comment:
escaped_comment = comment.replace("'", "\\'")
col_def += f" COMMENT '{escaped_comment}'"
return col_defutils/type_mapper.py
import re
from typing import Optional, Tuple
"""
数据库字段类型映射工具
功能:
将源数据库的字段类型(如 MySQL 的 `varchar(255)` 或 `int unsigned`)
转换为目标数据库兼容的类型(如 Doris 的 `varchar(65533)` 或 `bigint`)
设计原则:
- 规则表驱动:新增类型只需在 _TYPE_MAPPING 中添加一行
- 保持源表语义:如 MySQL 的 unsigned int → Doris 的 bigint(避免溢出)
- 兜底安全:未知类型统一转为 string(Doris)或 text(其他)
使用示例:
map_column_type("varchar(255)", "mysql", "doris") → "varchar(255)"
map_column_type("int unsigned", "mysql", "doris") → "bigint"
map_column_type("json", "mysql", "doris") → "string"
"""
def parse_column_type(col_type: str, src_db_type: str) -> Tuple[str, Optional[int], Optional[str]]:
"""
解析字段类型字符串,提取基础类型、长度、精度等信息。
特别处理:
- PostgreSQL 的 "character varying(100)" 或标准化后的 "varchar(100)" → ("varchar", 100, None)
- MySQL 的 "int unsigned" → ("int_unsigned", None, None)
- decimal(10,2) → ("decimal", None, "10,2")
:param col_type: 原始类型字符串,如 "varchar(255)", "int unsigned"
:param src_db_type: 源数据库类型("mysql", "postgresql", "doris")
:return: (base_type, length, prec_scale)
- base_type: 标准化基础类型(可能含 _unsigned 后缀)
- length: 字符长度(仅 char/varchar)
- prec_scale: 精度字符串,如 "10,2"
"""
col_type = col_type.strip().lower()
src_db_type = src_db_type.lower()
# ========== PostgreSQL 特殊类型处理 ==========
if src_db_type == "postgresql":
# 处理 numeric / decimal
if col_type.startswith(("numeric", "decimal")):
match = re.search(r'\((\d+),\s*(\d+)\)', col_type)
prec_scale = f"{match.group(1)},{match.group(2)}" if match else None
return "decimal", None, prec_scale
# 统一处理 varchar:支持 "character varying(N)" 和 "varchar(N)"
varchar_match = re.match(r"^(?:character varying|varchar)\s*\((\d+)\)", col_type)
if varchar_match:
length = int(varchar_match.group(1))
return "varchar", length, None
# 无长度的 varchar / character varying
if col_type in ("character varying", "varchar"):
return "varchar", None, None
# 统一处理 char:支持 "character(N)" 和 "char(N)"
char_match = re.match(r"^(?:character|char)\s*\((\d+)\)", col_type)
if char_match:
length = int(char_match.group(1))
return "char", length, None
# 无长度的 char / character
if col_type in ("character", "char"):
return "char", None, None
# 标准化其他常见类型
pg_type_map = {
"integer": "int",
"bigint": "bigint",
"double precision": "double",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"boolean": "boolean",
"text": "text",
}
base_type = pg_type_map.get(col_type, col_type)
return base_type, None, None
# ========== MySQL / Doris 类型处理(语法高度兼容) ==========
# 提取括号前的基础类型(如 "int unsigned" → base_part="int unsigned")
if "(" in col_type:
base_part = col_type.split("(", 1)[0].strip()
inner = col_type.split("(", 1)[1].rstrip(")")
# 解析括号内内容
if "," in inner:
prec_scale = inner # 如 "10,2"
length = None
else:
try:
length = int(inner)
prec_scale = None
except ValueError:
length = None
prec_scale = None
else:
base_part = col_type
length = None
prec_scale = None
# 标准化基础类型(取第一个词)
base_type = base_part.split()[0] # 如 "int" from "int unsigned"
# ========== 特别处理 MySQL 的 unsigned 类型 ==========
if src_db_type == "mysql" and "unsigned" in col_type:
if base_type == "tinyint":
base_type = "smallint" # tinyint unsigned → smallint
elif base_type == "smallint":
base_type = "int" # smallint unsigned → int
elif base_type == "int":
base_type = "int_unsigned"
elif base_type == "bigint":
base_type = "bigint_unsigned"
elif base_type == "mediumint":
base_type = "int"
return base_type, length, prec_scale
# ========== 类型映射规则表 ==========
# 结构:{ 标准化基础类型: { 目标数据库: 映射函数 } }
# 映射函数签名:(length: Optional[int], prec_scale: Optional[str]) -> str
_TYPE_MAPPING = {
# 字符串类型
"varchar": {
"mysql": lambda length, _: f"varchar({length})" if length else "varchar(255)",
"doris": lambda length, _: f"varchar({min(length or 65533, 65533)})",
"postgresql": lambda length, _: f"varchar({length})" if length else "varchar",
},
"char": {
"mysql": lambda length, _: f"char({length or 1})",
"doris": lambda length, _: f"char({length or 1})",
"postgresql": lambda length, _: f"char({length or 1})",
},
# 整数类型
"int": {
"mysql": lambda _, __: "int",
"doris": lambda _, __: "int",
"postgresql": lambda _, __: "int",
},
"int_unsigned": { # 来自 MySQL 的 int unsigned
"mysql": lambda _, __: "int unsigned",
"doris": lambda _, __: "bigint", # Doris 无 unsigned,升为 bigint 防溢出
"postgresql": lambda _, __: "bigint",
},
"bigint": {
"mysql": lambda _, __: "bigint",
"doris": lambda _, __: "bigint",
"postgresql": lambda _, __: "bigint",
},
"bigint_unsigned": {
"mysql": lambda _, __: "bigint unsigned",
"doris": lambda _, __: "bigint", # Doris bigint 范围足够(-2^63 ~ 2^63-1)
"postgresql": lambda _, __: "numeric(20,0)", # 安全起见用 numeric
},
"smallint": {
"mysql": lambda _, __: "smallint",
"doris": lambda _, __: "smallint",
"postgresql": lambda _, __: "smallint",
},
"tinyint": {
"mysql": lambda _, __: "tinyint(1)",
"doris": lambda _, __: "boolean", # Doris 常用 boolean 表示 0/1
"postgresql": lambda _, __: "boolean",
},
# 浮点/定点类型
"double": {
"mysql": lambda _, __: "double",
"doris": lambda _, __: "double",
"postgresql": lambda _, __: "double precision",
},
"float": {
"mysql": lambda _, __: "float",
"doris": lambda _, __: "float",
"postgresql": lambda _, __: "real",
},
"decimal": {
"mysql": lambda _, prec_scale: f"decimal({prec_scale or '10,0'})",
"doris": lambda _, prec_scale: f"decimal({prec_scale or '10,0'})",
"postgresql": lambda _, prec_scale: f"numeric({prec_scale or '10,0'})",
},
# 时间类型
"timestamp": {
"mysql": lambda _, __: "timestamp",
"doris": lambda _, __: "datetime", # Doris 无 timestamp,统一用 datetime
"postgresql": lambda _, __: "timestamp without time zone",
},
"datetime": {
"mysql": lambda _, __: "datetime",
"doris": lambda _, __: "datetime",
"postgresql": lambda _, __: "timestamp without time zone",
},
"date": {
"mysql": lambda _, __: "date",
"doris": lambda _, __: "date",
"postgresql": lambda _, __: "date",
},
# 大对象/特殊类型
"text": {
"mysql": lambda _, __: "text",
"doris": lambda _, __: "string", # Doris 推荐用 string 替代 text
"postgresql": lambda _, __: "text",
},
"mediumtext": {
"mysql": lambda _, __: "mediumtext",
"doris": lambda _, __: "string",
"postgresql": lambda _, __: "text",
},
"longtext": {
"mysql": lambda _, __: "longtext",
"doris": lambda _, __: "string",
"postgresql": lambda _, __: "text",
},
"json": {
"mysql": lambda _, __: "json",
"doris": lambda _, __: "string", # Doris 用 string 存 JSON
"postgresql": lambda _, __: "jsonb", # PostgreSQL 推荐 jsonb
},
"enum": {
"mysql": lambda _, __: "varchar(255)",
"doris": lambda _, __: "varchar(65533)",
"postgresql": lambda _, __: "varchar(255)",
},
"set": {
"mysql": lambda _, __: "varchar(255)",
"doris": lambda _, __: "varchar(65533)",
"postgresql": lambda _, __: "varchar(255)",
},
"boolean": {
"mysql": lambda _, __: "tinyint(1)",
"doris": lambda _, __: "boolean",
"postgresql": lambda _, __: "boolean",
},
}
def map_column_type(
src_type_str: str,
src_db_type: str,
target_db_type: str
) -> str:
"""
将源数据库字段类型映射为目标数据库兼容类型。
流程:
1. 解析源类型 → (base_type, length, prec_scale)
2. 查找 _TYPE_MAPPING[base_type][target_db]
3. 若找不到,兜底返回安全类型(Doris 用 string,其他用 text)
:param src_type_str: 源字段类型,如 "varchar(255)" 或 "int unsigned"
:param src_db_type: 源数据库类型("mysql", "postgresql", "doris")
:param target_db_type: 目标数据库类型
:return: 目标数据库的字段类型定义
"""
src_db_type = src_db_type.lower()
target_db_type = target_db_type.lower()
# 1. 解析源类型
base_type, length, prec_scale = parse_column_type(src_type_str, src_db_type)
# 2. 查找映射规则
if base_type in _TYPE_MAPPING:
mapper = _TYPE_MAPPING[base_type].get(target_db_type)
if mapper:
return mapper(length, prec_scale)
# 3. 无匹配规则:安全兜底
if target_db_type == "doris":
return "string" # Doris 最安全的通用类型
else:
return "text" # MySQL/PostgreSQL 通用大文本类型utils/validators.py
import re
def validate_table_name(name: str) -> str:
"""
校验表名是否合法(仅允许字母、数字、下划线,且以字母或下划线开头)
:param name: 表名字符串
:return: 校验通过的表名(已 strip)
:raises ValueError: 表名不合法
"""
if not isinstance(name, str):
raise ValueError("表名必须为字符串")
name = name.strip()
if not name:
raise ValueError("表名不能为空")
if not re.fullmatch(r"[a-zA-Z_][a-zA-Z0-9_]*", name):
raise ValueError(
f"表名 '{name}' 不合法:必须以字母或下划线开头,仅包含字母、数字、下划线"
)
return name
settings.py修改:
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.mysql',
'NAME': "db_name",
"USER": "root",
"PASSWORD": "root",
"HOST": "127.0.0.1",
"PORT": 3306,
"OPTIONS": {
"charset": "utf8mb4"
}
}
}
# 允许运行时动态添加数据库
DATABASES = DATABASES.copy()