Skip to content

大模型接入和管理

深入理解 Dify 的 Model Runtime 架构,掌握如何统一管理数百个大模型

📖 内容概述

本文将深入分析 Dify 的 Model Runtime 模块,理解如何通过统一的抽象层管理 OpenAI、Anthropic、国产模型等数百个大模型,掌握模型调用流程、参数标准化、流式响应、错误重试等核心机制。

🎯 学习目标

  • 理解 Model Runtime 的三层架构设计
  • 掌握模型供应商(Provider)抽象
  • 理解模型类型分类和能力定义
  • 分析 OpenAI 和 Anthropic 的适配差异
  • 掌握模型调用流程和参数标准化
  • 理解流式响应的实现原理
  • 学习错误处理和重试机制
  • 掌握模型接入的最佳实践

📂 源码路径

api/core/model_runtime/
├── README.md                    # Model Runtime 说明文档 ⭐
├── __init__.py

├── entities/                    # 实体定义
│   ├── model_entities.py        # 模型实体
│   ├── message_entities.py      # 消息实体
│   ├── llm_entities.py          # LLM 相关实体
│   └── provider_entities.py     # Provider 实体

├── model_providers/             # 模型供应商实现 ⭐
│   ├── __base/                  # 基类
│   │   ├── ai_model.py          # AI 模型基类
│   │   ├── large_language_model.py  # LLM 基类 ⭐
│   │   ├── text_embedding_model.py  # Embedding 基类
│   │   ├── rerank_model.py      # Rerank 基类
│   │   ├── tts_model.py         # TTS 基类
│   │   └── speech2text_model.py # STT 基类
│   │
│   └── model_provider_factory.py  # 工厂类 ⭐

├── callbacks/                   # 回调处理
│   ├── base_callback.py
│   └── logging_callback.py

├── errors/                      # 错误定义
│   ├── invoke.py
│   └── validate.py

└── schema_validators/           # 配置验证
    ├── provider_credential_schema_validator.py
    └── model_credential_schema_validator.py

一、Model Runtime 架构设计

1.1 三层架构

Model Runtime 采用经典的三层架构,从上到下分别是:

┌─────────────────────────────────────────────────┐
│              Factory Layer(工厂层)             │
│                                                 │
│  ModelProviderFactory                           │
│  - 获取所有 Provider                            │
│  - 获取所有模型列表                             │
│  - 获取 Provider 实例                           │
│  - 验证 Provider/模型凭据                      │
└────────────────┬────────────────────────────────┘

┌────────────────▼────────────────────────────────┐
│            Provider Layer(供应商层)           │
│                                                 │
│  各个模型供应商的实现(通过插件加载)           │
│  - OpenAI Provider                              │
│  - Anthropic Provider                           │
│  - 国产模型 Provider                            │
│  - 本地模型 Provider                            │
│  ...                                            │
└────────────────┬────────────────────────────────┘

┌────────────────▼────────────────────────────────┐
│              Model Layer(模型层)              │
│                                                 │
│  各类型模型的具体实现:                         │
│  - LargeLanguageModel(LLM)                   │
│  - TextEmbeddingModel(Embedding)             │
│  - RerankModel(重排序)                        │
│  - TTSModel(文本转语音)                       │
│  - Speech2TextModel(语音转文本)              │
│  - ModerationModel(内容审核)                 │
└─────────────────────────────────────────────────┘

设计优势

  • 统一接口:所有模型供应商使用相同的接口
  • 横向扩展:添加新供应商不影响现有代码
  • 类型安全:通过抽象类强制实现必要方法
  • 插件化:通过插件系统动态加载模型

1.2 模型类型分类

Dify 支持 6 种模型类型,每种类型有不同的能力:

python
# api/core/model_runtime/entities/model_entities.py
from enum import Enum

class ModelType(Enum):
    """模型类型枚举"""
    LLM = "llm"                      # 大语言模型
    TEXT_EMBEDDING = "text-embedding"  # 文本 Embedding
    RERANK = "rerank"                 # 重排序
    SPEECH2TEXT = "speech2text"       # 语音转文本
    TTS = "tts"                       # 文本转语音
    MODERATION = "moderation"         # 内容审核

各类型的能力

模型类型主要能力典型模型
LLM文本生成、对话、推理GPT-4, Claude, 文心一言
Text Embedding文本向量化text-embedding-3, m3e
Rerank文档重排序Cohere Rerank, bge-reranker
Speech2Text语音识别Whisper, 讯飞语音
TTS语音合成Azure TTS, 讯飞TTS
Moderation内容审核OpenAI Moderation

