Redis案例应用--Dify源码阅读

概览

相关应用源码

缓存

相关源码位置

  • api/core/rag/embedding/cached_embedding.py

    对嵌入向量进行缓存。

  • api/services/account_service.py

  • ……

向量嵌入缓存

  • 缓存key构造参数:模型提供者、模型版本、输入文本

            embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
    
  • 使用setex为缓存key设置值和过期时间10min。

            try:
                # encode embedding to base64
                embedding_vector = np.array(embedding_results)
                vector_bytes = embedding_vector.tobytes()
                # Transform to Base64
                encoded_vector = base64.b64encode(vector_bytes)
                # Transform to string
                encoded_str = encoded_vector.decode("utf-8")
                redis_client.setex(embedding_cache_key, 600, encoded_str)
            except Exception as ex:
                if dify_config.DEBUG:
                    logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
                raise ex
    
            return embedding_results
    
  • 命中缓存后通过expire更新过期时间为10min。

           embedding = redis_client.get(embedding_cache_key)
            if embedding:
                redis_client.expire(embedding_cache_key, 600)
                decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
                return [float(x) for x in decoded_embedding]
    
    def embed_query(self, text: str) -> list[float]:
        """Embed query text."""
        # use doc embedding cache or store if not exists
        hash = helper.generate_text_hash(text)
        embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
        embedding = redis_client.get(embedding_cache_key)
        if embedding:
            redis_client.expire(embedding_cache_key, 600)
            decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
            return [float(x) for x in decoded_embedding]
        try:
            embedding_result = self._model_instance.invoke_text_embedding(
                texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
            )

            embedding_results = embedding_result.embeddings[0]
            # FIXME: type ignore for numpy here
            embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()  # type: ignore
            if np.isnan(embedding_results).any():
                raise ValueError("Normalized embedding is nan please try again")
        except Exception as ex:
            if dify_config.DEBUG:
                logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
            raise ex

        try:
            # encode embedding to base64
            embedding_vector = np.array(embedding_results)
            vector_bytes = embedding_vector.tobytes()
            # Transform to Base64
            encoded_vector = base64.b64encode(vector_bytes)
            # Transform to string
            encoded_str = encoded_vector.decode("utf-8")
            redis_client.setex(embedding_cache_key, 600, encoded_str)
        except Exception as ex:
            if dify_config.DEBUG:
                logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
            raise ex

        return embedding_results

模型凭证缓存

api/core/helper/model_provider_cache.py

class ProviderCredentialsCacheType(Enum):
    PROVIDER = "provider"
    MODEL = "provider_model"
    LOAD_BALANCING_MODEL = "load_balancing_provider_model"


class ProviderCredentialsCache:
    def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
        self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}🆔{identity_id}"

    def get(self) -> Optional[dict]:
        """
        Get cached model provider credentials.

        :return:
        """
        cached_provider_credentials = redis_client.get(self.cache_key)
        if cached_provider_credentials:
            try:
                cached_provider_credentials = cached_provider_credentials.decode("utf-8")
                cached_provider_credentials = json.loads(cached_provider_credentials)
            except JSONDecodeError:
                return None

            return dict(cached_provider_credentials)
        else:
            return None

    def set(self, credentials: dict) -> None:
        """
        Cache model provider credentials.

        :param credentials: provider credentials
        :return:
        """
        redis_client.setex(self.cache_key, 86400, json.dumps(credentials))

    def delete(self) -> None:
        """
        Delete cached model provider credentials.

        :return:
        """
        redis_client.delete(self.cache_key)

私钥缓存

api/libs/rsa.py

def get_decrypt_decoding(tenant_id):
    filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"

    cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
    private_key = redis_client.get(cache_key)
    if not private_key:
        try:
            private_key = storage.load(filepath)
        except FileNotFoundError:
            raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))

        redis_client.setex(cache_key, 120, private_key)

    rsa_key = RSA.import_key(private_key)
    cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)

    return rsa_key, cipher_rsa

账号refresh_token

class AccountService:
    reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
    email_code_login_rate_limiter = RateLimiter(
        prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
    )
    email_code_account_deletion_rate_limiter = RateLimiter(
        prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
    )
    LOGIN_MAX_ERROR_LIMITS = 5
    FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5

    @staticmethod
    def _get_refresh_token_key(refresh_token: str) -> str:
        return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"

    @staticmethod
    def _get_account_refresh_token_key(account_id: str) -> str:
        return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"

    @staticmethod
    def _store_refresh_token(refresh_token: str, account_id: str) -> None:
        redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
        redis_client.setex(
            AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
        )

    @staticmethod
    def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
        redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
        redis_client.delete(AccountService._get_account_refresh_token_key(account_id))

