大数据

flink集成spring-boot分时段消费策略

flink消费策略,工作时间8:00 - 20:00 单条消费 其他时间批量消费

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>press.huang</groupId>
    <artifactId>flink-spring-boot-demo2</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-streaming-java</artifactId>
            <version>1.17.2</version>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-clients</artifactId>
            <version>1.17.2</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
            <version>2.7.18</version>
        </dependency>
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.5.6</version>
        </dependency>
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.33</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.20</version>
            <scope>provided</scope>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <!-- 指定编译器 JDK8 -->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

resources/application.yml

spring:
  datasource:
    url: jdbc:mysql://localhost:3306/flink_etl?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true
    username: root
    password: huangbo123
    driver-class-name: com.mysql.cj.jdbc.Driver
    hikari:
      minimum-idle: 5
      maximum-pool-size: 20
      # 10分钟无活动回收连接
      idle-timeout: 600000
      # 30分钟最大生命周期,到点重建
      max-lifetime: 1800000
      # 每5分钟测试连接可用性
      keepalive-time: 300000
      connection-timeout: 30000

mybatis-plus:
  mapper-locations: classpath*:mapper/*.xml
  configuration:
    map-underscore-to-camel-case: true

flink:
  batch-interval-ms: 30000
  batch-size: 5
  work-start: "08:00"
  work-end: "20:00"

resources/mapper/SubjectScoreMapper.xml

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
        "http://mybatis.org/dtd/mybatis-3-mapper.dtd">

<mapper namespace="press.huang.dev.mapper.SubjectScoreMapper">

</mapper>

resources/mapper/UserInfoMapper.xml

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
        "http://mybatis.org/dtd/mybatis-3-mapper.dtd">

<mapper namespace="press.huang.dev.mapper.UserInfoMapper">

</mapper>

DDL

CREATE TABLE `user_info` (
  `id` bigint NOT NULL,
  `name` varchar(100) DEFAULT NULL,
  `age` int DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `subject_score` (
  `id` bigint NOT NULL,
  `user_id` bigint DEFAULT NULL,
  `subject` varchar(100) DEFAULT NULL,
  `score` double DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

config

package press.huang.dev.config;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;

import java.time.LocalTime;

@Data
@Configuration
@ConfigurationProperties(prefix = "flink")
public class FlinkProperties {
    private long batchIntervalMs;
    private int batchSize;
    private String workStart;
    private String workEnd;
    private int parallelism;

    // 转 LocalTime
    public LocalTime getWorkStartTime() {
        return LocalTime.parse(workStart);
    }

    public LocalTime getWorkEndTime() {
        return LocalTime.parse(workEnd);
    }
}

entity

package press.huang.dev.entity;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@NoArgsConstructor
@AllArgsConstructor
public class UserEvent {
    private Long userId;
    private String subject;
    private Double score;
}
package press.huang.dev.entity;

import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;

@Data
@TableName("user_info")
public class UserInfo {
    private Long id;
    private String name;
    private Integer age;
}
package press.huang.dev.entity;

import com.baomidou.mybatisplus.annotation.TableName;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@NoArgsConstructor
@AllArgsConstructor
@TableName("subject_score")
public class SubjectScore {
    private Long id;
    private Long userId;
    private String subject;
    private Double score;
}

mapper

package press.huang.dev.mapper;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import press.huang.dev.entity.UserInfo;

@Mapper
public interface UserInfoMapper extends BaseMapper<UserInfo> {
}
package press.huang.dev.mapper;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import press.huang.dev.entity.SubjectScore;

@Mapper
public interface SubjectScoreMapper extends BaseMapper<SubjectScore> {
}

utils

package press.huang.dev.utils;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * 让 Flink 运行时动态获取 Spring Bean(Mapper)
 */
@Component
public class SpringContextUtil implements ApplicationContextAware {
    private static ApplicationContext CONTEXT;

    @Override
    public void setApplicationContext(ApplicationContext ctx) throws BeansException {
        CONTEXT = ctx;
    }

    public static <T> T getBean(Class<T> type) {
        return CONTEXT.getBean(type);
    }
}

functions

package press.huang.dev.functions;