1.3 工厂模式实现

ModelProviderFactory 是整个 Model Runtime 的入口

python
# api/core/model_runtime/model_providers/model_provider_factory.py
class ModelProviderFactory:
    """
    模型供应商工厂类
    负责获取和管理所有模型供应商
    """
    
    def __init__(self, tenant_id: str):
        """
        初始化工厂
        
        Args:
            tenant_id: 租户 ID(多租户隔离)
        """
        from core.plugin.impl.model import PluginModelClient
        
        self.tenant_id = tenant_id
        self.plugin_model_manager = PluginModelClient()
    
    def get_providers(self) -> Sequence[ProviderEntity]:
        """
        获取所有模型供应商
        
        Returns:
            所有供应商的列表
        """
        # 从插件服务器获取所有 Provider
        plugin_providers = self.get_plugin_model_providers()
        return [provider.declaration for provider in plugin_providers]
    
    def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
        """
        获取所有插件模型供应商
        使用上下文缓存避免重复获取
        """
        # 获取或初始化上下文
        try:
            contexts.plugin_model_providers.get()
        except LookupError:
            contexts.plugin_model_providers.set(None)
            contexts.plugin_model_providers_lock.set(Lock())
        
        # 使用锁保证线程安全
        with contexts.plugin_model_providers_lock.get():
            plugin_model_providers = contexts.plugin_model_providers.get()
            if plugin_model_providers is not None:
                return plugin_model_providers
            
            # 从插件管理器获取 Provider
            plugin_providers = self.plugin_model_manager.fetch_model_providers(
                self.tenant_id
            )
            
            # 缓存结果
            plugin_model_providers = []
            for provider in plugin_providers:
                provider.declaration.provider = (
                    provider.plugin_id + "/" + provider.declaration.provider
                )
                plugin_model_providers.append(provider)
            
            contexts.plugin_model_providers.set(plugin_model_providers)
            return plugin_model_providers
    
    def get_model_type_instance(
        self, 
        provider: str, 
        model_type: ModelType
    ) -> AIModel:
        """
        获取模型类型实例
        
        Args:
            provider: 供应商名称(如 "openai")
            model_type: 模型类型
            
        Returns:
            对应类型的模型实例
        """
        plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(
            provider
        )
        
        # 初始化参数
        init_params = {
            "tenant_id": self.tenant_id,
            "plugin_id": plugin_id,
            "provider_name": provider_name,
            "plugin_model_provider": self.get_plugin_model_provider(provider),
        }
        
        # 根据类型创建实例
        if model_type == ModelType.LLM:
            return LargeLanguageModel.model_validate(init_params)
        elif model_type == ModelType.TEXT_EMBEDDING:
            return TextEmbeddingModel.model_validate(init_params)
        elif model_type == ModelType.RERANK:
            return RerankModel.model_validate(init_params)
        elif model_type == ModelType.SPEECH2TEXT:
            return Speech2TextModel.model_validate(init_params)
        elif model_type == ModelType.MODERATION:
            return ModerationModel.model_validate(init_params)
        elif model_type == ModelType.TTS:
            return TTSModel.model_validate(init_params)
    
    def provider_credentials_validate(
        self, 
        *, 
        provider: str, 
        credentials: dict
    ):
        """
        验证供应商凭据
        
        Args:
            provider: 供应商名称
            credentials: 凭据字典
            
        Raises:
            ValueError: 凭据验证失败
        """
        # 获取供应商
        plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
        
        # 获取凭据 Schema
        provider_credential_schema = (
            plugin_model_provider_entity.declaration.provider_credential_schema
        )
        if not provider_credential_schema:
            raise ValueError(f"Provider {provider} does not have credential schema")
        
        # 验证凭据格式
        validator = ProviderCredentialSchemaValidator(provider_credential_schema)
        filtered_credentials = validator.validate_and_filter(credentials)
        
        # 调用插件验证凭据
        self.plugin_model_manager.validate_provider_credentials(
            tenant_id=self.tenant_id,
            user_id="unknown",
            plugin_id=plugin_model_provider_entity.plugin_id,
            provider=plugin_model_provider_entity.provider,
            credentials=filtered_credentials,
        )
        
        return filtered_credentials

1.4 使用示例

python
# 获取模型实例并调用

# 1. 创建工厂
factory = ModelProviderFactory(tenant_id="tenant_123")

# 2. 获取 LLM 实例
llm = factory.get_model_type_instance(
    provider="openai",
    model_type=ModelType.LLM
)

