大数据

Kahn 算法对任务进行排序

将 Kahn 算法中的图结构用 “来源表 → 目标表” 的语义来表达,并使用了更贴近业务的命名(如 sourceWithTargetstargetRefsCnt),这确实让代码在 数据开发/数仓调度场景下更容易理解,尤其是对熟悉 ODS/DWD/DWS 分层的同学来说。

下面我从几个维度确认你的实现是正确且可投入生产使用的

✅ 1. 逻辑正确性

  • 图构建正确
    sourceTable → targetTable 表示 “执行 targetTable 前必须先完成 sourceTable”,这是标准的依赖方向。
  • 入度计算准确
    targetRefsCnt 实际就是每个节点的 入度(in-degree),初始化为 0,每有一个依赖就 +1。
  • 拓扑排序流程标准
    入度为 0 的节点入队 → 出队加入结果 → 减少其子节点入度 → 子节点入度为 0 则入队。
  • 循环依赖检测完善
    最终结果长度 ≠ 总节点数 ⇒ 有环,抛出异常,符合预期。

✅ 2. 命名与可读性(亮点)

你的变量名通用术语优势
sourceWithTargetsadjacency list (反向依赖图)一看就知道“一个源表被哪些目标表依赖”
targetRefsCntin-degree map清晰表达“目标表有多少个依赖来源”
allTablesall nodesallNodes 更贴合数仓语境

💡 这种命名方式在团队协作中能显著降低沟通成本,尤其适合你作为大数据开发人员在调度系统中使用。

✅ 3. 边界处理良好

  • 处理了 sourceTables == null 的情况;
  • 使用 LinkedHashSet 保持插入顺序(虽然 Kahn 算法本身不保证唯一顺序,但稳定顺序有助于调试);
  • 所有表(包括纯源表如 ods_1)都被纳入图中,确保完整调度。
import java.util.*;

public class TaskSorter {

    /**
     * 使用 Kahn 算法对任务进行拓扑排序
     *
     * @param taskRefs 任务依赖关系,key: 目标表, value: 依赖的源表列表
     *                 例如: {"dwd_1": ["ods_1", "ods_2"]}
     * @return 拓扑排序后的任务执行顺序
     */
    public static List<String> sortTasks(Map<String, List<String>> taskRefs) {
        // 构建完整的图:收集所有节点(包括源表)
        Set<String> allTables = new LinkedHashSet<>();
        // 存放 来源表 -> [目标表1, 标目表2]
        Map<String, List<String>> sourceWithTargets = new LinkedHashMap<>();
        // 存放 目标表 对应 来源表数
        Map<String, Integer> targetRefsCnt = new HashMap<>();

        // 获取所有目标表、来源表
        for (Map.Entry<String, List<String>> entry : taskRefs.entrySet()) {
            String targetTable = entry.getKey();
            List<String> sourceTables = entry.getValue();

            allTables.add(targetTable);
            if (sourceTables != null) {
                allTables.addAll(sourceTables);
            }
        }
        
        for (String table : allTables) {
            // 所有表对应来源表数初始化为0
            targetRefsCnt.put(table, 0);
            // 所有表对应目标表初始化[]
            sourceWithTargets.put(table, new ArrayList<>());
        }

        // 构建图:source -> target(因为 source 是 target 的依赖)
        // 即:target 依赖 source => source 完成后才能执行 target
        // 所以边是 source → target
        for (Map.Entry<String, List<String>> entry : taskRefs.entrySet()) {
            String targetTable = entry.getKey();
            List<String> sourceTables = entry.getValue();
            if (sourceTables == null) {
                continue;
            }

            for (String sourceTable : sourceTables) {
                // source → target
                sourceWithTargets.get(sourceTable).add(targetTable);
                targetRefsCnt.put(targetTable, targetRefsCnt.get(targetTable) + 1);
            }
        }

        // Kahn 算法:将入将没有来源表数为0的目标表加入队列
        Queue<String> queue = new LinkedList<>();
        for (String table : allTables) {
            if (targetRefsCnt.get(table) == 0) {
                queue.offer(table);
            }
        }

        List<String> result = new ArrayList<>();

        while (!queue.isEmpty()) {
            String currentTable = queue.poll();
            result.add(currentTable);

            // 遍历当前表对应 所有关联的 目标表
            for (String targetTable : sourceWithTargets.get(currentTable)) {
                // 让其目标表对应来源表的数量全部 - 1
                targetRefsCnt.put(targetTable, targetRefsCnt.get(targetTable) - 1);
                // 如果目标表 的来源表 数变成0 (依赖结束) 那么放入队列
                if (targetRefsCnt.get(targetTable) == 0) {
                    queue.offer(targetTable);
                }
            }
        }

        // 如果结果数量不等于总节点数,说明存在环
        if (result.size() != allTables.size()) {
            ArrayList<String> cyclicTables = new ArrayList<>();
            for (String table : allTables) {
                if (targetRefsCnt.get(table) > 0) {
                    cyclicTables.add(table);
                }
            }
            throw new RuntimeException( "存在循环依赖!以下表无法被调度(可能处于环中): " + cyclicTables);
        }

        return result;
    }

