Redis案例应用--Dify源码阅读
概览
相关应用源码
缓存
相关源码位置
api/core/rag/embedding/cached_embedding.py
对嵌入向量进行缓存。
api/services/account_service.py
……
向量嵌入缓存
缓存key构造参数:模型提供者、模型版本、输入文本
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
使用setex为缓存key设置值和过期时间10min。
try: # encode embedding to base64 embedding_vector = np.array(embedding_results) vector_bytes = embedding_vector.tobytes() # Transform to Base64 encoded_vector = base64.b64encode(vector_bytes) # Transform to string encoded_str = encoded_vector.decode("utf-8") redis_client.setex(embedding_cache_key, 600, encoded_str) except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'") raise ex return embedding_results
命中缓存后通过expire更新过期时间为10min。
embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") return [float(x) for x in decoded_embedding]
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
raise ex
return embedding_results
模型凭证缓存
api/core/helper/model_provider_cache.py
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}🆔{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
私钥缓存
api/libs/rsa.py
def get_decrypt_decoding(tenant_id):
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key)
if not private_key:
try:
private_key = storage.load(filepath)
except FileNotFoundError:
raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))
redis_client.setex(cache_key, 120, private_key)
rsa_key = RSA.import_key(private_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
return rsa_key, cipher_rsa
账号refresh_token
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"
@staticmethod
def _get_account_refresh_token_key(account_id: str) -> str:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
计数器
相关源码位置
api/tasks/ops_trace_task.py
……
源码示例
执行trace记录时出异常则在redis中记录失败次数。
redis_client.incr(failed_key)
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
"""
Async process trace tasks
Usage: process_trace_tasks.delay(tasks_data)
"""
from core.ops.ops_trace_manager import OpsTraceManager
app_id = file_info.get("app_id")
file_id = file_info.get("file_id")
file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
if trace_info.get("message_data"):
trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"])
if trace_info.get("workflow_data"):
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
if trace_info.get("documents"):
trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]]
try:
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logging.info(f"Processing trace tasks success, app_id: {app_id}")
except Exception:
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logging.info(f"Processing trace tasks failed, app_id: {app_id}")
finally:
storage.delete(file_path)
分布式锁(Lock)
场景:数据库迁移
源码:api/commands.py
场景:确保同一时间只有一个实例在执行数据库迁移,使用Redis锁来防止并发问题。
@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style("Starting database migration.", fg="green"))
# run db migration
import flask_migrate # type: ignore
flask_migrate.upgrade()
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
logging.exception("Failed to execute database migration")
finally:
lock.release()
else:
click.echo("Database migration skipped")
场景:数据集服务(保存)
api/services/dataset_service.py
# DocumentService
@staticmethod
def save_document_with_dataset_id(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if not knowledge_config.original_document_id:
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
DocumentService.check_documents_upload_quota(count, features)
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
速率限制器
api/libs/helper.py
RateLimiter构造要素
前缀
最大速率限制
时间窗口
增加速率限制:increment_rate_limit
按照email(唯一标识账号)对请求速率进行限制。
使用有序集合记录请求时刻(zadd),值和分数都设置为当前时间戳,便于按照时间顺序对时间戳进行排序。
设置key过期时间,确保不需要的数据被及时清理,避免大量不活跃数据堆积在redis中
判断是否达到速率限制:is_rate_limited
获取当前时刻的时间戳:current_time
计算有效时间窗口下限
去除有效时间窗口下限前的请求记录(zremrangebyscore)
统计有效时间窗口范围内的请求次数attempts(zcard)
判断attempts是否超出预设次数上限
RateLimiter完整源码
class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email)
current_time = int(time.time())
window_start_time = current_time - self.time_window
redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = redis_client.zcard(key)
if attempts and int(attempts) >= self.max_attempts:
return True
return False
def increment_rate_limit(self, email: str):
key = self._get_key(email)
current_time = int(time.time())
redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)
app/features/rate_limiting/rate_limit.py
class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
# must be called after max_active_requests is set
if self.disabled():
return
if hasattr(self, "initialized"):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float("-inf")
self.flush_cache(use_local_value=True)