# 3. 调用模型
result = llm.invoke(
    model="gpt-4",
    credentials={
        "openai_api_key": "sk-xxx"
    },
    prompt_messages=[
        {
            "role": "user",
            "content": "你好"
        }
    ],
    stream=True
)

# 4. 处理流式响应
for chunk in result:
    print(chunk.delta.message.content)

二、LLM 模型抽象

2.1 LargeLanguageModel 基类

所有 LLM 都必须继承这个基类并实现核心方法:

python
# api/core/model_runtime/model_providers/__base/large_language_model.py
from abc import abstractmethod
from collections.abc import Generator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage
from .ai_model import AIModel

class LargeLanguageModel(AIModel):
    """
    大语言模型基类
    所有 LLM 供应商必须继承此类
    """
    
    model_type: ModelType = ModelType.LLM
    
    def invoke(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict | None = None,
        tools: list[PromptMessageTool] | None = None,
        stop: list[str] | None = None,
        stream: bool = True,
        user: str | None = None,
    ) -> LLMResult | Generator[LLMResultChunk, None, None]:
        """
        调用 LLM 模型
        
        Args:
            model: 模型名称(如 "gpt-4")
            credentials: 凭据字典(如 API Key)
            prompt_messages: Prompt 消息列表
            model_parameters: 模型参数(温度、top_p 等)
            tools: 工具列表(Function Calling)
            stop: 停止词列表
            stream: 是否流式返回
            user: 用户标识
            
        Returns:
            非流式:LLMResult
            流式:Generator[LLMResultChunk]
        """
        # 1. 验证凭据
        self._validate_credentials(model, credentials)
        
        # 2. 标准化参数
        model_parameters = self._standardize_parameters(model_parameters)
        
        # 3. 调用具体实现
        if stream:
            return self._invoke_stream(
                model, 
                credentials, 
                prompt_messages,
                model_parameters, 
                tools, 
                stop, 
                user
            )
        else:
            return self._invoke(
                model, 
                credentials, 
                prompt_messages,
                model_parameters, 
                tools, 
                stop, 
                user
            )
    
    @abstractmethod
    def _invoke(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        tools: list[PromptMessageTool] | None = None,
        stop: list[str] | None = None,
        user: str | None = None,
    ) -> LLMResult:
        """
        非流式调用(子类必须实现)
        """
        raise NotImplementedError
    
    @abstractmethod
    def _invoke_stream(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        tools: list[PromptMessageTool] | None = None,
        stop: list[str] | None = None,
        user: str | None = None,
    ) -> Generator[LLMResultChunk, None, None]:
        """
        流式调用(子类必须实现)
        """
        raise NotImplementedError
    
    def get_num_tokens(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        tools: list[PromptMessageTool] | None = None,
    ) -> int:
        """
        计算 Token 数量
        用于费用估算和限流
        """
        raise NotImplementedError
    
    def validate_credentials(
        self, 
        model: str, 
        credentials: dict
    ) -> None:
        """
        验证凭据是否有效
        通过实际调用模型来验证
        """
        try:
            self.invoke(
                model=model,
                credentials=credentials,
                prompt_messages=[
                    UserPromptMessage(content="ping")
                ],
                model_parameters={"max_tokens": 5},
                stream=False,
            )
        except Exception as e:
            raise CredentialsValidateFailedError(str(e))

2.2 消息实体

Dify 定义了统一的消息格式:

python
# api/core/model_runtime/entities/message_entities.py
from enum import Enum
from pydantic import BaseModel

class PromptMessageRole(Enum):
    """消息角色"""
    SYSTEM = "system"      # 系统消息
    USER = "user"          # 用户消息
    ASSISTANT = "assistant"  # 助手消息
    TOOL = "tool"          # 工具消息


class PromptMessage(BaseModel):
    """Prompt 消息基类"""
    role: PromptMessageRole
    content: str | list


class SystemPromptMessage(PromptMessage):
    """系统消息"""
    role: PromptMessageRole = PromptMessageRole.SYSTEM


class UserPromptMessage(PromptMessage):
    """用户消息"""
    role: PromptMessageRole = PromptMessageRole.USER


class AssistantPromptMessage(PromptMessage):
    """助手消息"""
    role: PromptMessageRole = PromptMessageRole.ASSISTANT
    
    # Function Calling 相关
    tool_calls: list[ToolCall] = []
    
    class ToolCall(BaseModel):
        """工具调用"""
        id: str
        type: str = "function"
        function: ToolCallFunction
        
        class ToolCallFunction(BaseModel):
            name: str
            arguments: str


