基于 Flink Checkpoint 实现管理 Python 微服务的两阶段提交 Sink 并集成 SkyWalking 链路追踪


我们面临一个棘手的场景:一个 Flink 流处理作业需要将结果原子性地写入两个独立的 Python 微服务。一个服务负责将聚合结果写入 PostgreSQL,另一个则需要调用外部合作方的 API 更新状态。任何一个服务失败,整个操作都必须回滚。这是一个典型的分布式事务问题,但挑战在于协调者是 Flink (JVM),而参与者是 Python 服务。直接使用 Flink 内置的 JdbcSink 无法覆盖非数据库的原子操作。

最初的构想是引入外部事务协调器,如 Seata,但这会增加架构的复杂性和运维成本。在真实项目中,我们倾向于尽可能利用现有技术栈的能力。Flink 的 Checkpoint 机制本质上就是一种两阶段提交协议。当 Flink 进行 Checkpoint 时,它会向所有 Source 发送 barrier,当 barrier 流经整个算子图到达 Sink 时,算子会快照自己的状态。Sink 在接收到 barrier 后,会执行 preCommit 操作,并将事务信息作为状态存入 Checkpoint。当 JobManager 收到所有 TaskManager 的 Checkpoint 完成通知后,会向所有算子广播 Checkpoint 已完成,此时 Sink 会执行 commit 操作。

这个过程给了我们启发:我们完全可以构建一个自定义的 TwoPhaseCommitSinkFunction,将 Flink Job 本身作为事务协调者,将外部的 Python 服务作为事务参与者,从而实现跨语言、跨进程的端到端 Exactly-Once 语义。

第一步:定义 Python 事务参与者接口

为了让 Python 微服务能够参与到 Flink 的 2PC 流程中,它必须遵循事务参与者的契约,即提供 preparecommitrollback 三个核心接口。我们使用 FastAPI 来实现这个服务,因为它轻量且性能出色。

假设这个服务负责将数据写入数据库。在 prepare 阶段,我们会将数据写入一个带有 pending 状态的记录,并获取一个本地事务ID。在 commit 阶段,我们将记录状态更新为 completed。在 rollback 阶段,则删除这条 pending 状态的记录。

这里的坑在于,prepare 阶段必须是持久化的。如果服务在 prepare 成功后、commit 前崩溃,重启后必须能根据事务ID找到那条 pending 记录并完成后续操作。

# file: participant_service/main.py

import os
import uuid
import logging
from contextlib import asynccontextmanager

import psycopg2
from psycopg2.extras import DictCursor
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import BaseModel

# --- 配置 ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_USER = os.getenv("DB_USER", "user")
DB_PASSWORD = os.getenv("DB_PASSWORD", "password")
DB_NAME = os.getenv("DB_NAME", "transactions")

# --- 数据库连接管理 ---
db_conn = None

def get_db_connection():
    global db_conn
    if db_conn is None or db_conn.closed:
        try:
            db_conn = psycopg2.connect(
                host=DB_HOST,
                port=DB_PORT,
                user=DB_USER,
                password=DB_PASSWORD,
                dbname=DB_NAME
            )
            logger.info("Database connection established.")
        except psycopg2.OperationalError as e:
            logger.error(f"Could not connect to database: {e}")
            raise
    return db_conn

def setup_database():
    """初始化数据库表"""
    conn = get_db_connection()
    with conn.cursor() as cur:
        cur.execute("""
        CREATE TABLE IF NOT EXISTS transactional_data (
            global_tx_id VARCHAR(255) PRIMARY KEY,
            payload JSONB,
            status VARCHAR(50) CHECK (status IN ('pending', 'committed', 'aborted'))
        );
        """)
        conn.commit()
    logger.info("Database table 'transactional_data' initialized.")

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 应用启动时
    setup_database()
    yield
    # 应用关闭时
    global db_conn
    if db_conn and not db_conn.closed:
        db_conn.close()
        logger.info("Database connection closed.")

app = FastAPI(lifespan=lifespan)

# --- 数据模型 ---
class PrepareRequest(BaseModel):
    global_tx_id: str
    data: dict

# --- 核心事务接口 ---