import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import org.apache.flink.api.common.state.*;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import press.huang.dev.entity.SubjectScore;
import press.huang.dev.entity.UserEvent;
import press.huang.dev.entity.UserInfo;
import press.huang.dev.mapper.SubjectScoreMapper;
import press.huang.dev.mapper.UserInfoMapper;
import press.huang.dev.utils.SpringContextUtil;

import java.time.Instant;
import java.time.LocalTime;
import java.time.ZoneId;
import java.util.*;

/**
 * KeyedProcessFunction:
 * - 上班时间 (08:00-20:00):单条消费(优先从 state,miss 再 DB)
 * - 下班时间:批量累积(batchSize 或 batchIntervalMs),批量 LambdaQueryWrapper 查询两张表,
 * 批量写入 state,然后逐条消费(从 state 读取并更新单学科)
 * <p>
 * State:
 * - userInfoMapState: MapState<Long, UserInfo>
 * - subjectScoreMapState: MapState<Long, Map<String, SubjectScore>>  (内层 Map 用于 O(1) 更新学科)
 * - buffer: ListState<UserEvent> (累积)
 * - timerState: ValueState<Long> (保存已注册的 timer 时间戳,便于跨 checkpoint 恢复)
 * <p>
 * 注意:为了保证触发语义稳定,批处理完成后才更新 lastBatchProcessTime(避免震荡触发)
 */
public class SubjectScoreKeyedProcessFunction extends KeyedProcessFunction<String, UserEvent, String> {
    // 用户信息状态:userId -> UserInfo
    private transient MapState<Long, UserInfo> userInfoMapState;
    // 学科成绩状态:userId -> (subject -> SubjectScore)
    private transient MapState<Long, Map<String, SubjectScore>> subjectScoreMapState;
    // 缓冲 ListState:用于下班时间累积
    private transient ListState<UserEvent> bufferState;

    // 存储已注册定时器时间戳(processing time),便于 checkpoint 恢复
    private transient ValueState<Long> timerState;

    private transient UserInfoMapper userInfoMapper;
    private transient SubjectScoreMapper subjectScoreMapper;

    private final int batchSize;
    private final long batchIntervalMs;
    private final LocalTime workStartTime;
    private final LocalTime workEndTime;

    // 最后一次完整批处理完成时间(processing time)
    private transient ValueState<Long> lastBatchProcessTimeState;

    public SubjectScoreKeyedProcessFunction(int batchSize, long batchIntervalMs, LocalTime workStartTime, LocalTime workEndTime) {
        this.batchSize = batchSize;
        this.batchIntervalMs = batchIntervalMs;
        this.workStartTime = workStartTime;
        this.workEndTime = workEndTime;
    }