# 使用示例
messages = [
    SystemPromptMessage(content="你是一个友好的助手"),
    UserPromptMessage(content="你好"),
    AssistantPromptMessage(content="你好!有什么我可以帮助你的吗?"),
    UserPromptMessage(content="介绍一下北京"),
]

2.3 参数标准化

不同模型供应商的参数格式不同,需要标准化:

python
# 标准参数映射
STANDARD_PARAMETERS = {
    "temperature": {
        "type": "float",
        "range": [0, 2],
        "default": 1.0
    },
    "top_p": {
        "type": "float",
        "range": [0, 1],
        "default": 1.0
    },
    "max_tokens": {
        "type": "int",
        "range": [1, None],
        "default": None
    },
    "presence_penalty": {
        "type": "float",
        "range": [-2, 2],
        "default": 0
    },
    "frequency_penalty": {
        "type": "float",
        "range": [-2, 2],
        "default": 0
    },
}

def _standardize_parameters(model_parameters: dict) -> dict:
    """
    标准化模型参数
    """
    standardized = {}
    
    for key, value in model_parameters.items():
        if key not in STANDARD_PARAMETERS:
            continue
        
        param_config = STANDARD_PARAMETERS[key]
        
        # 类型转换
        if param_config["type"] == "float":
            value = float(value)
        elif param_config["type"] == "int":
            value = int(value)
        
        # 范围检查
        min_val, max_val = param_config["range"]
        if min_val is not None and value < min_val:
            value = min_val
        if max_val is not None and value > max_val:
            value = max_val
        
        standardized[key] = value
    
    return standardized

三、OpenAI 适配实现

3.1 OpenAI LLM 实现(简化示例)

python
# 基于插件的实现思路
from openai import OpenAI
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

class OpenAILargeLanguageModel(LargeLanguageModel):
    """
    OpenAI LLM 实现
    """
    
    def _invoke(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        tools: list | None = None,
        stop: list[str] | None = None,
        user: str | None = None,
    ) -> LLMResult:
        """
        非流式调用 OpenAI
        """
        # 1. 创建 OpenAI 客户端
        client = OpenAI(api_key=credentials["openai_api_key"])
        
        # 2. 转换消息格式
        openai_messages = self._convert_messages(prompt_messages)
        
        # 3. 调用 OpenAI API
        response = client.chat.completions.create(
            model=model,
            messages=openai_messages,
            temperature=model_parameters.get("temperature", 1.0),
            top_p=model_parameters.get("top_p", 1.0),
            max_tokens=model_parameters.get("max_tokens"),
            stop=stop,
            user=user,
            tools=self._convert_tools(tools) if tools else None,
        )
        
        # 4. 转换响应格式
        return self._convert_response(response)
    
    def _invoke_stream(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        tools: list | None = None,
        stop: list[str] | None = None,
        user: str | None = None,
    ) -> Generator[LLMResultChunk, None, None]:
        """
        流式调用 OpenAI
        """
        # 1. 创建 OpenAI 客户端
        client = OpenAI(api_key=credentials["openai_api_key"])
        
        # 2. 转换消息格式
        openai_messages = self._convert_messages(prompt_messages)
        
        # 3. 流式调用 OpenAI API
        stream = client.chat.completions.create(
            model=model,
            messages=openai_messages,
            temperature=model_parameters.get("temperature", 1.0),
            top_p=model_parameters.get("top_p", 1.0),
            max_tokens=model_parameters.get("max_tokens"),
            stop=stop,
            user=user,
            tools=self._convert_tools(tools) if tools else None,
            stream=True,  # 启用流式
        )
        
        # 4. 逐块返回
        for chunk in stream:
            yield self._convert_stream_chunk(chunk)
    
    def _convert_messages(
        self, 
        messages: list[PromptMessage]
    ) -> list[dict]:
        """
        转换消息格式到 OpenAI 格式
        """
        openai_messages = []
        
        for message in messages:
            openai_message = {
                "role": message.role.value,
                "content": message.content,
            }
            
            # 处理 Function Calling
            if isinstance(message, AssistantPromptMessage) and message.tool_calls:
                openai_message["tool_calls"] = [
                    {
                        "id": tc.id,
                        "type": "function",
                        "function": {
                            "name": tc.function.name,
                            "arguments": tc.function.arguments,
                        }
                    }
                    for tc in message.tool_calls
                ]
            
            openai_messages.append(openai_message)
        
        return openai_messages
    
    def _convert_response(self, response) -> LLMResult:
        """
        转换 OpenAI 响应到统一格式
        """
        choice = response.choices[0]
        
        return LLMResult(
            model=response.model,
            prompt_messages=[],  # 省略
            message=AssistantPromptMessage(
                content=choice.message.content or "",
                tool_calls=self._extract_tool_calls(choice.message),
            ),
            usage=LLMUsage(
                prompt_tokens=response.usage.prompt_tokens,
                completion_tokens=response.usage.completion_tokens,
                total_tokens=response.usage.total_tokens,
            ),
        )
    
    def _convert_stream_chunk(self, chunk) -> LLMResultChunk:
        """
        转换流式响应块
        """
        delta = chunk.choices[0].delta
        
        return LLMResultChunk(
            model=chunk.model,
            delta=LLMResultChunkDelta(
                index=chunk.choices[0].index,
                message=AssistantPromptMessage(
                    content=delta.content or "",
                    tool_calls=self._extract_tool_calls(delta) if delta.tool_calls else [],
                ),
                finish_reason=chunk.choices[0].finish_reason,
            ),
        )
    
    def get_num_tokens(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        tools: list | None = None,
    ) -> int:
        """
        计算 Token 数量
        使用 tiktoken 库
        """
        import tiktoken
        
        # 获取对应的编码器
        encoding = tiktoken.encoding_for_model(model)
        
        # 计算消息 Token
        num_tokens = 0
        for message in prompt_messages:
            # 每条消息的格式开销
            num_tokens += 4  # <|start|>role<|message|>content<|end|>
            
            # 内容 Token
            num_tokens += len(encoding.encode(message.content))
        
        # 工具 Token
        if tools:
            for tool in tools:
                num_tokens += len(encoding.encode(str(tool)))
        
        return num_tokens

