基于强化学习与图数据库构建 Kong 动态自适应限流插件的实践复盘


我们团队的 API 网关 Kong 面临的第一个棘手问题,源于一个经典的矛盾:业务增长带来的流量洪峰与后端服务脆弱的稳定性。最初采用的静态限流策略,例如 rate-limiting 插件配置的 60 requests/minute,在实践中显得极其僵化。它要么在流量低谷时浪费了服务能力,要么在双十一这类流量脉冲到来时,因为阈值过低而误伤大量正常用户,导致业务指标断崖式下跌。提高阈值?又会在遭遇恶意攻击或下游服务抖动时,无法有效保护系统,引发雪崩。

这是一个典型的控制论问题,而静态阈值是最原始的开环控制。在一次敏捷迭代的复盘会上,我们决定彻底废弃这种“一刀切”的方案。我们的目标是构建一个能够根据系统实时状态(如服务延迟、错误率、资源利用率)和流量模式(用户类型、请求路径)动态调整限流策略的智能系统。一个闭环的、自适应的控制器。

初步构想与技术选型

我们的初步构想是开发一个自定义的 Kong 插件。这个插件的核心将是一个决策引擎,它能实时决定是否允许某个请求通过。这个决策不再基于固定的计数器,而是基于一个更复杂的模型。

强化学习(Reinforcement Learning, RL)立刻进入了我们的视野。我们可以将限流问题建模为一个 RL 任务:

  • 环境 (Environment): 整个 API 调用链路,包括 Kong、后端服务和外部流量。
  • 智能体 (Agent): 我们的自定义插件。
  • 状态 (State): 描述环境当前状况的一组向量。例如:当前服务的 QPS、平均响应延迟、P99 延迟、5xx 错误率、CPU/内存使用率等。
  • 动作 (Action): Agent 可以采取的操作。例如:将当前窗口的限流阈值提高10%,降低10%,或保持不变。
  • 奖励 (Reward): Agent 执行一个动作后,环境反馈的标量值。我们希望最大化这个值。例如:成功处理一个请求奖励 +1,一个 5xx 错误惩罚 -100,响应延迟超过 SLO 惩罚 -10。

这个模型的核心是“状态”的表示。一个服务的状态不仅取决于其自身,还与它的上游依赖、下游调用者、甚至是部署在同一节点上的其他服务有关。这种复杂的、网状的依赖关系,用传统的关系型数据库或键值存储来描述会非常痛苦。这正是图数据库的用武之地。

最终的技术栈选型如下:

  1. Kong: 作为执行层,利用其强大的插件化架构。
  2. Python: 使用 kong-pdk 开发插件逻辑。Python 拥有成熟的机器学习生态(PyTorch, Stable Baselines3),便于我们实现 RL Agent。
  3. Dgraph: 作为状态存储中心。我们将整个微服务架构的拓扑、实时健康状况、以及服务间的调用关系建模成一个动态更新的图。Dgraph 的 GraphQL API 和高性能使其成为理想选择。
  4. 敏捷开发: 整个项目采用小步快跑的迭代方式。每个 Sprint 都致力于交付一个可测试、可集成的最小功能闭环。

架构设计

整个系统的核心数据流和控制流可以用下面的图来表示。

