Celery案例应用——Dify源码分析
Dify如何使用Celery
项目骨架
Celery配置对象:CeleryConfig
配置对象包含:任务队列结果存储后端、消息中间件地址、是否使用Redis哨兵实现高可用、哨兵模式下的Redis主服务、SOCKET超时时间、数据库连接信息等。
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description="URL of the message broker for Celery tasks.",
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel for high availability.",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
)
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
数据库连接配置DatabaseConfig
class DatabaseConfig(BaseSettings):
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description="Port number for database connection.",
default=5432,
)
DB_USERNAME: str = Field(
description="Username for database authentication.",
default="postgres",
)
DB_PASSWORD: str = Field(
description="Password for database authentication.",
default="",
)
DB_DATABASE: str = Field(
description="Name of the database to connect to.",
default="dify",
)
DB_CHARSET: str = Field(
description="Character set for database connection.",
default="",
)
DB_EXTRAS: str = Field(
description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
default="postgresql",
)
@computed_field
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}"
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="Maximum number of database connections in the pool.",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="Maximum number of connections that can be created beyond the pool_size.",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="Number of seconds after which a connection is automatically recycled.",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description="If True, SQLAlchemy will log all SQL statements.",
default=False,
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count(),
)
@computed_field
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
}
中间件配置
CeleryConfig会作为Dify应用配置的中间件配置类(MiddlewareConfig)的父类之一。
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,
AzureBlobStorageConfig,
BaiduOBSStorageConfig,
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
OpenDALStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,
VolcengineTOSStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
MilvusConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,
PGVectorConfig,
PGVectoRSConfig,
QdrantConfig,
RelytConfig,
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
CouchbaseConfig,
InternalTestConfig,
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
LindormConfig,
OceanBaseVectorConfig,
BaiduVectorDBConfig,
OpenGaussConfig,
TableStoreConfig,
):
pass
初始化celery app: api/extensions/ext_celery.py
自定义Celery任务FlaskTask,使之运行时能够在flask应用的上下文中。
如果使用redis哨兵模式,则配置broker_transport_options
定义并配置任务脚本搜索路径
配置定时任务
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
with app.app_context():
return self.run(*args, **kwargs)
broker_transport_options = {}
if dify_config.CELERY_USE_SENTINEL:
broker_transport_options = {
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
},
}
celery_app = Celery(
app.name,
task_cls=FlaskTask,
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
task_ignore_result=True,
)
# Add SSL options to the Celery configuration
ssl_options = {
"ssl_cert_reqs": None,
"ssl_ca_certs": None,
"ssl_certfile": None,
"ssl_keyfile": None,
}
celery_app.conf.update(
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options,
broker_connection_retry_on_startup=True,
worker_log_format=dify_config.LOG_FORMAT,
worker_task_log_format=dify_config.LOG_FORMAT,
worker_hijack_root_logger=False,
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
)
if dify_config.BROKER_USE_SSL:
celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
)
if dify_config.LOG_FILE:
celery_app.conf.update(
worker_logfile=dify_config.LOG_FILE,
)
celery_app.set_default()
app.extensions["celery"] = celery_app
imports = [
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
"schedule.create_tidb_serverless_task",
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
"clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
"create_tidb_serverless_task": {
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
"schedule": crontab(minute="0", hour="*"),
},
"update_tidb_serverless_status_task": {
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
"schedule": timedelta(minutes=10),
},
"clean_messages": {
"task": "schedule.clean_messages.clean_messages",
"schedule": timedelta(days=day),
},
# every Monday
"mail_clean_document_notify_task": {
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app
定义celery任务: api/tasks
队列规划
dataset
处理数据集相关的程序
mail
处理邮件相关需求
ops_trace
app_deletion
处理删除app时相关数据删除需求
mail队列相关任务
发送登录时邮箱验证码
import logging import time import click from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail @shared_task(queue="mail") def send_email_code_login_mail_task(language: str, to: str, code: str): """ Async Send email code login mail :param language: Language in which the email should be sent (e.g., 'en', 'zh') :param to: Recipient email address :param code: Email code to be included in the email """ if not mail.is_inited(): return logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send email code login mail using different languages try: if language == "zh-Hans": html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) mail.send(to=to, subject="邮箱验证码", html=html_content) else: html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) mail.send(to=to, subject="Email Code", html=html_content) end_at = time.perf_counter() logging.info( click.style( "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" ) ) except Exception: logging.exception("Send email code login mail to {} failed".format(to))
邀请成员加入工作空间
@shared_task(queue="mail") def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Async Send invite member mail :param language :param to :param token :param inviter_name :param workspace_name Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) """ if not mail.is_inited(): return logging.info( click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") ) start_at = time.perf_counter() # send invite member mail using different languages try: url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" if language == "zh-Hans": html_content = render_template( "invite_member_mail_template_zh-CN.html", to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url, ) mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: html_content = render_template( "invite_member_mail_template_en-US.html", to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url, ) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( click.style( "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" ) ) except Exception: logging.exception("Send invite member mail to {} failed".format(to))
发送重置密码邮件
import logging import time import click from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail @shared_task(queue="mail") def send_reset_password_mail_task(language: str, to: str, code: str): """ Async Send reset password mail :param language: Language in which the email should be sent (e.g., 'en', 'zh') :param to: Recipient email address :param code: Reset password code """ if not mail.is_inited(): return logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send reset password mail using different languages try: if language == "zh-Hans": html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) else: html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) mail.send(to=to, subject="Set Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( click.style( "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" ) ) except Exception: logging.exception("Send password reset mail to {} failed".format(to))
……
celery应用运行环境
与api service同样使用基于dify-api:1.1.3镜像的环境。
# base image
FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=2.0.1
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages
# if you located in China, you can use aliyun mirror to speed up
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml poetry.lock ./
RUN poetry install --sync --no-cache --no-root
# production stage
FROM base AS production
ENV FLASK_APP=app.py
ENV EDITION=SELF_HOSTED
ENV DEPLOY_ENV=PRODUCTION
ENV CONSOLE_API_URL=http://127.0.0.1:5001
ENV CONSOLE_WEB_URL=http://127.0.0.1:3000
ENV SERVICE_API_URL=http://127.0.0.1:5001
ENV APP_WEB_URL=http://127.0.0.1:3000
EXPOSE 5001
# set timezone
ENV TZ=UTC
WORKDIR /app/api
RUN \
apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE
libmagic1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code
COPY . /app/api/
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
部署celery worker和celery beat实例:api/docker/entrypoint.sh
启动脚本定义于api/docker/entrypoint.sh中,当环境变量MODE为worker时,启动celery进程。
#!/bin/bash
set -e
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
echo "Running migrations"
flask upgrade-db
fi
if [[ "${MODE}" == "worker" ]]; then
# Get the number of available CPU cores
if [ "${CELERY_AUTO_SCALE,,}" = "true" ]; then
# Set MAX_WORKERS to the number of available cores if not specified
AVAILABLE_CORES=$(nproc)
MAX_WORKERS=${CELERY_MAX_WORKERS:-$AVAILABLE_CORES}
MIN_WORKERS=${CELERY_MIN_WORKERS:-1}
CONCURRENCY_OPTION="--autoscale=${MAX_WORKERS},${MIN_WORKERS}"
else
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
fi
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
fi
fi
docker-compose.yml中关于celery worker
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.1.3
restart: always
environment:
# Use the shared environment variables.
<<: *shared-api-worker-env
# Startup mode, 'worker' starts the Celery worker for processing the queue.
MODE: worker
SENTRY_DSN: ${API_SENTRY_DSN:-}
SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0}
SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
- db
- redis
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
networks:
- ssrf_proxy_network
- default
networks:
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
ssrf_proxy_network:
driver: bridge
internal: true
milvus:
driver: bridge
opensearch-net:
driver: bridge
internal: true
volumes:
oradata:
dify_es01_data:
在应用服务中触发celery任务
如在api/services/account_service.py中,账号服务类的方法会发起异步任务。
部分节选代码如下:
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task
class AccountService:
@classmethod
def send_reset_password_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
language: Optional[str] = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token