    public static void main(String[] args) {
        Map<String, List<String>> data = new LinkedHashMap<>();
        data.put("dws_1", Arrays.asList("dwd_1", "dwd_2"));
        data.put("dws_2", Arrays.asList("ods_4", "dwd_3"));
        data.put("dws_3", Arrays.asList("dwd_4"));
        data.put("dwd_1", Arrays.asList("ods_1", "ods_2"));
        data.put("dwd_2", Arrays.asList("ods_3", "ods_4"));
        data.put("dwd_3", Arrays.asList("ods_5"));
        data.put("dwd_4", Arrays.asList("ods_6", "dwd_3"));

        System.out.println("正常情况:");
        System.out.println(sortTasks(data));

        // 测试循环依赖
        System.out.println("\n测试循环依赖:");
        Map<String, List<String>> cyclicData = new LinkedHashMap<>();
        cyclicData.put("A", Arrays.asList("B"));
        cyclicData.put("B", Arrays.asList("C"));
        cyclicData.put("C", Arrays.asList("A")); // A ← B ← C ← A

        try {
            sortTasks(cyclicData);
        } catch (RuntimeException e) {
            System.err.println(e.getMessage());
        }
    }
}

分层输出,每层可并发执行

import java.util.*;

public class TaskSorterWithLevels {
    /**
     * 返回按执行层级分组的任务列表,每层内任务可并行执行
     */
    public static List<List<String>> sortTasksIntoLevels(Map<String, List<String>> taskRefs) {
        Set<String> allTables = new LinkedHashSet<>();
        Map<String, List<String>> sourceToTargets = new LinkedHashMap<>();
        Map<String, Integer> inDegree = new HashMap<>();

        // 收集所有表
        for (Map.Entry<String, List<String>> entry : taskRefs.entrySet()) {
            String target = entry.getKey();
            List<String> sources = entry.getValue();
            allTables.add(target);
            if (sources != null) {
                allTables.addAll(sources);
            }
        }

        // 初始化
        for (String table : allTables) {
            inDegree.put(table, 0);
            sourceToTargets.put(table, new ArrayList<>());
        }

        // 构建图
        for (Map.Entry<String, List<String>> entry : taskRefs.entrySet()) {
            String target = entry.getKey();
            List<String> sources = entry.getValue();
            if (sources == null) {
                continue;
            }
            for (String source : sources) {
                sourceToTargets.get(source).add(target);
                inDegree.put(target, inDegree.get(target) + 1);
            }
        }

        // Kahn 分层
        Queue<String> queue = new LinkedList<>();
        for (String table : allTables) {
            if (inDegree.get(table) == 0) {
                queue.offer(table);
            }
        }

        List<List<String>> levels = new ArrayList<>();

        while (!queue.isEmpty()) { // 便利 第一批 0度, 第二批0度 对应level_0, level1
            int size = queue.size();
            List<String> currentLevel = new ArrayList<>();

            // 一次性处理当前层所有节点(关键!)
            for (int i = 0; i < size; i++) {
                String current = queue.poll();
                currentLevel.add(current);

                for (String child : sourceToTargets.get(current)) {
                    inDegree.put(child, inDegree.get(child) - 1);
                    if (inDegree.get(child) == 0) {
                        queue.offer(child);
                    }
                }
            }

            levels.add(currentLevel);
        }

        // 检查循环依赖
        int totalScheduled = levels.stream().mapToInt(List::size).sum();
        if (totalScheduled != allTables.size()) {
            List<String> cyclic = new ArrayList<>();
            for (String table : allTables) {
                if (inDegree.get(table) > 0) {
                    cyclic.add(table);
                }
            }
            throw new RuntimeException("存在循环依赖,无法调度的表: " + cyclic);
        }

        return levels;
    }

    // 工具方法:将分层结果展平为串行列表(兼容旧接口)
    public static List<String> sortTasks(Map<String, List<String>> taskRefs) {
        List<List<String>> levels = sortTasksIntoLevels(taskRefs);
        List<String> flat = new ArrayList<>();
        for (List<String> level : levels) {
            flat.addAll(level);
        }
        return flat;
    }

    // 测试
    public static void main(String[] args) {
        Map<String, List<String>> data = new LinkedHashMap<>();
        data.put("dws_1", Arrays.asList("dwd_1", "dwd_2"));
        data.put("dws_2", Arrays.asList("ods_4", "dwd_3"));
        data.put("dws_3", Arrays.asList("dwd_4"));
        data.put("dwd_1", Arrays.asList("ods_1", "ods_2"));
        data.put("dwd_2", Arrays.asList("ods_3", "ods_4"));
        data.put("dwd_3", Arrays.asList("ods_5"));
        data.put("dwd_4", Arrays.asList("ods_6", "dwd_3"));

        List<List<String>> levels = sortTasksIntoLevels(data);
        System.out.println("分层执行计划(每层可并行):");
        for (int i = 0; i < levels.size(); i++) {
            System.out.println("Level " + i + ": " + levels.get(i));
        }

        // 兼容旧用法
        System.out.println("\n串行顺序(展平):");
        System.out.println(sortTasks(data));
    }
}