    @Override
    public void open(Configuration parameters) throws Exception {
        // 设置 TTL 30 分钟
        StateTtlConfig ttlConfig = StateTtlConfig
                .newBuilder(Time.minutes(30))
                .setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite) // 更新策略
                .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired) // 不返回过期数据
                .build();
        // MapStateDescriptor 要用 TypeInformation 指定泛型,避免泛型擦除问题
        MapStateDescriptor<Long, UserInfo> userInfoDesc =
                new MapStateDescriptor<>(
                        "userInfoMapState",
                        TypeInformation.of(Long.class),
                        TypeInformation.of(new TypeHint<UserInfo>() {
                        })
                );
        userInfoDesc.enableTimeToLive(ttlConfig);
        userInfoMapState = getRuntimeContext().getMapState(userInfoDesc);

        MapStateDescriptor<Long, Map<String, SubjectScore>> subjectScoreDesc =
                new MapStateDescriptor<>(
                        "subjectScoreMapState",
                        TypeInformation.of(Long.class),
                        TypeInformation.of(new TypeHint<Map<String, SubjectScore>>() {
                        })
                );
        subjectScoreDesc.enableTimeToLive(ttlConfig);
        subjectScoreMapState = getRuntimeContext().getMapState(subjectScoreDesc);

        ListStateDescriptor<UserEvent> bufferDesc =
                new ListStateDescriptor<>(
                        "bufferState",
                        TypeInformation.of(new TypeHint<UserEvent>() {
                        })
                );
        bufferDesc.enableTimeToLive(ttlConfig);
        bufferState = getRuntimeContext().getListState(bufferDesc);

        ValueStateDescriptor<Long> timerDesc =
                new ValueStateDescriptor<>("timerState", Long.class);
        timerState = getRuntimeContext().getState(timerDesc);

        ValueStateDescriptor<Long> lastBatchDesc = new ValueStateDescriptor<>("lastBatchProcessTime", Long.class);
        lastBatchProcessTimeState = getRuntimeContext().getState(lastBatchDesc);
        // 获取 Mapper
        userInfoMapper = SpringContextUtil.getBean(UserInfoMapper.class);
        subjectScoreMapper = SpringContextUtil.getBean(SubjectScoreMapper.class);
    }

    @Override
    public void processElement(UserEvent event, KeyedProcessFunction<String, UserEvent, String>.Context ctx, Collector<String> out) throws Exception {
        int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
        Long userId = event.getUserId();
        String subject = event.getSubject();
        Double score = event.getScore();

        System.out.println("并行度[" + subtaskIndex + "] 单条处理 - userId: " + userId + ", subject: " + subject + ", score: " + score + " key: " + ctx.getCurrentKey());

        // 如果 lastBatchProcessTime 没有设置,初始化为当前时间
        if (lastBatchProcessTimeState.value() == null) {
            lastBatchProcessTimeState.update(Instant.now().toEpochMilli());
        }
        // 判断时段
        LocalTime now = LocalTime.now(ZoneId.of("Asia/Shanghai")); // 用北京时间
        boolean isWorkTime = !now.isBefore(workStartTime) && now.isBefore(workEndTime);

        if (isWorkTime) {
            // 上班时间:单条消费(优先 state,miss 则 DB 并缓存)
            processSingle(event, out);
            return;
        }
        // 下班时间:累积到 ListState(只在下班时走这里)
        // 1) add to buffer
        bufferState.add(event);

        // 2) 检查是否需要触发批量(size-based 或 time-based)
        // 为了避免多次遍历 ListState,我们先将所有 buffer 扫描到临时 List(一次读取)
        List<UserEvent> buffered = new ArrayList<>();
        for (UserEvent ue : bufferState.get()) {
            buffered.add(ue);
        }

        boolean sizeReached = buffered.size() >= batchSize;
        long lastBatchTs = lastBatchProcessTimeState.value() == null ? 0L : lastBatchProcessTimeState.value();
        boolean timeExpired = (System.currentTimeMillis() - lastBatchTs) >= batchIntervalMs;

        if (sizeReached) {
            // 立即触发批处理(synchronous)
            flushBufferAndProcess(buffered, out);
            bufferState.clear();
            // 更新 lastBatchProcessTime 在批处理完成后
            lastBatchProcessTimeState.update(System.currentTimeMillis());
            // 清 timer if exists
            Long t = timerState.value();
            if (t != null) {
                ctx.timerService().deleteProcessingTimeTimer(t);
                timerState.clear();
            }
            return;
        }

        // 如果未达到 size,但时间条件满足(说明需要触发),或者 timer 未注册则注册定时器
        Long registeredTimer = timerState.value();
        long nowProcTime = ctx.timerService().currentProcessingTime();
        if (timeExpired && registeredTimer == null) {
            // 如果跨 checkpoint 恢复后 timeExpired 为 true,还是要注册一个短期 timer 来保障触发
            long fireTime = nowProcTime + 1L; // 尽快触发
            ctx.timerService().registerProcessingTimeTimer(fireTime);
            timerState.update(fireTime);
        } else if (registeredTimer == null) {
            // 注册正常的批量触发定时器(到期触发)
            long fireTime = nowProcTime + batchIntervalMs;
            ctx.timerService().registerProcessingTimeTimer(fireTime);
            timerState.update(fireTime);
        }
    }

    @Override
    public void onTimer(long timestamp, OnTimerContext ctx, Collector<String> out) throws Exception {
        // 定时器触发:读取 buffer 一次性处理
        List<UserEvent> buffered = new ArrayList<>();
        for (UserEvent ue : bufferState.get()) {
            buffered.add(ue);
        }
        if (!buffered.isEmpty()) {
            flushBufferAndProcess(buffered, out);
            lastBatchProcessTimeState.update(System.currentTimeMillis());
            bufferState.clear();
        }
        // 清除 timerState
        timerState.clear();
    }

    /**
     * 批量查询并更新状态,然后对每条消息从 state 中处理并输出
     */
    private void flushBufferAndProcess(List<UserEvent> buffered, Collector<String> out) throws Exception {
        if (buffered.isEmpty()) return;

        // 1. 收集本批次的 userId(去重)
        Set<Long> userIdSet = new HashSet<>(buffered.size());
        for (UserEvent e : buffered) {
            userIdSet.add(e.getUserId());
        }
        List<Long> userIdList = new ArrayList<>(userIdSet);

        // 2. 批量查询用户信息(LambdaQueryWrapper IN)
        LambdaQueryWrapper<UserInfo> userWrapper = new LambdaQueryWrapper<>();
        userWrapper.in(UserInfo::getId, userIdList);
        List<UserInfo> userInfos = userInfoMapper.selectList(userWrapper);
        // 写入 userInfoMapState(逐条 put)
        for (UserInfo ui : userInfos) {
            if (ui != null && ui.getId() != null) {
                userInfoMapState.put(ui.getId(), ui);
            }
        }

        // 3. 批量查询学科成绩(LambdaQueryWrapper IN)
        LambdaQueryWrapper<SubjectScore> scoreWrapper = new LambdaQueryWrapper<>();
        scoreWrapper.in(SubjectScore::getUserId, userIdList);
        List<SubjectScore> scores = subjectScoreMapper.selectList(scoreWrapper);

        // 4. 构建每个 userId 的 subject -> SubjectScore Map 并写入 subjectScoreMapState
        // 先 group by userId
        Map<Long, List<SubjectScore>> grouped = new HashMap<>(scores.size());
        for (SubjectScore score : scores) {
            if (score.getUserId() == null) continue;
            grouped.computeIfAbsent(score.getUserId(), k -> new ArrayList<>()).add(score);
        }

        for (Long uid : userIdList) {
            Map<String, SubjectScore> subjectMap = new HashMap<>();
            List<SubjectScore> listForUser = grouped.getOrDefault(uid, Collections.emptyList());
            for (SubjectScore s : listForUser) {
                subjectMap.put(s.getSubject(), s);
            }
            subjectScoreMapState.put(uid, subjectMap);
        }

        // 5. 批量更新完成后,逐条从 state 中处理每条消息(保证使用同一批次加载的 state)
        for (UserEvent e : buffered) {
            processSingleFromState(e, out);
        }

        // 6. 批次处理完成(注意:在调用方更新 lastBatchProcessTime)
    }

    /**
     * 上班时间:单条处理逻辑
     * - 优先从 state 读取 user info / score map
     * - state miss 时从 DB 查询并写入 state
     * - 更新/插入单学科成绩,写回 state
     */
    private void processSingle(UserEvent event, Collector<String> out) throws Exception {
        Long userId = event.getUserId();
        String subject = event.getSubject();
        Double score = event.getScore();

        // --- user info ---
        UserInfo ui = userInfoMapState.get(userId);
        if (ui == null) {
            ui = userInfoMapper.selectById(userId);
            if (ui != null) userInfoMapState.put(userId, ui);
        }

        // --- subject map ---
        Map<String, SubjectScore> subjectMap = subjectScoreMapState.get(userId);
        if (subjectMap == null) {
            LambdaQueryWrapper<SubjectScore> qw = new LambdaQueryWrapper<>();
            qw.eq(SubjectScore::getUserId, userId);
            List<SubjectScore> list = subjectScoreMapper.selectList(qw);
            subjectMap = new HashMap<>();
            for (SubjectScore s : list) {
                subjectMap.put(s.getSubject(), s);
            }
            subjectScoreMapState.put(userId, subjectMap);
        }

        SubjectScore existing = subjectMap.get(subject);
        SubjectScore outScore;
        if (existing != null) {
            if (!Objects.equals(existing.getScore(), score)) {
                outScore = new SubjectScore(existing.getId(), userId, subject, score);
            } else {
                // 与现有分数相同,无需写状态或输出(按需决定)
                outScore = null;
            }
        } else {
            outScore = new SubjectScore(null, userId, subject, score);
        }

        if (outScore != null) {
            // 更新 state
            Long id = outScore.getId();
            String oper = "update";
            if (id == null) {
                id = IdWorker.getId();
                outScore.setId(id);
                oper = "create";
            }
            subjectMap.put(subject, outScore);
            subjectScoreMapState.put(userId, subjectMap);

            String name = (ui != null ? ui.getName() : "UNKNOWN");
            out.collect("SINGLE -> userId=" + userId + ", name=" + name + ", " + oper + "=" + outScore);
        }
    }

    /**
     * 下班批量查询后,从 state 中读取并对单条消息进行更新/输出
     * 说明:state 已在 flushBufferAndProcess 中更新
     */
    private void processSingleFromState(UserEvent event, Collector<String> out) throws Exception {
        Long userId = event.getUserId();
        String subject = event.getSubject();
        Double score = event.getScore();

        UserInfo ui = userInfoMapState.get(userId);
        Map<String, SubjectScore> subjectMap = subjectScoreMapState.get(userId);
        if (subjectMap == null) {
            subjectMap = new HashMap<>();
        }

        SubjectScore existing = subjectMap.get(subject);
        SubjectScore outScore;
        if (existing != null) {
            if (!Objects.equals(existing.getScore(), score)) {
                outScore = new SubjectScore(existing.getId(), userId, subject, score);
            } else {
                // 与现有分数相同,无需写状态或输出(按需决定)
                outScore = null;
            }
        } else {
            outScore = new SubjectScore(null, userId, subject, score);
        }

        if (outScore != null) {
            // 更新 state
            Long id = outScore.getId();
            String oper = "update";
            if (id == null) {
                id = IdWorker.getId();
                outScore.setId(id);
                oper = "create";
            }
            subjectMap.put(subject, outScore);
            subjectScoreMapState.put(userId, subjectMap);

            String name = (ui != null ? ui.getName() : "UNKNOWN");
            out.collect("BATCH -> userId=" + userId + ", name=" + name + ", " + oper + "=" + outScore);
        }
    }
}