计数器

相关源码位置

  • api/tasks/ops_trace_task.py

  • ……

源码示例

执行trace记录时出异常则在redis中记录失败次数。

redis_client.incr(failed_key)
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
    """
    Async process trace tasks
    Usage: process_trace_tasks.delay(tasks_data)
    """
    from core.ops.ops_trace_manager import OpsTraceManager

    app_id = file_info.get("app_id")
    file_id = file_info.get("file_id")
    file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
    file_data = json.loads(storage.load(file_path))
    trace_info = file_data.get("trace_info")
    trace_info_type = file_data.get("trace_info_type")
    trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)

    if trace_info.get("message_data"):
        trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"])
    if trace_info.get("workflow_data"):
        trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
    if trace_info.get("documents"):
        trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]]

    try:
        if trace_instance:
            with current_app.app_context():
                trace_type = trace_info_info_map.get(trace_info_type)
                if trace_type:
                    trace_info = trace_type(**trace_info)
                trace_instance.trace(trace_info)
        logging.info(f"Processing trace tasks success, app_id: {app_id}")
    except Exception:
        failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
        redis_client.incr(failed_key)
        logging.info(f"Processing trace tasks failed, app_id: {app_id}")
    finally:
        storage.delete(file_path)

分布式锁(Lock)

场景:数据库迁移

源码:api/commands.py

场景:确保同一时间只有一个实例在执行数据库迁移,使用Redis锁来防止并发问题。

@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
    click.echo("Preparing database migration...")
    lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
    if lock.acquire(blocking=False):
        try:
            click.echo(click.style("Starting database migration.", fg="green"))

            # run db migration
            import flask_migrate  # type: ignore

            flask_migrate.upgrade()

            click.echo(click.style("Database migration successful!", fg="green"))

        except Exception:
            logging.exception("Failed to execute database migration")
        finally:
            lock.release()
    else:
        click.echo("Database migration skipped")

场景:数据集服务(保存)

api/services/dataset_service.py