graph TD
    subgraph "请求处理链路 (Hot Path)"
        Client[客户端] --> Kong[Kong API 网关]
        Kong --> Plugin[RL 自适应限流插件]
        Plugin --> Dgraph_Read{读取服务状态图}
        Dgraph_Read --> Model[加载的RL模型]
        Model --> Decision{执行/拒绝}
        Decision -- 执行 --> Upstream[上游服务]
        Decision -- 拒绝 --> Client_429[返回 429 Too Many Requests]
        Upstream --> Kong
        Kong --> Client
    end

    subgraph "模型训练与状态更新 (Cold Path)"
        Prometheus[Prometheus] -- 指标 --> Metrics_Collector[指标收集器]
        Upstream_Logs[服务日志] -- 日志 --> Log_Processor[日志处理器]
        
        Metrics_Collector --> Dgraph_Write[更新Dgraph状态图]
        Log_Processor --> Dgraph_Write[更新Dgraph状态图]
        
        Dgraph_Write --> Dgraph[(Dgraph 图数据库)]

        subgraph "RL 训练循环"
            RL_Trainer[训练器] -- 批量读取 --> Dgraph
            RL_Trainer -- 训练 --> Trained_Model[训练好的模型]
            Trained_Model -- 推送 --> Model_Registry[模型仓库 S3/OSS]
        end

        Plugin_Updater[插件更新器] -- 拉取最新模型 --> Model_Registry
        Plugin_Updater -- 更新 --> Model
    end

    style Dgraph fill:#f9f,stroke:#333,stroke-width:2px
    style Plugin fill:#bbf,stroke:#333,stroke-width:2px

这个架构清晰地分离了热路径和冷路径。热路径上的插件逻辑必须是极致轻量和高效的,它只负责从 Dgraph 读取预计算好的状态,并用本地加载的模型进行快速推理。所有重计算,包括指标聚合和模型训练,都在冷路径上异步完成。

Sprint 1: Dgraph 状态图建模与服务发现

第一个迭代的目标是构建我们的“世界模型”——在 Dgraph 中描述我们的服务架构。我们不希望手动维护这个图,而是让它通过服务发现和监控数据自动生成和更新。

我们定义的 Dgraph Schema (schema.graphql) 如下:

# 定义服务节点
type Service {
    id: ID!
    name: String! @id @search(by: [hash])
    version: String! @search(by: [exact])
    ip: String!
    # 服务的实时健康状态
    health: HealthStatus @hasInverse(field: ofService)
    # 该服务依赖的其他服务
    dependsOn: [Service] @hasInverse(field: dependedBy)
    # 依赖该服务的其他服务
    dependedBy: [Service]
    # 该服务暴露的API路由
    routes: [Route] @hasInverse(field: forService)
}

# 定义API路由节点
type Route {
    id: ID!
    path: String! @id @search(by: [hash])
    # 路由关联的服务
    forService: Service!
    # 路由的实时流量指标
    traffic: TrafficMetrics @hasInverse(field: ofRoute)
}

# 服务的健康状态,每个时间点一条记录
type HealthStatus {
    id: ID!
    ofService: Service!
    timestamp: DateTime!
    cpuUsage: Float!
    memoryUsage: Float!
    latencyP99: Int! # in milliseconds
    errorRate5xx: Float! # 0.0 to 1.0
}

# 路由的流量指标
type TrafficMetrics {
    id: ID!
    ofRoute: Route!
    timestamp: DateTime!
    qps: Int!
    inboundBandwidth: Int! # in bps
}

接着,我们编写一个 Python 脚本,它定期从我们的服务注册中心(例如 Consul)和监控系统(Prometheus)拉取数据,然后通过 Dgraph 的 GraphQL API 更新图数据。

# data_updater.py
import os
import requests
import json
import time
from datetime import datetime, timezone

# 配置
DGRAPH_ENDPOINT = os.getenv("DGRAPH_ENDPOINT", "http://localhost:8080/graphql")
PROMETHEUS_ENDPOINT = os.getenv("PROMETHEUS_ENDPOINT", "http://localhost:9090")
SERVICE_REGISTRY_URL = os.getenv("SERVICE_REGISTRY_URL", "http://localhost:8500/v1/catalog/services")

def execute_gql_mutation(query, variables):
    """执行 Dgraph GraphQL 变更。"""
    headers = {"Content-Type": "application/json"}
    try:
        response = requests.post(DGRAPH_ENDPOINT, json={"query": query, "variables": variables}, headers=headers)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error executing GraphQL mutation: {e}")
        return None