四、Anthropic 适配对比

4.1 与 OpenAI 的主要差异

维度OpenAIAnthropic
API 库openaianthropic
消息格式messagesmessages(类似但有差异)
系统消息在 messages 中单独的 system 参数
流式响应stream=Truestream=True
Token 计数tiktokenanthropic.count_tokens()
Function Callingtools 参数tools 参数(格式不同)

4.2 Anthropic 适配示例

python
# 核心差异点示例
from anthropic import Anthropic

class AnthropicLargeLanguageModel(LargeLanguageModel):
    """
    Anthropic (Claude) LLM 实现
    """
    
    def _invoke(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        **kwargs
    ) -> LLMResult:
        # 1. 创建客户端
        client = Anthropic(api_key=credentials["anthropic_api_key"])
        
        # 2. 提取系统消息(与 OpenAI 不同)
        system_message = ""
        user_messages = []
        for msg in prompt_messages:
            if msg.role == PromptMessageRole.SYSTEM:
                system_message = msg.content
            else:
                user_messages.append(msg)
        
        # 3. 转换消息格式
        anthropic_messages = self._convert_messages(user_messages)
        
        # 4. 调用 Anthropic API
        response = client.messages.create(
            model=model,
            system=system_message,  # 系统消息单独传递 ⭐
            messages=anthropic_messages,
            max_tokens=model_parameters.get("max_tokens", 1024),
            temperature=model_parameters.get("temperature", 1.0),
        )
        
        return self._convert_response(response)
    
    def _invoke_stream(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        **kwargs
    ) -> Generator[LLMResultChunk, None, None]:
        client = Anthropic(api_key=credentials["anthropic_api_key"])
        
        # 提取系统消息
        system_message, user_messages = self._extract_system(prompt_messages)
        
        # 流式调用
        with client.messages.stream(
            model=model,
            system=system_message,
            messages=self._convert_messages(user_messages),
            max_tokens=model_parameters.get("max_tokens", 1024),
        ) as stream:
            for event in stream:
                if event.type == "content_block_delta":
                    yield LLMResultChunk(
                        model=model,
                        delta=LLMResultChunkDelta(
                            message=AssistantPromptMessage(
                                content=event.delta.text
                            ),
                        ),
                    )
    
    def _convert_messages(
        self, 
        messages: list[PromptMessage]
    ) -> list[dict]:
        """
        转换到 Anthropic 格式
        """
        anthropic_messages = []
        
        for message in messages:
            # Anthropic 的消息格式略有不同
            anthropic_message = {
                "role": "user" if message.role == PromptMessageRole.USER else "assistant",
                "content": message.content,
            }
            anthropic_messages.append(anthropic_message)
        
        return anthropic_messages

五、国产模型适配

5.1 国产模型特点

