大模型接入和管理
深入理解 Dify 的 Model Runtime 架构,掌握如何统一管理数百个大模型
📖 内容概述
本文将深入分析 Dify 的 Model Runtime 模块,理解如何通过统一的抽象层管理 OpenAI、Anthropic、国产模型等数百个大模型,掌握模型调用流程、参数标准化、流式响应、错误重试等核心机制。
🎯 学习目标
- 理解 Model Runtime 的三层架构设计
- 掌握模型供应商(Provider)抽象
- 理解模型类型分类和能力定义
- 分析 OpenAI 和 Anthropic 的适配差异
- 掌握模型调用流程和参数标准化
- 理解流式响应的实现原理
- 学习错误处理和重试机制
- 掌握模型接入的最佳实践
📂 源码路径
api/core/model_runtime/
├── README.md # Model Runtime 说明文档 ⭐
├── __init__.py
│
├── entities/ # 实体定义
│ ├── model_entities.py # 模型实体
│ ├── message_entities.py # 消息实体
│ ├── llm_entities.py # LLM 相关实体
│ └── provider_entities.py # Provider 实体
│
├── model_providers/ # 模型供应商实现 ⭐
│ ├── __base/ # 基类
│ │ ├── ai_model.py # AI 模型基类
│ │ ├── large_language_model.py # LLM 基类 ⭐
│ │ ├── text_embedding_model.py # Embedding 基类
│ │ ├── rerank_model.py # Rerank 基类
│ │ ├── tts_model.py # TTS 基类
│ │ └── speech2text_model.py # STT 基类
│ │
│ └── model_provider_factory.py # 工厂类 ⭐
│
├── callbacks/ # 回调处理
│ ├── base_callback.py
│ └── logging_callback.py
│
├── errors/ # 错误定义
│ ├── invoke.py
│ └── validate.py
│
└── schema_validators/ # 配置验证
├── provider_credential_schema_validator.py
└── model_credential_schema_validator.py一、Model Runtime 架构设计
1.1 三层架构
Model Runtime 采用经典的三层架构,从上到下分别是:
┌─────────────────────────────────────────────────┐
│ Factory Layer(工厂层) │
│ │
│ ModelProviderFactory │
│ - 获取所有 Provider │
│ - 获取所有模型列表 │
│ - 获取 Provider 实例 │
│ - 验证 Provider/模型凭据 │
└────────────────┬────────────────────────────────┘
│
┌────────────────▼────────────────────────────────┐
│ Provider Layer(供应商层) │
│ │
│ 各个模型供应商的实现(通过插件加载) │
│ - OpenAI Provider │
│ - Anthropic Provider │
│ - 国产模型 Provider │
│ - 本地模型 Provider │
│ ... │
└────────────────┬────────────────────────────────┘
│
┌────────────────▼────────────────────────────────┐
│ Model Layer(模型层) │
│ │
│ 各类型模型的具体实现: │
│ - LargeLanguageModel(LLM) │
│ - TextEmbeddingModel(Embedding) │
│ - RerankModel(重排序) │
│ - TTSModel(文本转语音) │
│ - Speech2TextModel(语音转文本) │
│ - ModerationModel(内容审核) │
└─────────────────────────────────────────────────┘设计优势:
- ✅ 统一接口:所有模型供应商使用相同的接口
- ✅ 横向扩展:添加新供应商不影响现有代码
- ✅ 类型安全:通过抽象类强制实现必要方法
- ✅ 插件化:通过插件系统动态加载模型
1.2 模型类型分类
Dify 支持 6 种模型类型,每种类型有不同的能力:
python
# api/core/model_runtime/entities/model_entities.py
from enum import Enum
class ModelType(Enum):
"""模型类型枚举"""
LLM = "llm" # 大语言模型
TEXT_EMBEDDING = "text-embedding" # 文本 Embedding
RERANK = "rerank" # 重排序
SPEECH2TEXT = "speech2text" # 语音转文本
TTS = "tts" # 文本转语音
MODERATION = "moderation" # 内容审核各类型的能力:
| 模型类型 | 主要能力 | 典型模型 |
|---|---|---|
| LLM | 文本生成、对话、推理 | GPT-4, Claude, 文心一言 |
| Text Embedding | 文本向量化 | text-embedding-3, m3e |
| Rerank | 文档重排序 | Cohere Rerank, bge-reranker |
| Speech2Text | 语音识别 | Whisper, 讯飞语音 |
| TTS | 语音合成 | Azure TTS, 讯飞TTS |
| Moderation | 内容审核 | OpenAI Moderation |
1.3 工厂模式实现
ModelProviderFactory 是整个 Model Runtime 的入口:
python
# api/core/model_runtime/model_providers/model_provider_factory.py
class ModelProviderFactory:
"""
模型供应商工厂类
负责获取和管理所有模型供应商
"""
def __init__(self, tenant_id: str):
"""
初始化工厂
Args:
tenant_id: 租户 ID(多租户隔离)
"""
from core.plugin.impl.model import PluginModelClient
self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelClient()
def get_providers(self) -> Sequence[ProviderEntity]:
"""
获取所有模型供应商
Returns:
所有供应商的列表
"""
# 从插件服务器获取所有 Provider
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
"""
获取所有插件模型供应商
使用上下文缓存避免重复获取
"""
# 获取或初始化上下文
try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set(None)
contexts.plugin_model_providers_lock.set(Lock())
# 使用锁保证线程安全
with contexts.plugin_model_providers_lock.get():
plugin_model_providers = contexts.plugin_model_providers.get()
if plugin_model_providers is not None:
return plugin_model_providers
# 从插件管理器获取 Provider
plugin_providers = self.plugin_model_manager.fetch_model_providers(
self.tenant_id
)
# 缓存结果
plugin_model_providers = []
for provider in plugin_providers:
provider.declaration.provider = (
provider.plugin_id + "/" + provider.declaration.provider
)
plugin_model_providers.append(provider)
contexts.plugin_model_providers.set(plugin_model_providers)
return plugin_model_providers
def get_model_type_instance(
self,
provider: str,
model_type: ModelType
) -> AIModel:
"""
获取模型类型实例
Args:
provider: 供应商名称(如 "openai")
model_type: 模型类型
Returns:
对应类型的模型实例
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(
provider
)
# 初始化参数
init_params = {
"tenant_id": self.tenant_id,
"plugin_id": plugin_id,
"provider_name": provider_name,
"plugin_model_provider": self.get_plugin_model_provider(provider),
}
# 根据类型创建实例
if model_type == ModelType.LLM:
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
def provider_credentials_validate(
self,
*,
provider: str,
credentials: dict
):
"""
验证供应商凭据
Args:
provider: 供应商名称
credentials: 凭据字典
Raises:
ValueError: 凭据验证失败
"""
# 获取供应商
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
# 获取凭据 Schema
provider_credential_schema = (
plugin_model_provider_entity.declaration.provider_credential_schema
)
if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have credential schema")
# 验证凭据格式
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)
# 调用插件验证凭据
self.plugin_model_manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=plugin_model_provider_entity.provider,
credentials=filtered_credentials,
)
return filtered_credentials1.4 使用示例
python
# 获取模型实例并调用
# 1. 创建工厂
factory = ModelProviderFactory(tenant_id="tenant_123")
# 2. 获取 LLM 实例
llm = factory.get_model_type_instance(
provider="openai",
model_type=ModelType.LLM
)
# 3. 调用模型
result = llm.invoke(
model="gpt-4",
credentials={
"openai_api_key": "sk-xxx"
},
prompt_messages=[
{
"role": "user",
"content": "你好"
}
],
stream=True
)
# 4. 处理流式响应
for chunk in result:
print(chunk.delta.message.content)二、LLM 模型抽象
2.1 LargeLanguageModel 基类
所有 LLM 都必须继承这个基类并实现核心方法:
python
# api/core/model_runtime/model_providers/__base/large_language_model.py
from abc import abstractmethod
from collections.abc import Generator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage
from .ai_model import AIModel
class LargeLanguageModel(AIModel):
"""
大语言模型基类
所有 LLM 供应商必须继承此类
"""
model_type: ModelType = ModelType.LLM
def invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
"""
调用 LLM 模型
Args:
model: 模型名称(如 "gpt-4")
credentials: 凭据字典(如 API Key)
prompt_messages: Prompt 消息列表
model_parameters: 模型参数(温度、top_p 等)
tools: 工具列表(Function Calling)
stop: 停止词列表
stream: 是否流式返回
user: 用户标识
Returns:
非流式:LLMResult
流式:Generator[LLMResultChunk]
"""
# 1. 验证凭据
self._validate_credentials(model, credentials)
# 2. 标准化参数
model_parameters = self._standardize_parameters(model_parameters)
# 3. 调用具体实现
if stream:
return self._invoke_stream(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
user
)
else:
return self._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
user
)
@abstractmethod
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
) -> LLMResult:
"""
非流式调用(子类必须实现)
"""
raise NotImplementedError
@abstractmethod
def _invoke_stream(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
流式调用(子类必须实现)
"""
raise NotImplementedError
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
计算 Token 数量
用于费用估算和限流
"""
raise NotImplementedError
def validate_credentials(
self,
model: str,
credentials: dict
) -> None:
"""
验证凭据是否有效
通过实际调用模型来验证
"""
try:
self.invoke(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping")
],
model_parameters={"max_tokens": 5},
stream=False,
)
except Exception as e:
raise CredentialsValidateFailedError(str(e))2.2 消息实体
Dify 定义了统一的消息格式:
python
# api/core/model_runtime/entities/message_entities.py
from enum import Enum
from pydantic import BaseModel
class PromptMessageRole(Enum):
"""消息角色"""
SYSTEM = "system" # 系统消息
USER = "user" # 用户消息
ASSISTANT = "assistant" # 助手消息
TOOL = "tool" # 工具消息
class PromptMessage(BaseModel):
"""Prompt 消息基类"""
role: PromptMessageRole
content: str | list
class SystemPromptMessage(PromptMessage):
"""系统消息"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
class UserPromptMessage(PromptMessage):
"""用户消息"""
role: PromptMessageRole = PromptMessageRole.USER
class AssistantPromptMessage(PromptMessage):
"""助手消息"""
role: PromptMessageRole = PromptMessageRole.ASSISTANT
# Function Calling 相关
tool_calls: list[ToolCall] = []
class ToolCall(BaseModel):
"""工具调用"""
id: str
type: str = "function"
function: ToolCallFunction
class ToolCallFunction(BaseModel):
name: str
arguments: str
# 使用示例
messages = [
SystemPromptMessage(content="你是一个友好的助手"),
UserPromptMessage(content="你好"),
AssistantPromptMessage(content="你好!有什么我可以帮助你的吗?"),
UserPromptMessage(content="介绍一下北京"),
]2.3 参数标准化
不同模型供应商的参数格式不同,需要标准化:
python
# 标准参数映射
STANDARD_PARAMETERS = {
"temperature": {
"type": "float",
"range": [0, 2],
"default": 1.0
},
"top_p": {
"type": "float",
"range": [0, 1],
"default": 1.0
},
"max_tokens": {
"type": "int",
"range": [1, None],
"default": None
},
"presence_penalty": {
"type": "float",
"range": [-2, 2],
"default": 0
},
"frequency_penalty": {
"type": "float",
"range": [-2, 2],
"default": 0
},
}
def _standardize_parameters(model_parameters: dict) -> dict:
"""
标准化模型参数
"""
standardized = {}
for key, value in model_parameters.items():
if key not in STANDARD_PARAMETERS:
continue
param_config = STANDARD_PARAMETERS[key]
# 类型转换
if param_config["type"] == "float":
value = float(value)
elif param_config["type"] == "int":
value = int(value)
# 范围检查
min_val, max_val = param_config["range"]
if min_val is not None and value < min_val:
value = min_val
if max_val is not None and value > max_val:
value = max_val
standardized[key] = value
return standardized三、OpenAI 适配实现
3.1 OpenAI LLM 实现(简化示例)
python
# 基于插件的实现思路
from openai import OpenAI
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class OpenAILargeLanguageModel(LargeLanguageModel):
"""
OpenAI LLM 实现
"""
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list | None = None,
stop: list[str] | None = None,
user: str | None = None,
) -> LLMResult:
"""
非流式调用 OpenAI
"""
# 1. 创建 OpenAI 客户端
client = OpenAI(api_key=credentials["openai_api_key"])
# 2. 转换消息格式
openai_messages = self._convert_messages(prompt_messages)
# 3. 调用 OpenAI API
response = client.chat.completions.create(
model=model,
messages=openai_messages,
temperature=model_parameters.get("temperature", 1.0),
top_p=model_parameters.get("top_p", 1.0),
max_tokens=model_parameters.get("max_tokens"),
stop=stop,
user=user,
tools=self._convert_tools(tools) if tools else None,
)
# 4. 转换响应格式
return self._convert_response(response)
def _invoke_stream(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list | None = None,
stop: list[str] | None = None,
user: str | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
流式调用 OpenAI
"""
# 1. 创建 OpenAI 客户端
client = OpenAI(api_key=credentials["openai_api_key"])
# 2. 转换消息格式
openai_messages = self._convert_messages(prompt_messages)
# 3. 流式调用 OpenAI API
stream = client.chat.completions.create(
model=model,
messages=openai_messages,
temperature=model_parameters.get("temperature", 1.0),
top_p=model_parameters.get("top_p", 1.0),
max_tokens=model_parameters.get("max_tokens"),
stop=stop,
user=user,
tools=self._convert_tools(tools) if tools else None,
stream=True, # 启用流式
)
# 4. 逐块返回
for chunk in stream:
yield self._convert_stream_chunk(chunk)
def _convert_messages(
self,
messages: list[PromptMessage]
) -> list[dict]:
"""
转换消息格式到 OpenAI 格式
"""
openai_messages = []
for message in messages:
openai_message = {
"role": message.role.value,
"content": message.content,
}
# 处理 Function Calling
if isinstance(message, AssistantPromptMessage) and message.tool_calls:
openai_message["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
}
}
for tc in message.tool_calls
]
openai_messages.append(openai_message)
return openai_messages
def _convert_response(self, response) -> LLMResult:
"""
转换 OpenAI 响应到统一格式
"""
choice = response.choices[0]
return LLMResult(
model=response.model,
prompt_messages=[], # 省略
message=AssistantPromptMessage(
content=choice.message.content or "",
tool_calls=self._extract_tool_calls(choice.message),
),
usage=LLMUsage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
),
)
def _convert_stream_chunk(self, chunk) -> LLMResultChunk:
"""
转换流式响应块
"""
delta = chunk.choices[0].delta
return LLMResultChunk(
model=chunk.model,
delta=LLMResultChunkDelta(
index=chunk.choices[0].index,
message=AssistantPromptMessage(
content=delta.content or "",
tool_calls=self._extract_tool_calls(delta) if delta.tool_calls else [],
),
finish_reason=chunk.choices[0].finish_reason,
),
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list | None = None,
) -> int:
"""
计算 Token 数量
使用 tiktoken 库
"""
import tiktoken
# 获取对应的编码器
encoding = tiktoken.encoding_for_model(model)
# 计算消息 Token
num_tokens = 0
for message in prompt_messages:
# 每条消息的格式开销
num_tokens += 4 # <|start|>role<|message|>content<|end|>
# 内容 Token
num_tokens += len(encoding.encode(message.content))
# 工具 Token
if tools:
for tool in tools:
num_tokens += len(encoding.encode(str(tool)))
return num_tokens四、Anthropic 适配对比
4.1 与 OpenAI 的主要差异
| 维度 | OpenAI | Anthropic |
|---|---|---|
| API 库 | openai | anthropic |
| 消息格式 | messages | messages(类似但有差异) |
| 系统消息 | 在 messages 中 | 单独的 system 参数 |
| 流式响应 | stream=True | stream=True |
| Token 计数 | tiktoken | anthropic.count_tokens() |
| Function Calling | tools 参数 | tools 参数(格式不同) |
4.2 Anthropic 适配示例
python
# 核心差异点示例
from anthropic import Anthropic
class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
Anthropic (Claude) LLM 实现
"""
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
**kwargs
) -> LLMResult:
# 1. 创建客户端
client = Anthropic(api_key=credentials["anthropic_api_key"])
# 2. 提取系统消息(与 OpenAI 不同)
system_message = ""
user_messages = []
for msg in prompt_messages:
if msg.role == PromptMessageRole.SYSTEM:
system_message = msg.content
else:
user_messages.append(msg)
# 3. 转换消息格式
anthropic_messages = self._convert_messages(user_messages)
# 4. 调用 Anthropic API
response = client.messages.create(
model=model,
system=system_message, # 系统消息单独传递 ⭐
messages=anthropic_messages,
max_tokens=model_parameters.get("max_tokens", 1024),
temperature=model_parameters.get("temperature", 1.0),
)
return self._convert_response(response)
def _invoke_stream(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
**kwargs
) -> Generator[LLMResultChunk, None, None]:
client = Anthropic(api_key=credentials["anthropic_api_key"])
# 提取系统消息
system_message, user_messages = self._extract_system(prompt_messages)
# 流式调用
with client.messages.stream(
model=model,
system=system_message,
messages=self._convert_messages(user_messages),
max_tokens=model_parameters.get("max_tokens", 1024),
) as stream:
for event in stream:
if event.type == "content_block_delta":
yield LLMResultChunk(
model=model,
delta=LLMResultChunkDelta(
message=AssistantPromptMessage(
content=event.delta.text
),
),
)
def _convert_messages(
self,
messages: list[PromptMessage]
) -> list[dict]:
"""
转换到 Anthropic 格式
"""
anthropic_messages = []
for message in messages:
# Anthropic 的消息格式略有不同
anthropic_message = {
"role": "user" if message.role == PromptMessageRole.USER else "assistant",
"content": message.content,
}
anthropic_messages.append(anthropic_message)
return anthropic_messages五、国产模型适配
5.1 国产模型特点
| 模型 | 供应商 | 特点 |
|---|---|---|
| 文心一言 | 百度 | API 格式接近 OpenAI |
| 通义千问 | 阿里 | 支持多模态 |
| 智谱AI | 智谱 | ChatGLM 系列 |
| 讯飞星火 | 科大讯飞 | WebSocket 流式 |
| MiniMax | MiniMax | 角色扮演能力强 |
5.2 适配挑战
python
# 以讯飞星火为例,使用 WebSocket 而非 HTTP
import websocket
import json
class SparkLLM(LargeLanguageModel):
"""
讯飞星火大模型适配
特点:使用 WebSocket 协议
"""
def _invoke_stream(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
**kwargs
) -> Generator[LLMResultChunk, None, None]:
# 1. 构建 WebSocket URL(需要签名)
ws_url = self._build_ws_url(credentials)
# 2. 构建请求数据
request_data = {
"header": {
"app_id": credentials["app_id"],
"uid": "user123"
},
"parameter": {
"chat": {
"domain": model,
"temperature": model_parameters.get("temperature", 0.5),
"max_tokens": model_parameters.get("max_tokens", 2048),
}
},
"payload": {
"message": {
"text": self._convert_messages(prompt_messages)
}
}
}
# 3. 建立 WebSocket 连接
ws = websocket.create_connection(ws_url)
try:
# 4. 发送请求
ws.send(json.dumps(request_data))
# 5. 接收流式响应
while True:
response = ws.recv()
data = json.loads(response)
# 解析响应
if data["header"]["code"] != 0:
raise Exception(data["header"]["message"])
# 提取内容
content = data["payload"]["choices"]["text"][0]["content"]
yield LLMResultChunk(
model=model,
delta=LLMResultChunkDelta(
message=AssistantPromptMessage(content=content),
),
)
# 检查是否结束
if data["header"]["status"] == 2:
break
finally:
ws.close()六、流式响应实现
6.1 流式响应原理
Client Server LLM
│ │ │
│ 1. HTTP Request │ │
│──────────────────────> │ │
│ │ 2. Invoke LLM │
│ │───────────────────────>│
│ │ │
│ │ 3. Stream Chunk 1 │
│ 4. SSE Event 1 │<───────────────────────│
│<────────────────────── │ │
│ │ │
│ │ 5. Stream Chunk 2 │
│ 6. SSE Event 2 │<───────────────────────│
│<────────────────────── │ │
│ │ │
│ │ 7. Stream Chunk N │
│ 8. SSE Event N │<───────────────────────│
│<────────────────────── │ │
│ │ │
│ 9. Connection Close │ │
│<────────────────────── │ │6.2 SSE (Server-Sent Events) 实现
python
# api/controllers/console/app/completion.py
from flask import Response, stream_with_context
@app.route('/completion-messages', methods=['POST'])
def completion():
"""
流式对话接口
"""
# 1. 调用 LLM(流式)
result_generator = llm.invoke(
model="gpt-4",
credentials=credentials,
prompt_messages=messages,
stream=True, # 启用流式
)
# 2. 定义 SSE 生成器
def generate():
"""SSE 事件生成器"""
try:
for chunk in result_generator:
# 构建 SSE 事件
event_data = {
"event": "message",
"conversation_id": conversation_id,
"message_id": message_id,
"answer": chunk.delta.message.content,
"created_at": int(time.time()),
}
# 发送 SSE 格式数据
yield f"data: {json.dumps(event_data)}\n\n"
# 发送结束事件
yield f"data: {json.dumps({'event': 'message_end'})}\n\n"
except Exception as e:
# 发送错误事件
error_data = {
"event": "error",
"message": str(e),
}
yield f"data: {json.dumps(error_data)}\n\n"
# 3. 返回 SSE 响应
return Response(
stream_with_context(generate()),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no', # 禁用 Nginx 缓冲
}
)前端接收 SSE:
typescript
// web/service/completion.ts
function streamCompletion(params: CompletionParams) {
const eventSource = new EventSource('/api/completion-messages')
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data)
switch (data.event) {
case 'message':
// 更新 UI,显示新内容
appendMessage(data.answer)
break
case 'message_end':
// 对话结束
eventSource.close()
onComplete()
break
case 'error':
// 处理错误
eventSource.close()
onError(data.message)
break
}
}
eventSource.onerror = () => {
eventSource.close()
onError('Connection failed')
}
return eventSource
}七、错误处理和重试
7.1 错误类型定义
python
# api/core/model_runtime/errors/invoke.py
class InvokeError(Exception):
"""模型调用错误基类"""
pass
class InvokeAuthorizationError(InvokeError):
"""授权错误(API Key 无效)"""
description = "Invalid API key or insufficient permissions"
class InvokeRateLimitError(InvokeError):
"""速率限制错误"""
description = "Rate limit exceeded"
class InvokeServerUnavailableError(InvokeError):
"""服务器不可用"""
description = "Model server unavailable"
class InvokeConnectionError(InvokeError):
"""连接错误"""
description = "Failed to connect to model server"
class InvokeBadRequestError(InvokeError):
"""请求参数错误"""
description = "Invalid request parameters"7.2 重试策略
python
import time
from functools import wraps
def retry_on_error(
max_retries: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
exceptions: tuple = (InvokeServerUnavailableError, InvokeConnectionError)
):
"""
错误重试装饰器
Args:
max_retries: 最大重试次数
delay: 初始延迟(秒)
backoff: 延迟倍数
exceptions: 需要重试的异常类型
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_delay = delay
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
if attempt == max_retries:
# 达到最大重试次数,抛出异常
raise
# 记录重试日志
logger.warning(
f"Attempt {attempt + 1} failed: {e}. "
f"Retrying in {current_delay}s..."
)
# 等待后重试
time.sleep(current_delay)
current_delay *= backoff
except Exception as e:
# 不可重试的异常直接抛出
raise
return wrapper
return decorator
# 使用示例
class OpenAILLM(LargeLanguageModel):
@retry_on_error(max_retries=3, delay=1.0, backoff=2.0)
def _invoke(self, model, credentials, prompt_messages, **kwargs):
"""调用时自动重试"""
client = OpenAI(api_key=credentials["openai_api_key"])
try:
response = client.chat.completions.create(...)
return self._convert_response(response)
except openai.RateLimitError as e:
raise InvokeRateLimitError(str(e))
except openai.APIConnectionError as e:
raise InvokeConnectionError(str(e))
except openai.APIError as e:
raise InvokeServerUnavailableError(str(e))八、令牌计数和配额管理
8.1 Token 计数
python
# Token 计数对于成本控制至关重要
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list | None = None,
) -> int:
"""
计算 Token 数量
"""
import tiktoken
# 获取模型对应的编码器
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
# 计算消息 Token
for message in prompt_messages:
# 消息格式开销
num_tokens += 4 # <|start|>role<|message|>content<|end|>
# 内容 Token
if isinstance(message.content, str):
num_tokens += len(encoding.encode(message.content))
elif isinstance(message.content, list):
# 多模态内容
for item in message.content:
if item.get("type") == "text":
num_tokens += len(encoding.encode(item["text"]))
elif item.get("type") == "image_url":
# 图片固定 Token 数
num_tokens += 85
# 计算工具 Token
if tools:
tools_text = json.dumps([tool.dict() for tool in tools])
num_tokens += len(encoding.encode(tools_text))
# 响应格式开销
num_tokens += 3 # <|start|>assistant<|message|>
return num_tokens8.2 配额管理
python
# api/services/model_provider_service.py
class ModelProviderService:
@staticmethod
def check_quota(
tenant_id: str,
provider: str,
model_type: ModelType,
tokens: int
):
"""
检查配额
Args:
tenant_id: 租户 ID
provider: 供应商
model_type: 模型类型
tokens: 需要的 Token 数
Raises:
QuotaExceededError: 配额不足
"""
from extensions.ext_redis import redis_client
# 1. 获取租户配额配置
quota_key = f"quota:{tenant_id}:{provider}:{model_type.value}"
quota_config = redis_client.hgetall(quota_key)
if not quota_config:
# 没有配额限制
return
# 2. 检查月度配额
monthly_limit = int(quota_config.get("monthly_limit", 0))
if monthly_limit > 0:
current_month = datetime.now().strftime("%Y-%m")
usage_key = f"usage:{tenant_id}:{provider}:{current_month}"
current_usage = int(redis_client.get(usage_key) or 0)
if current_usage + tokens > monthly_limit:
raise QuotaExceededError(
f"Monthly quota exceeded. "
f"Limit: {monthly_limit}, "
f"Used: {current_usage}, "
f"Requested: {tokens}"
)
# 3. 检查速率限制(RPM - Requests Per Minute)
rpm_limit = int(quota_config.get("rpm_limit", 0))
if rpm_limit > 0:
rpm_key = f"rpm:{tenant_id}:{provider}:{datetime.now().minute}"
current_rpm = redis_client.incr(rpm_key)
redis_client.expire(rpm_key, 60) # 1分钟后过期
if current_rpm > rpm_limit:
raise RateLimitError(
f"Rate limit exceeded. "
f"Limit: {rpm_limit} requests/min"
)
@staticmethod
def record_usage(
tenant_id: str,
provider: str,
model: str,
tokens: int,
cost: float = None
):
"""
记录使用量
"""
from extensions.ext_redis import redis_client
from extensions.ext_database import db
from models.provider import ProviderModelUsage
# 1. 更新 Redis 缓存(快速)
current_month = datetime.now().strftime("%Y-%m")
usage_key = f"usage:{tenant_id}:{provider}:{current_month}"
redis_client.incrby(usage_key, tokens)
redis_client.expire(usage_key, 60 * 60 * 24 * 35) # 35天过期
# 2. 记录到数据库(持久化)
usage = ProviderModelUsage(
tenant_id=tenant_id,
provider=provider,
model=model,
tokens=tokens,
cost=cost,
created_at=datetime.utcnow()
)
db.session.add(usage)
db.session.commit()九、实践项目
项目 1:分析 OpenAI 适配实现
目标:理解 OpenAI 模型是如何被适配的
步骤:
- 阅读
large_language_model.py基类 - 查看 OpenAI 的插件实现(通过插件系统)
- 追踪一次完整的模型调用
- 绘制调用时序图
项目 2:实现一个自定义模型适配
目标:为一个新的模型供应商创建适配器
要求:
python
class MyCustomLLM(LargeLanguageModel):
"""自定义 LLM 适配"""
def _invoke(self, model, credentials, prompt_messages, **kwargs):
# 实现非流式调用
pass
def _invoke_stream(self, model, credentials, prompt_messages, **kwargs):
# 实现流式调用
pass
def get_num_tokens(self, model, credentials, prompt_messages, **kwargs):
# 实现 Token 计数
pass项目 3:对比分析不同模型的适配差异
输出:
- OpenAI vs Anthropic 差异对比表
- 国产模型适配挑战总结
- 最佳实践建议
📚 扩展阅读
🎓 自测题
- Model Runtime 的三层架构分别是什么?各有什么作用?
- 为什么需要参数标准化?
- OpenAI 和 Anthropic 在系统消息处理上有什么区别?
- 流式响应相比非流式响应有什么优势?
- 如何实现错误重试机制?
- Token 计数在什么场景下很重要?
- 如何优化大模型的调用性能和成本?
✅ 小结
本文深入分析了 Dify 的 Model Runtime 模块:
关键要点:
- ✅ 三层架构:Factory - Provider - Model
- ✅ 统一接口:所有模型使用相同的调用方式
- ✅ 参数标准化:屏蔽不同供应商的差异
- ✅ 流式响应:通过 SSE 实现实时反馈
- ✅ 错误处理:重试机制和异常分类
- ✅ 配额管理:Token 计数和使用量控制
下一步:
学习进度: ⬜ 未开始 | 🚧 进行中 | ✅ 已完成