def update_service_topology():
    """从服务注册中心同步服务及其依赖关系。"""
    try:
        services_response = requests.get(SERVICE_REGISTRY_URL)
        services_response.raise_for_status()
        services = services_response.json()

        for service_name, tags in services.items():
            # 简化处理,假设tags包含版本和依赖信息
            # 在真实项目中,这部分会更复杂
            version = next((t.split(':')[1] for t in tags if t.startswith('version:')), 'v1.0.0')
            
            # 使用 Dgraph 的 "upsert" 逻辑
            mutation = """
            mutation upsertService($name: String!, $version: String!) {
              addData(input: [{
                name: $name,
                version: $version,
                # ip等信息也应在此处更新
              }], upsert: true) {
                data {
                  id
                }
              }
            }
            """
            execute_gql_mutation(mutation, {"name": service_name, "version": version})
            print(f"Upserted service: {service_name}")

    except requests.exceptions.RequestException as e:
        print(f"Failed to fetch services from registry: {e}")


def update_health_metrics():
    """从 Prometheus 拉取健康指标并更新到 Dgraph。"""
    # 示例: 查询 user-service 的 P99 延迟
    query = 'histogram_quantile(0.99, sum(rate(http_server_requests_seconds_bucket{service="user-service"}[5m])) by (le))'
    try:
        response = requests.get(f"{PROMETHEUS_ENDPOINT}/api/v1/query", params={"query": query})
        response.raise_for_status()
        results = response.json()['data']['result']
        if not results:
            return

        latency_p99_ms = int(float(results[0]['value'][1]) * 1000)
        
        # 实际项目中,会遍历所有服务和所有指标
        service_name = "user-service"
        
        mutation = """
        mutation updateHealth($serviceName: String!, $healthData: HealthStatusInput!) {
          updateService(input: {
            filter: { name: { eq: $serviceName } },
            set: {
              health: $healthData
            }
          }) {
            service {
              name
            }
          }
        }
        """
        variables = {
            "serviceName": service_name,
            "healthData": {
                "timestamp": datetime.now(timezone.utc).isoformat(),
                "latencyP99": latency_p99_ms,
                "cpuUsage": 0.5, # 示例数据
                "errorRate5xx": 0.01, # 示例数据
                "memoryUsage": 0.6, # 示例数据
            }
        }
        execute_gql_mutation(mutation, variables)
        print(f"Updated health for {service_name}: P99 latency {latency_p99_ms}ms")

    except requests.exceptions.RequestException as e:
        print(f"Failed to fetch metrics from Prometheus: {e}")

if __name__ == "__main__":
    while True:
        print("--- Running update cycle ---")
        update_service_topology()
        update_health_metrics()
        time.sleep(15) # 每 15 秒更新一次

这个脚本作为后台守护进程运行,持续地将外部世界的状态同步到我们的 Dgraph “大脑”中。

Sprint 2: Kong Python 插件骨架与 Dgraph 集成

第二个迭代的核心是打通插件与 Dgraph 的通信。我们使用 Kong 官方提供的 kong-pdk 来编写 Python 插件。

文件结构:

kong-plugins/
└── rl-rate-limiter/
    ├── kong/plugins/rl-rate-limiter/
    │   ├── handler.py
    │   └── schema.lua
    └── rockspecs/
        └── kong-plugin-rl-rate-limiter-0.1.0-1.rockspec

schema.lua 定义了插件的配置项:

-- schema.lua
local typedefs = require "kong.db.schema.typedefs"

return {
  name = "rl-rate-limiter",
  fields = {
    { consumer = typedefs.no_consumer },
    { route = typedefs.no_route },
    { service = typedefs.no_service },
    {
      config = {
        type = "record",
        fields = {
          { dgraph_endpoint = typedefs.url { required = true } },
          { model_path = typedefs.string { required = true, default = "/usr/local/kong/models/agent.pth" } },
          { fallback_rate_limit = typedefs.integer { default = 100 } },
          { fallback_rate_window = typedefs.integer { default = 60 } },
        },
      },
    },
  },
}