# DocumentService
    @staticmethod
    def save_document_with_dataset_id(
        dataset: Dataset,
        knowledge_config: KnowledgeConfig,
        account: Account | Any,
        dataset_process_rule: Optional[DatasetProcessRule] = None,
        created_from: str = "web",
    ):
        # check document limit
        features = FeatureService.get_features(current_user.current_tenant_id)

        if features.billing.enabled:
            if not knowledge_config.original_document_id:
                count = 0
                if knowledge_config.data_source:
                    if knowledge_config.data_source.info_list.data_source_type == "upload_file":
                        upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                        count = len(upload_file_list)
                    elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
                        notion_info_list = knowledge_config.data_source.info_list.notion_info_list
                        for notion_info in notion_info_list:  # type: ignore
                            count = count + len(notion_info.pages)
                    elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
                        website_info = knowledge_config.data_source.info_list.website_info_list
                        count = len(website_info.urls)  # type: ignore
                    batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)

                    if features.billing.subscription.plan == "sandbox" and count > 1:
                        raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
                    if count > batch_upload_limit:
                        raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

                    DocumentService.check_documents_upload_quota(count, features)

        # if dataset is empty, update dataset data_source_type
        if not dataset.data_source_type:
            dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type  # type: ignore

        if not dataset.indexing_technique:
            if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
                raise ValueError("Indexing technique is invalid")

            dataset.indexing_technique = knowledge_config.indexing_technique
            if knowledge_config.indexing_technique == "high_quality":
                model_manager = ModelManager()
                if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
                    dataset_embedding_model = knowledge_config.embedding_model
                    dataset_embedding_model_provider = knowledge_config.embedding_model_provider
                else:
                    embedding_model = model_manager.get_default_model_instance(
                        tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
                    )
                    dataset_embedding_model = embedding_model.model
                    dataset_embedding_model_provider = embedding_model.provider
                dataset.embedding_model = dataset_embedding_model
                dataset.embedding_model_provider = dataset_embedding_model_provider
                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                    dataset_embedding_model_provider, dataset_embedding_model
                )
                dataset.collection_binding_id = dataset_collection_binding.id
                if not dataset.retrieval_model:
                    default_retrieval_model = {
                        "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
                        "reranking_enable": False,
                        "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
                        "top_k": 2,
                        "score_threshold_enabled": False,
                    }

                    dataset.retrieval_model = (
                        knowledge_config.retrieval_model.model_dump()
                        if knowledge_config.retrieval_model
                        else default_retrieval_model
                    )  # type: ignore

        documents = []
        if knowledge_config.original_document_id:
            document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
            documents.append(document)
            batch = document.batch
        else:
            batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
            # save process rule
            if not dataset_process_rule:
                process_rule = knowledge_config.process_rule
                if process_rule:
                    if process_rule.mode in ("custom", "hierarchical"):
                        dataset_process_rule = DatasetProcessRule(
                            dataset_id=dataset.id,
                            mode=process_rule.mode,
                            rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
                            created_by=account.id,
                        )
                    elif process_rule.mode == "automatic":
                        dataset_process_rule = DatasetProcessRule(
                            dataset_id=dataset.id,
                            mode=process_rule.mode,
                            rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                            created_by=account.id,
                        )
                    else:
                        logging.warn(
                            f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
                        )
                        return
                    db.session.add(dataset_process_rule)
                    db.session.commit()
            lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
            with redis_client.lock(lock_name, timeout=600):
                position = DocumentService.get_documents_position(dataset.id)
                document_ids = []
                duplicate_document_ids = []
                if knowledge_config.data_source.info_list.data_source_type == "upload_file":  # type: ignore
                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids  # type: ignore
                    for file_id in upload_file_list:
                        file = (
                            db.session.query(UploadFile)
                            .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
                            .first()
                        )

                        # raise error if file not found
                        if not file:
                            raise FileNotExistsError()

                        file_name = file.name
                        data_source_info = {
                            "upload_file_id": file_id,
                        }
                        # check duplicate
                        if knowledge_config.duplicate:
                            document = Document.query.filter_by(
                                dataset_id=dataset.id,
                                tenant_id=current_user.current_tenant_id,
                                data_source_type="upload_file",
                                enabled=True,
                                name=file_name,
                            ).first()
                            if document:
                                document.dataset_process_rule_id = dataset_process_rule.id  # type: ignore
                                document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
                                document.created_from = created_from
                                document.doc_form = knowledge_config.doc_form
                                document.doc_language = knowledge_config.doc_language
                                document.data_source_info = json.dumps(data_source_info)
                                document.batch = batch
                                document.indexing_status = "waiting"
                                db.session.add(document)
                                documents.append(document)
                                duplicate_document_ids.append(document.id)
                                continue
                        document = DocumentService.build_document(
                            dataset,
                            dataset_process_rule.id,  # type: ignore
                            knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                            knowledge_config.doc_form,
                            knowledge_config.doc_language,
                            data_source_info,
                            created_from,
                            position,
                            account,
                            file_name,
                            batch,
                        )
                        db.session.add(document)
                        db.session.flush()
                        document_ids.append(document.id)
                        documents.append(document)
                        position += 1
                elif knowledge_config.data_source.info_list.data_source_type == "notion_import":  # type: ignore
                    notion_info_list = knowledge_config.data_source.info_list.notion_info_list  # type: ignore
                    if not notion_info_list:
                        raise ValueError("No notion info list found.")
                    exist_page_ids = []
                    exist_document = {}
                    documents = Document.query.filter_by(
                        dataset_id=dataset.id,
                        tenant_id=current_user.current_tenant_id,
                        data_source_type="notion_import",
                        enabled=True,
                    ).all()
                    if documents:
                        for document in documents:
                            data_source_info = json.loads(document.data_source_info)
                            exist_page_ids.append(data_source_info["notion_page_id"])
                            exist_document[data_source_info["notion_page_id"]] = document.id
                    for notion_info in notion_info_list:
                        workspace_id = notion_info.workspace_id
                        data_source_binding = DataSourceOauthBinding.query.filter(
                            db.and_(
                                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
                                DataSourceOauthBinding.provider == "notion",
                                DataSourceOauthBinding.disabled == False,
                                DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
                            )
                        ).first()
                        if not data_source_binding:
                            raise ValueError("Data source binding not found.")
                        for page in notion_info.pages:
                            if page.page_id not in exist_page_ids:
                                data_source_info = {
                                    "notion_workspace_id": workspace_id,
                                    "notion_page_id": page.page_id,
                                    "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
                                    "type": page.type,
                                }
                                # Truncate page name to 255 characters to prevent DB field length errors
                                truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
                                document = DocumentService.build_document(
                                    dataset,
                                    dataset_process_rule.id,  # type: ignore
                                    knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                                    knowledge_config.doc_form,
                                    knowledge_config.doc_language,
                                    data_source_info,
                                    created_from,
                                    position,
                                    account,
                                    truncated_page_name,
                                    batch,
                                )
                                db.session.add(document)
                                db.session.flush()
                                document_ids.append(document.id)
                                documents.append(document)
                                position += 1
                            else:
                                exist_document.pop(page.page_id)
                    # delete not selected documents
                    if len(exist_document) > 0:
                        clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
                elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":  # type: ignore
                    website_info = knowledge_config.data_source.info_list.website_info_list  # type: ignore
                    if not website_info:
                        raise ValueError("No website info list found.")
                    urls = website_info.urls
                    for url in urls:
                        data_source_info = {
                            "url": url,
                            "provider": website_info.provider,
                            "job_id": website_info.job_id,
                            "only_main_content": website_info.only_main_content,
                            "mode": "crawl",
                        }
                        if len(url) > 255:
                            document_name = url[:200] + "..."
                        else:
                            document_name = url
                        document = DocumentService.build_document(
                            dataset,
                            dataset_process_rule.id,  # type: ignore
                            knowledge_config.data_source.info_list.data_source_type,  # type: ignore
                            knowledge_config.doc_form,
                            knowledge_config.doc_language,
                            data_source_info,
                            created_from,
                            position,
                            account,
                            document_name,
                            batch,
                        )
                        db.session.add(document)
                        db.session.flush()
                        document_ids.append(document.id)
                        documents.append(document)
                        position += 1
                db.session.commit()

                # trigger async task
                if document_ids:
                    document_indexing_task.delay(dataset.id, document_ids)
                if duplicate_document_ids:
                    duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

        return documents, batch