模型供应商特点
文心一言百度API 格式接近 OpenAI
通义千问阿里支持多模态
智谱AI智谱ChatGLM 系列
讯飞星火科大讯飞WebSocket 流式
MiniMaxMiniMax角色扮演能力强

5.2 适配挑战

python
# 以讯飞星火为例,使用 WebSocket 而非 HTTP
import websocket
import json

class SparkLLM(LargeLanguageModel):
    """
    讯飞星火大模型适配
    特点:使用 WebSocket 协议
    """
    
    def _invoke_stream(
        self,
        model: str,
        credentials: dict,
        prompt_messages: list[PromptMessage],
        model_parameters: dict,
        **kwargs
    ) -> Generator[LLMResultChunk, None, None]:
        # 1. 构建 WebSocket URL(需要签名)
        ws_url = self._build_ws_url(credentials)
        
        # 2. 构建请求数据
        request_data = {
            "header": {
                "app_id": credentials["app_id"],
                "uid": "user123"
            },
            "parameter": {
                "chat": {
                    "domain": model,
                    "temperature": model_parameters.get("temperature", 0.5),
                    "max_tokens": model_parameters.get("max_tokens", 2048),
                }
            },
            "payload": {
                "message": {
                    "text": self._convert_messages(prompt_messages)
                }
            }
        }
        
        # 3. 建立 WebSocket 连接
        ws = websocket.create_connection(ws_url)
        
        try:
            # 4. 发送请求
            ws.send(json.dumps(request_data))
            
            # 5. 接收流式响应
            while True:
                response = ws.recv()
                data = json.loads(response)
                
                # 解析响应
                if data["header"]["code"] != 0:
                    raise Exception(data["header"]["message"])
                
                # 提取内容
                content = data["payload"]["choices"]["text"][0]["content"]
                yield LLMResultChunk(
                    model=model,
                    delta=LLMResultChunkDelta(
                        message=AssistantPromptMessage(content=content),
                    ),
                )
                
                # 检查是否结束
                if data["header"]["status"] == 2:
                    break
        finally:
            ws.close()

六、流式响应实现

6.1 流式响应原理

Client                    Server                    LLM
  │                         │                        │
  │  1. HTTP Request       │                        │
  │──────────────────────> │                        │
  │                         │  2. Invoke LLM        │
  │                         │───────────────────────>│
  │                         │                        │
  │                         │  3. Stream Chunk 1    │
  │  4. SSE Event 1        │<───────────────────────│
  │<──────────────────────  │                        │
  │                         │                        │
  │                         │  5. Stream Chunk 2    │
  │  6. SSE Event 2        │<───────────────────────│
  │<──────────────────────  │                        │
  │                         │                        │
  │                         │  7. Stream Chunk N    │
  │  8. SSE Event N        │<───────────────────────│
  │<──────────────────────  │                        │
  │                         │                        │
  │  9. Connection Close   │                        │
  │<──────────────────────  │                        │

6.2 SSE (Server-Sent Events) 实现

python
# api/controllers/console/app/completion.py
from flask import Response, stream_with_context

@app.route('/completion-messages', methods=['POST'])
def completion():
    """
    流式对话接口
    """
    # 1. 调用 LLM(流式)
    result_generator = llm.invoke(
        model="gpt-4",
        credentials=credentials,
        prompt_messages=messages,
        stream=True,  # 启用流式
    )
    
    # 2. 定义 SSE 生成器
    def generate():
        """SSE 事件生成器"""
        try:
            for chunk in result_generator:
                # 构建 SSE 事件
                event_data = {
                    "event": "message",
                    "conversation_id": conversation_id,
                    "message_id": message_id,
                    "answer": chunk.delta.message.content,
                    "created_at": int(time.time()),
                }
                
                # 发送 SSE 格式数据
                yield f"data: {json.dumps(event_data)}\n\n"
            
            # 发送结束事件
            yield f"data: {json.dumps({'event': 'message_end'})}\n\n"
            
        except Exception as e:
            # 发送错误事件
            error_data = {
                "event": "error",
                "message": str(e),
            }
            yield f"data: {json.dumps(error_data)}\n\n"
    
    # 3. 返回 SSE 响应
    return Response(
        stream_with_context(generate()),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'X-Accel-Buffering': 'no',  # 禁用 Nginx 缓冲
        }
    )

前端接收 SSE