@app.post("/prepare")
async def prepare_transaction(request: PrepareRequest):
    """
    第一阶段:准备事务。
    将数据以 'pending' 状态写入数据库。这是幂等操作。
    """
    conn = get_db_connection()
    try:
        with conn.cursor(cursor_factory=DictCursor) as cur:
            # 检查事务是否已存在
            cur.execute("SELECT status FROM transactional_data WHERE global_tx_id = %s FOR UPDATE;", (request.global_tx_id,))
            existing = cur.fetchone()

            if existing:
                if existing['status'] == 'pending':
                    logger.warning(f"TXID {request.global_tx_id} already in 'pending' state. Idempotent call.")
                    return {"status": "ok", "message": "Transaction already prepared."}
                else:
                    logger.error(f"TXID {request.global_tx_id} exists with status {existing['status']}. Conflict.")
                    raise HTTPException(status_code=409, detail="Transaction conflict")

            # 插入新记录
            cur.execute(
                "INSERT INTO transactional_data (global_tx_id, payload, status) VALUES (%s, %s, 'pending')",
                (request.global_tx_id, str(request.data))
            )
            conn.commit()
            logger.info(f"Prepared transaction: {request.global_tx_id}")
            return {"status": "ok", "tx_id": request.global_tx_id}
    except Exception as e:
        conn.rollback()
        logger.error(f"Error during prepare for TXID {request.global_tx_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/commit/{global_tx_id}")
async def commit_transaction(global_tx_id: str):
    """
    第二阶段:提交事务。
    将 'pending' 状态更新为 'committed'。
    """
    conn = get_db_connection()
    try:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE transactional_data SET status = 'committed' WHERE global_tx_id = %s AND status = 'pending'",
                (global_tx_id,)
            )
            if cur.rowcount == 0:
                # 可能是重复提交或者事务不存在
                cur.execute("SELECT status FROM transactional_data WHERE global_tx_id = %s", (global_tx_id,))
                existing_status = cur.fetchone()
                if existing_status and existing_status[0] == 'committed':
                    logger.warning(f"TXID {global_tx_id} already committed. Idempotent call.")
                    return {"status": "ok", "message": "Transaction already committed."}
                logger.error(f"Commit failed: TXID {global_tx_id} not found in 'pending' state.")
                raise HTTPException(status_code=404, detail="Transaction not found in pending state")

            conn.commit()
            logger.info(f"Committed transaction: {global_tx_id}")
            return {"status": "ok"}
    except Exception as e:
        conn.rollback()
        logger.error(f"Error during commit for TXID {global_tx_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/rollback/{global_tx_id}")
