文章
flink集成spring-boot分时段消费策略
flink消费策略,工作时间8:00 - 20:00 单条消费 其他时间批量消费
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>press.huang</groupId>
<artifactId>flink-spring-boot-demo2</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-streaming-java</artifactId>
<version>1.17.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients</artifactId>
<version>1.17.2</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
<version>2.7.18</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-boot-starter</artifactId>
<version>3.5.6</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.20</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<!-- 指定编译器 JDK8 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>resources/application.yml
spring:
datasource:
url: jdbc:mysql://localhost:3306/flink_etl?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true
username: root
password: huangbo123
driver-class-name: com.mysql.cj.jdbc.Driver
hikari:
minimum-idle: 5
maximum-pool-size: 20
# 10分钟无活动回收连接
idle-timeout: 600000
# 30分钟最大生命周期,到点重建
max-lifetime: 1800000
# 每5分钟测试连接可用性
keepalive-time: 300000
connection-timeout: 30000
mybatis-plus:
mapper-locations: classpath*:mapper/*.xml
configuration:
map-underscore-to-camel-case: true
flink:
batch-interval-ms: 30000
batch-size: 5
work-start: "08:00"
work-end: "20:00"resources/mapper/SubjectScoreMapper.xml
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="press.huang.dev.mapper.SubjectScoreMapper">
</mapper>resources/mapper/UserInfoMapper.xml
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="press.huang.dev.mapper.UserInfoMapper">
</mapper>
DDL
CREATE TABLE `user_info` (
`id` bigint NOT NULL,
`name` varchar(100) DEFAULT NULL,
`age` int DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE `subject_score` (
`id` bigint NOT NULL,
`user_id` bigint DEFAULT NULL,
`subject` varchar(100) DEFAULT NULL,
`score` double DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;config
package press.huang.dev.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import java.time.LocalTime;
@Data
@Configuration
@ConfigurationProperties(prefix = "flink")
public class FlinkProperties {
private long batchIntervalMs;
private int batchSize;
private String workStart;
private String workEnd;
private int parallelism;
// 转 LocalTime
public LocalTime getWorkStartTime() {
return LocalTime.parse(workStart);
}
public LocalTime getWorkEndTime() {
return LocalTime.parse(workEnd);
}
}
entity
package press.huang.dev.entity;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class UserEvent {
private Long userId;
private String subject;
private Double score;
}
package press.huang.dev.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
@Data
@TableName("user_info")
public class UserInfo {
private Long id;
private String name;
private Integer age;
}
package press.huang.dev.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
@TableName("subject_score")
public class SubjectScore {
private Long id;
private Long userId;
private String subject;
private Double score;
}
mapper
package press.huang.dev.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import press.huang.dev.entity.UserInfo;
@Mapper
public interface UserInfoMapper extends BaseMapper<UserInfo> {
}package press.huang.dev.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import press.huang.dev.entity.SubjectScore;
@Mapper
public interface SubjectScoreMapper extends BaseMapper<SubjectScore> {
}
utils
package press.huang.dev.utils;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
/**
* 让 Flink 运行时动态获取 Spring Bean(Mapper)
*/
@Component
public class SpringContextUtil implements ApplicationContextAware {
private static ApplicationContext CONTEXT;
@Override
public void setApplicationContext(ApplicationContext ctx) throws BeansException {
CONTEXT = ctx;
}
public static <T> T getBean(Class<T> type) {
return CONTEXT.getBean(type);
}
}
functions
package press.huang.dev.functions;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import org.apache.flink.api.common.state.*;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import press.huang.dev.entity.SubjectScore;
import press.huang.dev.entity.UserEvent;
import press.huang.dev.entity.UserInfo;
import press.huang.dev.mapper.SubjectScoreMapper;
import press.huang.dev.mapper.UserInfoMapper;
import press.huang.dev.utils.SpringContextUtil;
import java.time.Instant;
import java.time.LocalTime;
import java.time.ZoneId;
import java.util.*;
/**
* KeyedProcessFunction:
* - 上班时间 (08:00-20:00):单条消费(优先从 state,miss 再 DB)
* - 下班时间:批量累积(batchSize 或 batchIntervalMs),批量 LambdaQueryWrapper 查询两张表,
* 批量写入 state,然后逐条消费(从 state 读取并更新单学科)
* <p>
* State:
* - userInfoMapState: MapState<Long, UserInfo>
* - subjectScoreMapState: MapState<Long, Map<String, SubjectScore>> (内层 Map 用于 O(1) 更新学科)
* - buffer: ListState<UserEvent> (累积)
* - timerState: ValueState<Long> (保存已注册的 timer 时间戳,便于跨 checkpoint 恢复)
* <p>
* 注意:为了保证触发语义稳定,批处理完成后才更新 lastBatchProcessTime(避免震荡触发)
*/
public class SubjectScoreKeyedProcessFunction extends KeyedProcessFunction<String, UserEvent, String> {
// 用户信息状态:userId -> UserInfo
private transient MapState<Long, UserInfo> userInfoMapState;
// 学科成绩状态:userId -> (subject -> SubjectScore)
private transient MapState<Long, Map<String, SubjectScore>> subjectScoreMapState;
// 缓冲 ListState:用于下班时间累积
private transient ListState<UserEvent> bufferState;
// 存储已注册定时器时间戳(processing time),便于 checkpoint 恢复
private transient ValueState<Long> timerState;
private transient UserInfoMapper userInfoMapper;
private transient SubjectScoreMapper subjectScoreMapper;
private final int batchSize;
private final long batchIntervalMs;
private final LocalTime workStartTime;
private final LocalTime workEndTime;
// 最后一次完整批处理完成时间(processing time)
private transient ValueState<Long> lastBatchProcessTimeState;
public SubjectScoreKeyedProcessFunction(int batchSize, long batchIntervalMs, LocalTime workStartTime, LocalTime workEndTime) {
this.batchSize = batchSize;
this.batchIntervalMs = batchIntervalMs;
this.workStartTime = workStartTime;
this.workEndTime = workEndTime;
}
@Override
public void open(Configuration parameters) throws Exception {
// 设置 TTL 30 分钟
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.minutes(30))
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite) // 更新策略
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired) // 不返回过期数据
.build();
// MapStateDescriptor 要用 TypeInformation 指定泛型,避免泛型擦除问题
MapStateDescriptor<Long, UserInfo> userInfoDesc =
new MapStateDescriptor<>(
"userInfoMapState",
TypeInformation.of(Long.class),
TypeInformation.of(new TypeHint<UserInfo>() {
})
);
userInfoDesc.enableTimeToLive(ttlConfig);
userInfoMapState = getRuntimeContext().getMapState(userInfoDesc);
MapStateDescriptor<Long, Map<String, SubjectScore>> subjectScoreDesc =
new MapStateDescriptor<>(
"subjectScoreMapState",
TypeInformation.of(Long.class),
TypeInformation.of(new TypeHint<Map<String, SubjectScore>>() {
})
);
subjectScoreDesc.enableTimeToLive(ttlConfig);
subjectScoreMapState = getRuntimeContext().getMapState(subjectScoreDesc);
ListStateDescriptor<UserEvent> bufferDesc =
new ListStateDescriptor<>(
"bufferState",
TypeInformation.of(new TypeHint<UserEvent>() {
})
);
bufferDesc.enableTimeToLive(ttlConfig);
bufferState = getRuntimeContext().getListState(bufferDesc);
ValueStateDescriptor<Long> timerDesc =
new ValueStateDescriptor<>("timerState", Long.class);
timerState = getRuntimeContext().getState(timerDesc);
ValueStateDescriptor<Long> lastBatchDesc = new ValueStateDescriptor<>("lastBatchProcessTime", Long.class);
lastBatchProcessTimeState = getRuntimeContext().getState(lastBatchDesc);
// 获取 Mapper
userInfoMapper = SpringContextUtil.getBean(UserInfoMapper.class);
subjectScoreMapper = SpringContextUtil.getBean(SubjectScoreMapper.class);
}
@Override
public void processElement(UserEvent event, KeyedProcessFunction<String, UserEvent, String>.Context ctx, Collector<String> out) throws Exception {
int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
Long userId = event.getUserId();
String subject = event.getSubject();
Double score = event.getScore();
System.out.println("并行度[" + subtaskIndex + "] 单条处理 - userId: " + userId + ", subject: " + subject + ", score: " + score + " key: " + ctx.getCurrentKey());
// 如果 lastBatchProcessTime 没有设置,初始化为当前时间
if (lastBatchProcessTimeState.value() == null) {
lastBatchProcessTimeState.update(Instant.now().toEpochMilli());
}
// 判断时段
LocalTime now = LocalTime.now(ZoneId.of("Asia/Shanghai")); // 用北京时间
boolean isWorkTime = !now.isBefore(workStartTime) && now.isBefore(workEndTime);
if (isWorkTime) {
// 上班时间:单条消费(优先 state,miss 则 DB 并缓存)
processSingle(event, out);
return;
}
// 下班时间:累积到 ListState(只在下班时走这里)
// 1) add to buffer
bufferState.add(event);
// 2) 检查是否需要触发批量(size-based 或 time-based)
// 为了避免多次遍历 ListState,我们先将所有 buffer 扫描到临时 List(一次读取)
List<UserEvent> buffered = new ArrayList<>();
for (UserEvent ue : bufferState.get()) {
buffered.add(ue);
}
boolean sizeReached = buffered.size() >= batchSize;
long lastBatchTs = lastBatchProcessTimeState.value() == null ? 0L : lastBatchProcessTimeState.value();
boolean timeExpired = (System.currentTimeMillis() - lastBatchTs) >= batchIntervalMs;
if (sizeReached) {
// 立即触发批处理(synchronous)
flushBufferAndProcess(buffered, out);
bufferState.clear();
// 更新 lastBatchProcessTime 在批处理完成后
lastBatchProcessTimeState.update(System.currentTimeMillis());
// 清 timer if exists
Long t = timerState.value();
if (t != null) {
ctx.timerService().deleteProcessingTimeTimer(t);
timerState.clear();
}
return;
}
// 如果未达到 size,但时间条件满足(说明需要触发),或者 timer 未注册则注册定时器
Long registeredTimer = timerState.value();
long nowProcTime = ctx.timerService().currentProcessingTime();
if (timeExpired && registeredTimer == null) {
// 如果跨 checkpoint 恢复后 timeExpired 为 true,还是要注册一个短期 timer 来保障触发
long fireTime = nowProcTime + 1L; // 尽快触发
ctx.timerService().registerProcessingTimeTimer(fireTime);
timerState.update(fireTime);
} else if (registeredTimer == null) {
// 注册正常的批量触发定时器(到期触发)
long fireTime = nowProcTime + batchIntervalMs;
ctx.timerService().registerProcessingTimeTimer(fireTime);
timerState.update(fireTime);
}
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<String> out) throws Exception {
// 定时器触发:读取 buffer 一次性处理
List<UserEvent> buffered = new ArrayList<>();
for (UserEvent ue : bufferState.get()) {
buffered.add(ue);
}
if (!buffered.isEmpty()) {
flushBufferAndProcess(buffered, out);
lastBatchProcessTimeState.update(System.currentTimeMillis());
bufferState.clear();
}
// 清除 timerState
timerState.clear();
}
/**
* 批量查询并更新状态,然后对每条消息从 state 中处理并输出
*/
private void flushBufferAndProcess(List<UserEvent> buffered, Collector<String> out) throws Exception {
if (buffered.isEmpty()) return;
// 1. 收集本批次的 userId(去重)
Set<Long> userIdSet = new HashSet<>(buffered.size());
for (UserEvent e : buffered) {
userIdSet.add(e.getUserId());
}
List<Long> userIdList = new ArrayList<>(userIdSet);
// 2. 批量查询用户信息(LambdaQueryWrapper IN)
LambdaQueryWrapper<UserInfo> userWrapper = new LambdaQueryWrapper<>();
userWrapper.in(UserInfo::getId, userIdList);
List<UserInfo> userInfos = userInfoMapper.selectList(userWrapper);
// 写入 userInfoMapState(逐条 put)
for (UserInfo ui : userInfos) {
if (ui != null && ui.getId() != null) {
userInfoMapState.put(ui.getId(), ui);
}
}
// 3. 批量查询学科成绩(LambdaQueryWrapper IN)
LambdaQueryWrapper<SubjectScore> scoreWrapper = new LambdaQueryWrapper<>();
scoreWrapper.in(SubjectScore::getUserId, userIdList);
List<SubjectScore> scores = subjectScoreMapper.selectList(scoreWrapper);
// 4. 构建每个 userId 的 subject -> SubjectScore Map 并写入 subjectScoreMapState
// 先 group by userId
Map<Long, List<SubjectScore>> grouped = new HashMap<>(scores.size());
for (SubjectScore score : scores) {
if (score.getUserId() == null) continue;
grouped.computeIfAbsent(score.getUserId(), k -> new ArrayList<>()).add(score);
}
for (Long uid : userIdList) {
Map<String, SubjectScore> subjectMap = new HashMap<>();
List<SubjectScore> listForUser = grouped.getOrDefault(uid, Collections.emptyList());
for (SubjectScore s : listForUser) {
subjectMap.put(s.getSubject(), s);
}
subjectScoreMapState.put(uid, subjectMap);
}
// 5. 批量更新完成后,逐条从 state 中处理每条消息(保证使用同一批次加载的 state)
for (UserEvent e : buffered) {
processSingleFromState(e, out);
}
// 6. 批次处理完成(注意:在调用方更新 lastBatchProcessTime)
}
/**
* 上班时间:单条处理逻辑
* - 优先从 state 读取 user info / score map
* - state miss 时从 DB 查询并写入 state
* - 更新/插入单学科成绩,写回 state
*/
private void processSingle(UserEvent event, Collector<String> out) throws Exception {
Long userId = event.getUserId();
String subject = event.getSubject();
Double score = event.getScore();
// --- user info ---
UserInfo ui = userInfoMapState.get(userId);
if (ui == null) {
ui = userInfoMapper.selectById(userId);
if (ui != null) userInfoMapState.put(userId, ui);
}
// --- subject map ---
Map<String, SubjectScore> subjectMap = subjectScoreMapState.get(userId);
if (subjectMap == null) {
LambdaQueryWrapper<SubjectScore> qw = new LambdaQueryWrapper<>();
qw.eq(SubjectScore::getUserId, userId);
List<SubjectScore> list = subjectScoreMapper.selectList(qw);
subjectMap = new HashMap<>();
for (SubjectScore s : list) {
subjectMap.put(s.getSubject(), s);
}
subjectScoreMapState.put(userId, subjectMap);
}
SubjectScore existing = subjectMap.get(subject);
SubjectScore outScore;
if (existing != null) {
if (!Objects.equals(existing.getScore(), score)) {
outScore = new SubjectScore(existing.getId(), userId, subject, score);
} else {
// 与现有分数相同,无需写状态或输出(按需决定)
outScore = null;
}
} else {
outScore = new SubjectScore(null, userId, subject, score);
}
if (outScore != null) {
// 更新 state
Long id = outScore.getId();
String oper = "update";
if (id == null) {
id = IdWorker.getId();
outScore.setId(id);
oper = "create";
}
subjectMap.put(subject, outScore);
subjectScoreMapState.put(userId, subjectMap);
String name = (ui != null ? ui.getName() : "UNKNOWN");
out.collect("SINGLE -> userId=" + userId + ", name=" + name + ", " + oper + "=" + outScore);
}
}
/**
* 下班批量查询后,从 state 中读取并对单条消息进行更新/输出
* 说明:state 已在 flushBufferAndProcess 中更新
*/
private void processSingleFromState(UserEvent event, Collector<String> out) throws Exception {
Long userId = event.getUserId();
String subject = event.getSubject();
Double score = event.getScore();
UserInfo ui = userInfoMapState.get(userId);
Map<String, SubjectScore> subjectMap = subjectScoreMapState.get(userId);
if (subjectMap == null) {
subjectMap = new HashMap<>();
}
SubjectScore existing = subjectMap.get(subject);
SubjectScore outScore;
if (existing != null) {
if (!Objects.equals(existing.getScore(), score)) {
outScore = new SubjectScore(existing.getId(), userId, subject, score);
} else {
// 与现有分数相同,无需写状态或输出(按需决定)
outScore = null;
}
} else {
outScore = new SubjectScore(null, userId, subject, score);
}
if (outScore != null) {
// 更新 state
Long id = outScore.getId();
String oper = "update";
if (id == null) {
id = IdWorker.getId();
outScore.setId(id);
oper = "create";
}
subjectMap.put(subject, outScore);
subjectScoreMapState.put(userId, subjectMap);
String name = (ui != null ? ui.getName() : "UNKNOWN");
out.collect("BATCH -> userId=" + userId + ", name=" + name + ", " + oper + "=" + outScore);
}
}
}
主程序
package press.huang.dev;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import press.huang.dev.config.FlinkProperties;
import press.huang.dev.entity.UserEvent;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import press.huang.dev.functions.SubjectScoreKeyedProcessFunction;
import press.huang.dev.utils.SpringContextUtil;
import java.util.TimeZone;
/**
* 本地测试 main(Spring Boot + Flink)
* 注意:本地调试时请确保引入 flink-clients 依赖(见 pom.xml)
*/
@SpringBootApplication
public class FlinkApplication {
public static void main(String[] args) throws Exception {
// 设置 JVM 默认时区为北京时间
TimeZone.setDefault(TimeZone.getTimeZone("Asia/Shanghai"));
SpringApplication.run(FlinkApplication.class, args);
FlinkProperties props = SpringContextUtil.getBean(FlinkProperties.class);
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(3);
// 模拟输入
DataStreamSource<String> stream = env.socketTextStream("localhost", 7777).setParallelism(1);
SingleOutputStreamOperator<UserEvent> mapStream = stream.map(new MapFunction<String, UserEvent>() {
@Override
public UserEvent map(String line) throws Exception {
String[] fields = line.split(",");
return new UserEvent(Long.parseLong(fields[0]), fields[1], Double.parseDouble(fields[2]));
}
});
int parallelism = env.getParallelism();
KeyedStream<UserEvent, String> keyedStream = mapStream.keyBy(new KeySelector<UserEvent, String>() {
@Override
public String getKey(UserEvent event) throws Exception {
return "" + (event.getUserId() % parallelism);
}
});
keyedStream.process(new SubjectScoreKeyedProcessFunction(props.getBatchSize(), props.getBatchIntervalMs(), props.getWorkStartTime(), props.getWorkEndTime())).print();
env.execute("SubjectScore Hybrid Demo");
}
}