handler.py 是插件的核心逻辑所在。在这个 Sprint,我们只实现从 Dgraph 读取状态的部分。

# handler.py
import os
import json
import torch
import requests
import numpy as np
from kong_pdk.plugin import Plugin

# 假设我们有一个预训练的模型和状态规范化器
# model = torch.load(config['model_path'])
# scaler = ...

class RlRateLimiter(Plugin):
    def __init__(self, config):
        super().__init__(config)
        self.dgraph_endpoint = self.config["dgraph_endpoint"]
        self.session = requests.Session() # 使用会话以提高性能
        
    def get_service_state(self, service_name):
        """从Dgraph查询服务的当前状态向量。"""
        query = """
        query getServiceState($name: String!) {
          queryService(filter: { name: { eq: $name } }) {
            name
            health(order: { desc: timestamp }, first: 1) {
              latencyP99
              errorRate5xx
              cpuUsage
              memoryUsage
            }
            routes(first: 1) {
              traffic(order: { desc: timestamp }, first: 1) {
                qps
              }
            }
          }
        }
        """
        try:
            response = self.session.post(
                self.dgraph_endpoint,
                json={"query": query, "variables": {"name": service_name}},
                timeout=0.1 # 必须设置严格的超时
            )
            response.raise_for_status()
            data = response.json().get("data", {}).get("queryService", [])
            
            if not data:
                return None
            
            service_data = data[0]
            health = service_data.get('health', [{}])[0]
            traffic = service_data.get('routes', [{}])[0].get('traffic', [{}])[0]
            
            # 构建状态向量,顺序必须与模型训练时一致
            state_vector = np.array([
                health.get('latencyP99', 500),
                health.get('errorRate5xx', 0.5),
                health.get('cpuUsage', 0.8),
                health.get('memoryUsage', 0.8),
                traffic.get('qps', 1000)
            ], dtype=np.float32)

            return state_vector

        except requests.exceptions.RequestException as e:
            # 日志记录错误
            # kong.log.err(f"Failed to query Dgraph: {e}")
            return None

    async def access(self, kong):
        route = await kong.request.get_route()
        service = route["service"]
        service_name = service["name"]
        
        state = self.get_service_state(service_name)
        
        if state is None:
            # Dgraph查询失败或服务不存在,应用备用静态限流
            # 这是保证系统韧性的关键
            # ... 此处省略静态限流逻辑 ...
            await kong.log.warn(f"Dgraph state unavailable for {service_name}, using fallback limit.")
            return

        # 在下一个Sprint中,我们将把state喂给模型
        # state_tensor = torch.from_numpy(state).unsqueeze(0)
        # action = model.predict(state_tensor)
        # ...
        
        await kong.log.info(f"Retrieved state for {service_name}: {state.tolist()}")

这里的关键点在于容错处理。如果 Dgraph 查询失败或超时,插件必须能够优雅地降级 (graceful degradation),切换到一个保守的、预设的静态限流策略。这是生产级代码和 “hello world” 示例的核心区别。

Sprint 3: 强化学习 Agent 的训练

这是最核心的部分。我们选择 stable-baselines3 库,它封装了多种成熟的 RL 算法。我们需要先定义一个符合 gym 接口的自定义环境,这个环境可以从 Dgraph 中保存的历史数据中采样,用于离线训练。

# training/environment.py
import gym
from gym import spaces
import numpy as np
import pandas as pd # 用于处理从Dgraph导出的历史数据

class RateLimitingEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, historical_data_path):
        super(RateLimitingEnv, self).__init__()
        
        # 加载历史数据
        self.df = pd.read_csv(historical_data_path)
        self.current_step = 0

        # 定义动作空间: 0=降低阈值, 1=保持, 2=提高阈值
        self.action_space = spaces.Discrete(3)
        
        # 定义状态空间: [latencyP99, errorRate5xx, cpuUsage, memoryUsage, qps]
        # 注意要进行归一化,所以是 Box(0, 1, ...)
        self.observation_space = spaces.Box(low=0, high=1, shape=(5,), dtype=np.float32)
        
        self.current_rate_limit_factor = 1.0 # 初始限流因子

    def _get_obs(self):
        # 从DataFrame获取当前时间步的状态,并进行归一化
        # 这里的归一化参数需要在整个数据集上计算
        raw_state = self.df.iloc[self.current_step][['latency', 'error_rate', 'cpu', 'mem', 'qps']].values
        # normalized_state = scaler.transform(raw_state)
        return raw_state.astype(np.float32)

    def reset(self):
        self.current_step = 0
        self.current_rate_limit_factor = 1.0
        return self._get_obs()

    def step(self, action):
        state = self._get_obs()
        qps = state[4]
        
        # 根据动作调整限流因子
        if action == 0:
            self.current_rate_limit_factor *= 0.9
        elif action == 2:
            self.current_rate_limit_factor *= 1.1
        
        # 假设我们有一个基线QPS
        baseline_qps = 100
        current_limit = baseline_qps * self.current_rate_limit_factor
        
        # 计算奖励
        reward = 0
        latency = state[0]
        error_rate = state[1]
        
        # 核心奖励函数设计
        if qps <= current_limit:
            # 成功处理的请求,奖励与QPS正相关
            reward += qps / 10
            # 对高延迟进行惩罚
            if latency > 200: # SLO=200ms
                reward -= (latency - 200) / 10
        else:
            # 超过限流,模拟拒绝请求,无奖励
            pass
            
        # 对5xx错误进行重罚
        reward -= error_rate * 1000
        
        self.current_step += 1
        done = self.current_step >= len(self.df) - 1
        
        return self._get_obs(), reward, done, {}

然后是训练脚本:

# training/train.py
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from environment import RateLimitingEnv

# 1. 准备数据
# 你需要一个脚本从Dgraph导出历史数据并存为CSV
historical_data_file = "dgraph_export.csv"

# 2. 创建环境
env = DummyVecEnv([lambda: RateLimitingEnv(historical_data_file)])

# 3. 选择模型并训练
# PPO (Proximal Policy Optimization) 是一种稳健的算法
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_ratelimit_tensorboard/")
# 训练10万个时间步
model.learn(total_timesteps=100000)

# 4. 保存模型
model.save("rl_rate_limiter_agent")
print("Model saved!")

训练过程会生成一个 rl_rate_limiter_agent.zip 文件,我们将它解压后得到的 .pth 文件部署到 Kong 插件可以访问的路径。

Sprint 4: 整合与部署

最后一个迭代是将训练好的模型集成到插件中,并完成端到端逻辑。

更新 handler.py:

# handler.py (final version)
import os
import json
import torch
import numpy as np
import requests
from kong_pdk.plugin import Plugin
from stable_baselines3 import PPO # 需要将stable-baselines3打包到插件环境中

# 全局加载模型,避免每次请求都加载
# 这是一个关键的性能优化
MODEL = None

def load_model(path):
    global MODEL
    if MODEL is None:
        try:
            # 注意: PPO.load() 返回一个完整的 PPO 对象,我们只需要策略网络
            # 在生产中,最好是只保存和加载策略网络的状态字典
            # model_policy = PPO.load(path).policy
            # For simplicity here, we assume a simple torch model is saved.
            # Let's mock a simple policy network
            class MockPolicy(torch.nn.Module):
                def __init__(self):
                    super().__init__()
                    self.net = torch.nn.Linear(5, 3)
                def forward(self, x):
                    return self.net(x)
                def predict(self, obs, deterministic=True):
                    obs = torch.from_numpy(obs).float()
                    logits = self.forward(obs)
                    if deterministic:
                        action = torch.argmax(logits, dim=1).numpy()
                    else:
                        # ...
                        pass
                    return action, None
            MODEL = MockPolicy()
            # In real scenario: MODEL.load_state_dict(torch.load(path))
        except Exception as e:
            # kong.log.err(...)
            print(f"Failed to load model: {e}")
            MODEL = None