typescript
// web/service/completion.ts
function streamCompletion(params: CompletionParams) {
  const eventSource = new EventSource('/api/completion-messages')
  
  eventSource.onmessage = (event) => {
    const data = JSON.parse(event.data)
    
    switch (data.event) {
      case 'message':
        // 更新 UI,显示新内容
        appendMessage(data.answer)
        break
      
      case 'message_end':
        // 对话结束
        eventSource.close()
        onComplete()
        break
      
      case 'error':
        // 处理错误
        eventSource.close()
        onError(data.message)
        break
    }
  }
  
  eventSource.onerror = () => {
    eventSource.close()
    onError('Connection failed')
  }
  
  return eventSource
}

七、错误处理和重试

7.1 错误类型定义

python
# api/core/model_runtime/errors/invoke.py
class InvokeError(Exception):
    """模型调用错误基类"""
    pass


class InvokeAuthorizationError(InvokeError):
    """授权错误(API Key 无效)"""
    description = "Invalid API key or insufficient permissions"


class InvokeRateLimitError(InvokeError):
    """速率限制错误"""
    description = "Rate limit exceeded"


class InvokeServerUnavailableError(InvokeError):
    """服务器不可用"""
    description = "Model server unavailable"


class InvokeConnectionError(InvokeError):
    """连接错误"""
    description = "Failed to connect to model server"


class InvokeBadRequestError(InvokeError):
    """请求参数错误"""
    description = "Invalid request parameters"

7.2 重试策略

python
import time
from functools import wraps