主程序

package press.huang.dev;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import press.huang.dev.config.FlinkProperties;
import press.huang.dev.entity.UserEvent;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import press.huang.dev.functions.SubjectScoreKeyedProcessFunction;
import press.huang.dev.utils.SpringContextUtil;

import java.util.TimeZone;

/**
 * 本地测试 main(Spring Boot + Flink)
 * 注意:本地调试时请确保引入 flink-clients 依赖(见 pom.xml)
 */
@SpringBootApplication
public class FlinkApplication {
    public static void main(String[] args) throws Exception {
        // 设置 JVM 默认时区为北京时间
        TimeZone.setDefault(TimeZone.getTimeZone("Asia/Shanghai"));

        SpringApplication.run(FlinkApplication.class, args);
        FlinkProperties props = SpringContextUtil.getBean(FlinkProperties.class);
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         env.setParallelism(3);
        // 模拟输入
        DataStreamSource<String> stream = env.socketTextStream("localhost", 7777).setParallelism(1);
        SingleOutputStreamOperator<UserEvent> mapStream = stream.map(new MapFunction<String, UserEvent>() {
            @Override
            public UserEvent map(String line) throws Exception {
                String[] fields = line.split(",");
                return new UserEvent(Long.parseLong(fields[0]), fields[1], Double.parseDouble(fields[2]));
            }
        });

        int parallelism = env.getParallelism();
        KeyedStream<UserEvent, String> keyedStream = mapStream.keyBy(new KeySelector<UserEvent, String>() {
            @Override
            public String getKey(UserEvent event) throws Exception {
                return "" + (event.getUserId() % parallelism);
            }
        });

        keyedStream.process(new SubjectScoreKeyedProcessFunction(props.getBatchSize(), props.getBatchIntervalMs(), props.getWorkStartTime(), props.getWorkEndTime())).print();


        env.execute("SubjectScore Hybrid Demo");
    }
}