class RlRateLimiter(Plugin):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.session = requests.Session()
        load_model(self.config['model_path'])

    # get_service_state(...) 方法同上

    async def access(self, kong):
        if MODEL is None:
            # 如果模型加载失败,必须降级
            await kong.log.err("RL Model is not loaded, rate limiting is disabled or in fallback mode.")
            return

        route = await kong.request.get_route()
        service = route["service"]
        service_name = service["name"]

        state = self.get_service_state(service_name)
        
        if state is None:
            # ... 降级逻辑 ...
            await kong.log.warn(f"State for {service_name} unavailable, using fallback.")
            # kong.response.exit(429, ...)
            return

        # 模型推理
        # 归一化状态 state_normalized = scaler.transform(state)
        action, _ = MODEL.predict(state.reshape(1, -1), deterministic=True)
        action = action[0] # predict returns a tuple

        # redis_key = f"rl_rate_limit:{service_name}"
        # 获取当前限流阈值
        # current_limit = kong.cache.get(redis_key)
        
        # 根据 action 调整并设置新的阈值 (这里简化了逻辑)
        # 实际需要一个原子操作来更新阈值
        if action == 0: # 降低
            # new_limit = max(current_limit * 0.9, config.min_limit)
            await kong.log.info(f"Action for {service_name}: DECREASE limit.")
        elif action == 2: # 提高
            # new_limit = min(current_limit * 1.1, config.max_limit)
            await kong.log.info(f"Action for {service_name}: INCREASE limit.")
        else: # 保持
            await kong.log.info(f"Action for {service_name}: MAINTAIN limit.")
            
        # kong.cache.set(redis_key, new_limit, ttl=10)
        
        # ... 后续再结合 Kong 自带的限流插件(或在插件内实现计数)来执行决策 ...
        # 这是一个复杂点:我们的插件是决策者,还需要一个执行者。
        # 最简单的实现是动态修改 rate-limiting 插件的配置,但这有性能问题。
        # 更好的方法是在本插件内直接使用 kong.shared.dict 或 redis 进行计数和限流。

这个版本展示了将模型推理集成到请求处理流程中的思路。在真实项目中,决策的执行(即如何精确地调整并实施限流)是另一个需要细致设计的环节,通常会借助 Redis 或 lua-resty-lock 来实现分布式环境下的原子计数和锁定。

局限性与未来迭代方向

经过几个敏捷迭代,我们成功上线了这套系统。它确实比静态限流策略表现得更加智能和有弹性,尤其是在应对突发流量和服务降级场景时。但这个方案并非银弹,它也引入了新的复杂性:

  1. 奖励函数设计的挑战: 奖励函数直接决定了 Agent 的行为。一个设计不佳的奖励函数可能导致 Agent 学会一些意想不到的、甚至有害的策略(例如,为了避免延迟惩罚而过早地、过度地限流)。
  2. 模型的泛化能力: 离线训练的模型可能无法很好地适应线上从未见过的流量模式。这要求我们建立一套持续学习 (Continual Learning) 的机制,定期用新的线上数据重新训练或微调模型。
  3. 状态表示的完备性: 我们目前的状态向量还比较简单。更复杂的模型可能需要考虑用户画像、API 的业务重要性、甚至是整个调用链路的端到端状态,这将进一步加大图建模和状态提取的难度。
  4. 可解释性: 强化学习模型通常像一个黑盒。当它做出一个非预期的限流决策时,我们很难快速定位原因。为 Agent 的决策提供可解释性报告,是未来重要的优化方向。

未来的迭代计划将围绕这些问题展开,例如探索在线学习(Online RL)框架,引入更复杂的图神经网络(GNN)来从 Dgraph 中自动学习状态表征,以及研究模型可解释性工具在这一场景的应用。


  目录