def retry_on_error(
    max_retries: int = 3,
    delay: float = 1.0,
    backoff: float = 2.0,
    exceptions: tuple = (InvokeServerUnavailableError, InvokeConnectionError)
):
    """
    错误重试装饰器
    
    Args:
        max_retries: 最大重试次数
        delay: 初始延迟(秒)
        backoff: 延迟倍数
        exceptions: 需要重试的异常类型
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            current_delay = delay
            
            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                    
                except exceptions as e:
                    if attempt == max_retries:
                        # 达到最大重试次数,抛出异常
                        raise
                    
                    # 记录重试日志
                    logger.warning(
                        f"Attempt {attempt + 1} failed: {e}. "
                        f"Retrying in {current_delay}s..."
                    )
                    
                    # 等待后重试
                    time.sleep(current_delay)
                    current_delay *= backoff
                    
                except Exception as e:
                    # 不可重试的异常直接抛出
                    raise
        
        return wrapper
    return decorator


# 使用示例
class OpenAILLM(LargeLanguageModel):
    
    @retry_on_error(max_retries=3, delay=1.0, backoff=2.0)
    def _invoke(self, model, credentials, prompt_messages, **kwargs):
        """调用时自动重试"""
        client = OpenAI(api_key=credentials["openai_api_key"])
        
        try:
            response = client.chat.completions.create(...)
            return self._convert_response(response)
            
        except openai.RateLimitError as e:
            raise InvokeRateLimitError(str(e))
        except openai.APIConnectionError as e:
            raise InvokeConnectionError(str(e))
        except openai.APIError as e:
            raise InvokeServerUnavailableError(str(e))

八、令牌计数和配额管理

8.1 Token 计数

python
# Token 计数对于成本控制至关重要
def get_num_tokens(
    self,
    model: str,
    credentials: dict,
    prompt_messages: list[PromptMessage],
    tools: list | None = None,
) -> int:
    """
    计算 Token 数量
    """
    import tiktoken
    
    # 获取模型对应的编码器
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    
    num_tokens = 0
    
    # 计算消息 Token
    for message in prompt_messages:
        # 消息格式开销
        num_tokens += 4  # <|start|>role<|message|>content<|end|>
        
        # 内容 Token
        if isinstance(message.content, str):
            num_tokens += len(encoding.encode(message.content))
        elif isinstance(message.content, list):
            # 多模态内容
            for item in message.content:
                if item.get("type") == "text":
                    num_tokens += len(encoding.encode(item["text"]))
                elif item.get("type") == "image_url":
                    # 图片固定 Token 数
                    num_tokens += 85
    
    # 计算工具 Token
    if tools:
        tools_text = json.dumps([tool.dict() for tool in tools])
        num_tokens += len(encoding.encode(tools_text))
    
    # 响应格式开销
    num_tokens += 3  # <|start|>assistant<|message|>
    
    return num_tokens

8.2 配额管理

python
# api/services/model_provider_service.py
class ModelProviderService:
    @staticmethod
    def check_quota(
        tenant_id: str,
        provider: str,
        model_type: ModelType,
        tokens: int
    ):
        """
        检查配额
        
        Args:
            tenant_id: 租户 ID
            provider: 供应商
            model_type: 模型类型
            tokens: 需要的 Token 数
            
        Raises:
            QuotaExceededError: 配额不足
        """
        from extensions.ext_redis import redis_client
        
        # 1. 获取租户配额配置
        quota_key = f"quota:{tenant_id}:{provider}:{model_type.value}"
        quota_config = redis_client.hgetall(quota_key)
        
        if not quota_config:
            # 没有配额限制
            return
        
        # 2. 检查月度配额
        monthly_limit = int(quota_config.get("monthly_limit", 0))
        if monthly_limit > 0:
            current_month = datetime.now().strftime("%Y-%m")
            usage_key = f"usage:{tenant_id}:{provider}:{current_month}"
            current_usage = int(redis_client.get(usage_key) or 0)
            
            if current_usage + tokens > monthly_limit:
                raise QuotaExceededError(
                    f"Monthly quota exceeded. "
                    f"Limit: {monthly_limit}, "
                    f"Used: {current_usage}, "
                    f"Requested: {tokens}"
                )
        
        # 3. 检查速率限制(RPM - Requests Per Minute)
        rpm_limit = int(quota_config.get("rpm_limit", 0))
        if rpm_limit > 0:
            rpm_key = f"rpm:{tenant_id}:{provider}:{datetime.now().minute}"
            current_rpm = redis_client.incr(rpm_key)
            redis_client.expire(rpm_key, 60)  # 1分钟后过期
            
            if current_rpm > rpm_limit:
                raise RateLimitError(
                    f"Rate limit exceeded. "
                    f"Limit: {rpm_limit} requests/min"
                )
    
    @staticmethod
    def record_usage(
        tenant_id: str,
        provider: str,
        model: str,
        tokens: int,
        cost: float = None
    ):
        """
        记录使用量
        """
        from extensions.ext_redis import redis_client
        from extensions.ext_database import db
        from models.provider import ProviderModelUsage
        
        # 1. 更新 Redis 缓存(快速)
        current_month = datetime.now().strftime("%Y-%m")
        usage_key = f"usage:{tenant_id}:{provider}:{current_month}"
        redis_client.incrby(usage_key, tokens)
        redis_client.expire(usage_key, 60 * 60 * 24 * 35)  # 35天过期
        
        # 2. 记录到数据库(持久化)
        usage = ProviderModelUsage(
            tenant_id=tenant_id,
            provider=provider,
            model=model,
            tokens=tokens,
            cost=cost,
            created_at=datetime.utcnow()
        )
        db.session.add(usage)
        db.session.commit()

九、实践项目

项目 1:分析 OpenAI 适配实现

目标:理解 OpenAI 模型是如何被适配的

步骤

  1. 阅读 large_language_model.py 基类
  2. 查看 OpenAI 的插件实现(通过插件系统)
  3. 追踪一次完整的模型调用
  4. 绘制调用时序图

项目 2:实现一个自定义模型适配

目标:为一个新的模型供应商创建适配器

要求

python
class MyCustomLLM(LargeLanguageModel):
    """自定义 LLM 适配"""
    
    def _invoke(self, model, credentials, prompt_messages, **kwargs):
        # 实现非流式调用
        pass
    
    def _invoke_stream(self, model, credentials, prompt_messages, **kwargs):
        # 实现流式调用
        pass
    
    def get_num_tokens(self, model, credentials, prompt_messages, **kwargs):
        # 实现 Token 计数
        pass

项目 3:对比分析不同模型的适配差异

输出

  • OpenAI vs Anthropic 差异对比表
  • 国产模型适配挑战总结
  • 最佳实践建议

📚 扩展阅读

🎓 自测题

  1. Model Runtime 的三层架构分别是什么?各有什么作用?
  2. 为什么需要参数标准化?
  3. OpenAI 和 Anthropic 在系统消息处理上有什么区别?
  4. 流式响应相比非流式响应有什么优势?
  5. 如何实现错误重试机制?
  6. Token 计数在什么场景下很重要?
  7. 如何优化大模型的调用性能和成本?

✅ 小结

本文深入分析了 Dify 的 Model Runtime 模块:

关键要点

  • ✅ 三层架构:Factory - Provider - Model
  • ✅ 统一接口:所有模型使用相同的调用方式
  • ✅ 参数标准化:屏蔽不同供应商的差异
  • ✅ 流式响应:通过 SSE 实现实时反馈
  • ✅ 错误处理:重试机制和异常分类
  • ✅ 配额管理:Token 计数和使用量控制

下一步


学习进度: ⬜ 未开始 | 🚧 进行中 | ✅ 已完成