async def rollback_transaction(global_tx_id: str):
    """
    回滚事务。
    删除 'pending' 状态的记录。
    """
    conn = get_db_connection()
    try:
        with conn.cursor() as cur:
            # 只有 'pending' 状态的事务才能被回滚
            cur.execute(
                "DELETE FROM transactional_data WHERE global_tx_id = %s AND status = 'pending'",
                (global_tx_id,)
            )
            if cur.rowcount == 0:
                logger.warning(f"Rollback ignored: TXID {global_tx_id} not found in 'pending' state.")

            conn.commit()
            logger.info(f"Rolled back transaction: {global_tx_id}")
            return {"status": "ok"}
    except Exception as e:
        conn.rollback()
        logger.error(f"Error during rollback for TXID {global_tx_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))

这段代码的关键在于数据库操作的原子性和幂等性。例如,commit 接口即使被重复调用,也只会成功一次,这对于应对网络重试至关重要。

现在轮到 Flink 这边的核心逻辑了。我们需要创建一个 PythonMicroservice2PCSink,它继承自 Flink 的 TwoPhaseCommitSinkFunction。这个抽象类帮我们处理了大部分与 Checkpoint 交互的复杂逻辑,我们只需要实现四个关键方法:

  1. beginTransaction(): 在每个新事务开始时调用,通常用来创建一个临时资源或生成一个唯一的事务 ID。
  2. invoke(): 处理每条流入的数据。在这个方法里,我们会调用 Python 服务的 /prepare 接口。
  3. preCommit(): 在 Checkpoint 触发时调用。这个方法应该确保 invoke 中的所有操作都已刷新并准备好提交。此时,事务尚未被真正提交。
  4. commit(): 在 Checkpoint 成功后调用。在这里,我们调用 Python 服务的 /commit 接口。
  5. abort(): 在事务需要回滚时(例如作业失败)调用,对应 Python 服务的 /rollback 接口。
// file: flink-2pc-app/src/main/java/com/example/PythonMicroservice2PCSink.java
package com.example;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeutils.base.VoidSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.streaming.api.functions.sink.TwoPhaseCommitSinkFunction;

import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.skywalking.apm.toolkit.trace.TraceContext;
import org.apache.skywalking.apm.toolkit.trace.Tags;
import org.apache.skywalking.apm.toolkit.trace.Tracer;
import org.apache.skywalking.apm.toolkit.trace.Trace;
import org.apache.skywalking.apm.toolkit.trace.ActiveSpan;

import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

// Flink 的 2PC Sink 需要状态,我们定义一个 TransactionState 类来保存每个事务的信息
class TransactionState {
    // 全局事务ID,由 Flink 端生成,传递给 Python 服务
    String globalTxId;

    public TransactionState(String globalTxId) {
        this.globalTxId = globalTxId;
    }

    @Override
    public String toString() {
        return "TransactionState{" + "globalTxId='" + globalTxId + '\'' + '}';
    }
}


public class PythonMicroservice2PCSink
        extends TwoPhaseCommitSinkFunction<String, TransactionState, Void> {

    private static final Logger LOG = LoggerFactory.getLogger(PythonMicroservice2PCSink.class);
    private final String endpointUrl;
    private transient CloseableHttpClient httpClient;
    private transient ObjectMapper objectMapper;

    public PythonMicroservice2PCSink(String endpointUrl) {
        // 传入Python参与者服务的地址
        super(new KryoSerializer<>(TransactionState.class, new ExecutionConfig()), VoidSerializer.INSTANCE);
        this.endpointUrl = endpointUrl;
    }

    @Override
    public void open(org.apache.flink.configuration.Configuration parameters) throws Exception {
        super.open(parameters);
        this.httpClient = HttpClients.createDefault();
        this.objectMapper = new ObjectMapper();
    }

    @Override
    public void close() throws Exception {
        super.close();
        if (httpClient != null) {
            httpClient.close();
        }
    }

    /**
     * 1. 开始一个新事务,生成全局唯一的事务ID
     */
    @Override
    @Trace(operationName = "2PC-BeginTransaction")
    protected TransactionState beginTransaction() throws Exception {
        String txId = UUID.randomUUID().toString();
        LOG.info("Beginning transaction: {}", txId);
        ActiveSpan.tag("tx.id", txId);
        return new TransactionState(txId);
    }

    /**
     * 2. 处理每条数据,调用 Python 服务的 /prepare 接口
     */
    @Override
    @Trace(operationName = "2PC-InvokeAndPrepare")
    protected void invoke(TransactionState transaction, String value, Context context) throws Exception {
        ActiveSpan.tag("tx.id", transaction.globalTxId);
        ActiveSpan.tag("record.value", value);

        HttpPost request = new HttpPost(endpointUrl + "/prepare");

        Map<String, Object> payload = new HashMap<>();
        payload.put("global_tx_id", transaction.globalTxId);
        payload.put("data", objectMapper.readTree(value));

        StringEntity entity = new StringEntity(objectMapper.writeValueAsString(payload));
        request.setEntity(entity);
        request.setHeader("Accept", "application/json");
        request.setHeader("Content-type", "application/json");

        // SkyWalking 核心:注入追踪上下文到 HTTP Header
        TraceContext.inject(request::setHeader);

        LOG.info("Preparing TXID {}: value={}", transaction.globalTxId, value);

        try (CloseableHttpResponse response = httpClient.execute(request)) {
            int statusCode = response.getStatusLine().getStatusCode();
            String responseBody = EntityUtils.toString(response.getEntity());
            if (statusCode < 200 || statusCode >= 300) {
                LOG.error("Prepare failed for TXID {}. Status: {}, Body: {}", transaction.globalTxId, statusCode, responseBody);
                // 抛出异常会让 Flink 作业失败并重启,最终触发 abort
                throw new RuntimeException("Prepare phase failed for transaction " + transaction.globalTxId);
            }
            LOG.info("Prepare successful for TXID {}", transaction.globalTxId);
        }
    }

    /**
     * 3. 预提交,在 Flink Checkpoint 时触发。
     * 对于我们的HTTP调用模式,invoke已经完成了所有准备工作,这里无需额外操作。
     */
    @Override
    @Trace(operationName = "2PC-PreCommit")
    protected void preCommit(TransactionState transaction) throws Exception {
        ActiveSpan.tag("tx.id", transaction.globalTxId);
        LOG.info("Pre-committing transaction: {}", transaction.globalTxId);
        // 在我们的设计中,准备阶段在 invoke() 中已完成,所以这里是空操作。
        // 如果是批量写入,可以在这里 flush 缓存。
    }

    /**
     * 4. 正式提交,在 Flink Checkpoint 成功后触发
     */
    @Override
    @Trace(operationName = "2PC-Commit")
    protected void commit(TransactionState transaction) {
        ActiveSpan.tag("tx.id", transaction.globalTxId);
        LOG.info("Committing transaction: {}", transaction.globalTxId);
        HttpPost request = new HttpPost(endpointUrl + "/commit/" + transaction.globalTxId);
        TraceContext.inject(request::setHeader);

        try (CloseableHttpResponse response = httpClient.execute(request)) {
            int statusCode = response.getStatusLine().getStatusCode();
            if (statusCode < 200 || statusCode >= 300) {
                 // 这里的错误处理非常关键。如果 commit 失败,Flink 没有内置的重试机制。
                 // 一个常见的错误是在这里抛异常,但这没用,因为 Checkpoint 已经完成了。
                 // 生产级系统需要一个重试队列或备用机制来保证最终提交成功。
                 LOG.error("FATAL: Commit failed for TXID {} and cannot be automatically recovered. Status: {}. Needs manual intervention.",
                        transaction.globalTxId, statusCode);
                 Tracer.activeSpan().error();
                 Tags.LOGIC_SPAN.set(true);
            } else {
                 LOG.info("Commit successful for TXID {}", transaction.globalTxId);
            }
        } catch (Exception e) {
            LOG.error("FATAL: Exception during commit for TXID " + transaction.globalTxId, e);
            Tracer.activeSpan().error(e);
            Tags.LOGIC_SPAN.set(true);
        }
    }

    /**
     * 5. 回滚事务
     */
    @Override
    @Trace(operationName = "2PC-Abort")
    protected void abort(TransactionState transaction) {
        ActiveSpan.tag("tx.id", transaction.globalTxId);
        LOG.info("Aborting transaction: {}", transaction.globalTxId);
        HttpPost request = new HttpPost(endpointUrl + "/rollback/" + transaction.globalTxId);
        TraceContext.inject(request::setHeader);

        try (CloseableHttpResponse response = httpClient.execute(request)) {
            int statusCode = response.getStatusLine().getStatusCode();
            if (statusCode < 200 || statusCode >= 300) {
                 LOG.error("Abort failed for TXID {}. Status: {}.", transaction.globalTxId, statusCode);
                 Tracer.activeSpan().error();
            } else {
                 LOG.info("Abort successful for TXID {}", transaction.globalTxId);
            }
        } catch (Exception e) {
            LOG.error("Exception during abort for TXID " + transaction.globalTxId, e);
            Tracer.activeSpan().error(e);
        }
    }
}

这段 Java 代码中最精髓的部分是与 SkyWalking 的集成。通过 TraceContext.inject(request::setHeader),我们将 Flink 端当前 Span 的追踪信息(如 traceId, spanId)注入到发往 Python 服务的 HTTP 请求头中。Python 端的 SkyWalking Agent 会自动识别这些请求头,从而将两个独立系统的调用串联成一个完整的分布式调用链。

第三步:可视化与验证

要将整个系统运行起来,我们需要一个 Flink 作业、Python 服务、PostgreSQL 数据库以及 SkyWalking 的后端和 UI。

// file: flink-2pc-app/src/main/java/com/example/DataStreamJob.java
package com.example;

import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.CheckpointingMode;

public class DataStreamJob {
    public static void main(String[] args) throws Exception {
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        // 必须开启 Checkpoint,这是 2PC Sink 的基础
        env.enableCheckpointing(5000, CheckpointingMode.EXACTLY_ONCE);

        env.addSource(new SourceFunction<String>() {
                private volatile boolean isRunning = true;
                private int counter = 0;

                @Override
                public void run(SourceContext<String> ctx) throws Exception {
                    while (isRunning) {
                        synchronized (ctx.getCheckpointLock()) {
                            String json = String.format("{\"id\": %d, \"data\": \"message-%d\"}", counter++, counter);
                            ctx.collect(json);
                        }
                        Thread.sleep(1000);
                    }
                }

                @Override
                public void cancel() {
                    isRunning = false;
                }
            })
            .name("SimpleJsonSource")
            .addSink(new PythonMicroservice2PCSink("http://participant-service:8000"))
            .name("Python2PCSink")
            .setParallelism(1);

        env.execute("Flink 2PC to Python Example");
    }
}

系统架构与调用流程

我们可以使用 Mermaid 图来清晰地展示整个流程。

sequenceDiagram
    participant Flink JobManager
    participant Flink TaskManager (Sink)
    participant Python Service
    participant PostgreSQL

    Note over Flink TaskManager (Sink): Flink Checkpoint[N] starts
    Flink JobManager->>Flink TaskManager (Sink): Trigger Checkpoint[N]
    activate Flink TaskManager (Sink)

    Note over Flink TaskManager (Sink): sink.preCommit() for transaction TX_N
    Flink TaskManager (Sink)->>Flink JobManager: Acknowledge preCommit for TX_N
    deactivate Flink TaskManager (Sink)

    Note over Flink JobManager: All tasks acknowledged, Checkpoint[N] completed
    Flink JobManager->>Flink TaskManager (Sink): Notify Checkpoint[N] Completed
    activate Flink TaskManager (Sink)

    Note over Flink TaskManager (Sink): sink.commit() for transaction TX_N
    Flink TaskManager (Sink)->>Python Service: POST /commit/TX_N
    activate Python Service
    Python Service->>PostgreSQL: UPDATE ... SET status='committed'
    activate PostgreSQL
    PostgreSQL-->>Python Service: UPDATE successful
    deactivate PostgreSQL
    Python Service-->>Flink TaskManager (Sink): 200 OK
    deactivate Python Service
    deactivate Flink TaskManager (Sink)

    loop Every 1 second
        Note over Flink TaskManager (Sink): New record arrives, sink.invoke()
        activate Flink TaskManager (Sink)
        Note over Flink TaskManager (Sink): Flink Checkpoint[N+1] begins, new transaction TX_N+1
        Flink TaskManager (Sink)->>Python Service: POST /prepare (body contains TX_N+1)
        activate Python Service
        Python Service->>PostgreSQL: INSERT ... (status='pending')
        activate PostgreSQL
        PostgreSQL-->>Python Service: INSERT successful
        deactivate PostgreSQL
        Python Service-->>Flink TaskManager (Sink): 200 OK
        deactivate Python Service
        deactivate Flink TaskManager (Sink)
    end

当一切正常运行时,SkyWalking 的 UI 会展示出清晰的调用链。你会看到一个由 Flink 作业发起的根 Span,下面挂着多个子 Span,分别对应 2PC-BeginTransaction2PC-InvokeAndPrepare 等。2PC-InvokeAndPrepare Span 下面会有一个跨进程的 Exit Span,指向 Python 服务的 /prepare 端点。同样,2PC-Commit Span 下也会有一个指向 /commit 端点的跨进程调用。如果任何一个环节出错,例如 Python 服务返回 500 错误,对应的 Span 会被标记为红色,我们可以迅速定位到问题所在,无论是 Flink 端的逻辑错误还是 Python 参与者的实现缺陷。

方案的局限性与未来优化

这个方案虽然优雅地解决了跨语言分布式事务问题,但在生产环境中仍需考虑其边界。

首先,性能开销。每次 invoke 都伴随着一次同步的 HTTP 调用,这会显著增加数据处理的延迟。对于高吞吐量的场景,可以优化为在 invoke 中将数据批量缓存在内存,然后在 preCommit 时一次性批量发送 prepare 请求。但这会增加 Sink 实现的复杂性,需要处理好批量失败和重试的逻辑。

其次,commit 阶段的可靠性。如代码注释中所述,标准的 TwoPhaseCommitSinkFunctioncommit 失败后没有重试机制。这是 2PC 协议固有的“阻塞”问题。一旦参与者在 prepare 后、commit 前失联,事务状态会悬而不决。在真实项目中,对于 commit 失败,通常需要引入一个独立的补偿任务或死信队列,进行异步重试,直到最终成功为止。

最后,此方案高度依赖 Flink 的 Checkpoint 机制。Checkpoint 的频率、超时时间等参数需要根据业务的 RPO/RTO 要求和外部服务的性能进行精细调优。过高的 Checkpoint 频率会增加外部服务的压力,而过低的频率则会延长故障恢复的时间。


  目录