速率限制器

api/libs/helper.py

RateLimiter构造要素
  • 前缀

  • 最大速率限制

  • 时间窗口

增加速率限制:increment_rate_limit

按照email(唯一标识账号)对请求速率进行限制。

  1. 使用有序集合记录请求时刻(zadd),值和分数都设置为当前时间戳,便于按照时间顺序对时间戳进行排序。

  2. 设置key过期时间,确保不需要的数据被及时清理,避免大量不活跃数据堆积在redis中

判断是否达到速率限制:is_rate_limited
  1. 获取当前时刻的时间戳:current_time

  2. 计算有效时间窗口下限

  3. 去除有效时间窗口下限前的请求记录(zremrangebyscore)

  4. 统计有效时间窗口范围内的请求次数attempts(zcard)

  5. 判断attempts是否超出预设次数上限

RateLimiter完整源码
class RateLimiter:
    def __init__(self, prefix: str, max_attempts: int, time_window: int):
        self.prefix = prefix
        self.max_attempts = max_attempts
        self.time_window = time_window

    def _get_key(self, email: str) -> str:
        return f"{self.prefix}:{email}"

    def is_rate_limited(self, email: str) -> bool:
        key = self._get_key(email)
        current_time = int(time.time())
        window_start_time = current_time - self.time_window

        redis_client.zremrangebyscore(key, "-inf", window_start_time)
        attempts = redis_client.zcard(key)

        if attempts and int(attempts) >= self.max_attempts:
            return True
        return False

    def increment_rate_limit(self, email: str):
        key = self._get_key(email)
        current_time = int(time.time())

        redis_client.zadd(key, {current_time: current_time})
        redis_client.expire(key, self.time_window * 2)

app/features/rate_limiting/rate_limit.py

class RateLimit:
    _MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
    _ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
    _UNLIMITED_REQUEST_ID = "unlimited_request_id"
    _REQUEST_MAX_ALIVE_TIME = 10 * 60  # 10 minutes
    _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60  # recalculate request_count from request_detail every 5 minutes
    _instance_dict: dict[str, "RateLimit"] = {}

    def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
        if client_id not in cls._instance_dict:
            instance = super().__new__(cls)
            cls._instance_dict[client_id] = instance
        return cls._instance_dict[client_id]

    def __init__(self, client_id: str, max_active_requests: int):
        self.max_active_requests = max_active_requests
        # must be called after max_active_requests is set
        if self.disabled():
            return
        if hasattr(self, "initialized"):
            return
        self.initialized = True
        self.client_id = client_id
        self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
        self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
        self.last_recalculate_time = float("-inf")
        self.flush_cache(use_local_value=True)
CoolCats
CoolCats
理学学士

我的研究兴趣是时空数据分析、知识图谱、自然语言处理与服务端开发