处理管线¶
本文档详细介绍 Unifiles 的文档处理管线,从文件上传到知识库索引的完整流程。
管线概览¶
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ 上传 │ → │ 验证 │ → │ 转换 │ → │ OCR │ → │ 后处理 │
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
│
▼
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ 搜索 │ ← │ 索引 │ ← │ 嵌入 │ ← │ 分块 │ ← │ Markdown│
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
阶段详解¶
Stage 1: 文件上传¶
# 上传流程
async def upload_file(file: UploadFile, user_id: str) -> File:
# 1. 生成文件 ID
file_id = generate_uuid()
# 2. 计算内容哈希 (用于去重)
content = await file.read()
content_hash = hashlib.sha256(content).hexdigest()
# 3. 检查是否已存在相同文件
existing = await find_file_by_hash(user_id, content_hash)
if existing:
return existing # 直接返回已有文件
# 4. 存储到 MinIO
storage_key = f"raw/{user_id}/{file_id}/{file.filename}"
await storage.upload(
bucket="unifiles-raw",
key=storage_key,
data=content
)
# 5. 保存元数据到 PostgreSQL
file_record = await save_file_record(
id=file_id,
user_id=user_id,
filename=file.filename,
storage_key=storage_key,
content_hash=content_hash,
size=len(content),
status="uploaded"
)
return file_record
Stage 2: 文件验证¶
验证阶段检查文件的合法性:
class FileValidator:
ALLOWED_TYPES = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"text/plain": ".txt",
"text/markdown": ".md",
"text/html": ".html",
"image/png": ".png",
"image/jpeg": ".jpg",
}
async def validate(self, file: UploadFile) -> ValidationResult:
errors = []
# 1. 检查 MIME 类型
if file.content_type not in self.ALLOWED_TYPES:
errors.append(f"Unsupported file type: {file.content_type}")
# 2. 检查文件大小
content = await file.read()
if len(content) > MAX_FILE_SIZE:
errors.append(f"File too large: {len(content)} bytes")
# 3. 检查文件内容 (magic bytes)
detected_type = magic.from_buffer(content, mime=True)
if detected_type != file.content_type:
errors.append(f"Content type mismatch: declared {file.content_type}, detected {detected_type}")
# 4. 检查文件是否损坏
if not await self.is_valid_file(content, file.content_type):
errors.append("File appears to be corrupted")
return ValidationResult(
valid=len(errors) == 0,
errors=errors
)
Stage 3: 格式转换¶
将各种格式统一转换为可处理的中间格式:
class FormatConverter:
"""格式转换器"""
async def convert(self, file_path: str, mime_type: str) -> ConversionResult:
if mime_type == "application/pdf":
return await self.process_pdf(file_path)
elif mime_type.startswith("application/vnd.openxmlformats"):
return await self.process_office(file_path)
elif mime_type.startswith("image/"):
return await self.process_image(file_path)
elif mime_type.startswith("text/"):
return await self.process_text(file_path)
else:
raise UnsupportedFormatError(mime_type)
async def process_pdf(self, file_path: str) -> ConversionResult:
"""处理 PDF 文件"""
pages = []
with fitz.open(file_path) as doc:
for page_num, page in enumerate(doc):
# 提取文本
text = page.get_text("text")
# 提取表格
tables = page.find_tables()
# 如果文本很少,可能是扫描件,需要 OCR
if len(text.strip()) < 100:
# 渲染为图片
pix = page.get_pixmap(dpi=300)
image_path = f"/tmp/page_{page_num}.png"
pix.save(image_path)
pages.append(PageResult(
page_num=page_num,
text=text,
tables=tables,
needs_ocr=True,
image_path=image_path
))
else:
pages.append(PageResult(
page_num=page_num,
text=text,
tables=tables,
needs_ocr=False
))
return ConversionResult(pages=pages)
Stage 4: OCR 识别¶
对需要 OCR 的页面进行文字识别:
class OCRProcessor:
"""OCR 处理器"""
def __init__(self, provider: str = "internal"):
self.provider = self._create_provider(provider)
async def process(self, pages: list[PageResult]) -> list[PageResult]:
"""处理需要 OCR 的页面"""
tasks = []
for page in pages:
if page.needs_ocr:
tasks.append(self._ocr_page(page))
else:
tasks.append(asyncio.sleep(0, result=page))
return await asyncio.gather(*tasks)
async def _ocr_page(self, page: PageResult) -> PageResult:
"""对单页进行 OCR"""
# 调用 OCR 提供商
result = await self.provider.recognize(page.image_path)
# 合并 OCR 结果
page.text = result.text
page.ocr_confidence = result.confidence
page.ocr_boxes = result.boxes # 文字位置信息
return page
支持的 OCR 提供商:
| 提供商 | 特点 |
|---|---|
| 内置 (Tesseract) | 免费,离线,基础准确率 |
| 阿里云 OCR | 高精度,支持中文 |
| Azure Document Intelligence | 结构化提取 |
| 自定义 | 通过接口扩展 |
Stage 5: 后处理¶
将 OCR 结果转换为结构化 Markdown:
class PostProcessor:
"""后处理器:生成 Markdown"""
async def process(self, pages: list[PageResult]) -> MarkdownResult:
markdown_parts = []
for page in pages:
# 1. 处理标题
headings = self._extract_headings(page)
# 2. 处理段落
paragraphs = self._extract_paragraphs(page)
# 3. 处理表格
tables = self._format_tables(page.tables)
# 4. 处理图片引用
images = self._extract_images(page)
# 5. 组装 Markdown
page_md = self._assemble_markdown(
headings, paragraphs, tables, images
)
markdown_parts.append(page_md)
# 合并所有页面
full_markdown = "\n\n".join(markdown_parts)
# 提取元数据
metadata = self._extract_metadata(pages)
return MarkdownResult(
content=full_markdown,
metadata=metadata,
page_count=len(pages),
word_count=len(full_markdown.split())
)
def _format_tables(self, tables: list) -> str:
"""将表格转换为 Markdown 格式"""
result = []
for table in tables:
# 表头
header = "| " + " | ".join(table[0]) + " |"
separator = "| " + " | ".join(["---"] * len(table[0])) + " |"
# 表体
body = []
for row in table[1:]:
body.append("| " + " | ".join(row) + " |")
result.append("\n".join([header, separator] + body))
return "\n\n".join(result)
Stage 6: 文本分块¶
将 Markdown 内容分割为适合检索的块:
class TextChunker:
"""文本分块器"""
def __init__(
self,
strategy: str = "semantic",
chunk_size: int = 512,
chunk_overlap: int = 50
):
self.strategy = strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk(self, text: str) -> list[Chunk]:
if self.strategy == "fixed":
return self._fixed_chunk(text)
elif self.strategy == "semantic":
return self._semantic_chunk(text)
elif self.strategy == "paragraph":
return self._paragraph_chunk(text)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
def _semantic_chunk(self, text: str) -> list[Chunk]:
"""语义分块:在语义边界处分割"""
chunks = []
# 首先按段落分割
paragraphs = text.split("\n\n")
current_chunk = []
current_size = 0
for para in paragraphs:
para_tokens = self._count_tokens(para)
if current_size + para_tokens > self.chunk_size:
# 保存当前块
if current_chunk:
chunks.append(Chunk(
content="\n\n".join(current_chunk),
token_count=current_size
))
# 开始新块 (包含重叠)
overlap_paras = self._get_overlap(current_chunk)
current_chunk = overlap_paras + [para]
current_size = sum(self._count_tokens(p) for p in current_chunk)
else:
current_chunk.append(para)
current_size += para_tokens
# 保存最后一块
if current_chunk:
chunks.append(Chunk(
content="\n\n".join(current_chunk),
token_count=current_size
))
return chunks
Stage 7: 向量嵌入¶
为每个分块生成向量嵌入:
class EmbeddingService:
"""嵌入服务"""
def __init__(
self,
provider: str = "openai",
model: str = "text-embedding-3-small"
):
self.provider = provider
self.model = model
self.client = self._create_client()
async def embed(self, texts: list[str]) -> list[list[float]]:
"""批量生成嵌入"""
# 分批处理 (OpenAI 限制)
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = await self.client.embeddings.create(
model=self.model,
input=batch
)
embeddings = [e.embedding for e in response.data]
all_embeddings.extend(embeddings)
return all_embeddings
async def embed_single(self, text: str) -> list[float]:
"""生成单个嵌入"""
embeddings = await self.embed([text])
return embeddings[0]
Stage 8: 向量索引¶
将分块和嵌入存储到 pgvector:
class VectorIndexer:
"""向量索引器"""
async def index(
self,
document_id: str,
kb_id: str,
chunks: list[Chunk],
embeddings: list[list[float]]
):
"""索引文档分块"""
async with self.pool.acquire() as conn:
# 批量插入
await conn.executemany(
"""
INSERT INTO chunks (
id, document_id, knowledge_base_id,
content, token_count, chunk_index,
embedding, metadata
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""",
[
(
str(uuid.uuid4()),
document_id,
kb_id,
chunk.content,
chunk.token_count,
i,
embedding,
json.dumps(chunk.metadata or {})
)
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings))
]
)
# 更新文档统计
await conn.execute(
"""
UPDATE documents
SET chunk_count = $1, status = 'indexed', indexed_at = NOW()
WHERE id = $2
""",
len(chunks),
document_id
)
Stage 9: 向量搜索¶
执行语义相似度搜索:
class VectorSearch:
"""向量搜索"""
async def search(
self,
kb_id: str,
query: str,
top_k: int = 10,
filters: dict = None
) -> list[SearchResult]:
"""执行向量搜索"""
# 1. 生成查询向量
query_embedding = await self.embedding_service.embed_single(query)
# 2. 构建查询
sql = """
SELECT
c.id,
c.content,
c.metadata,
c.document_id,
d.title as document_title,
1 - (c.embedding <=> $1::vector) as similarity
FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE c.knowledge_base_id = $2
"""
params = [query_embedding, kb_id]
# 3. 应用过滤器
if filters:
if "metadata" in filters:
sql += " AND c.metadata @> $3"
params.append(json.dumps(filters["metadata"]))
# 4. 排序和限制
sql += " ORDER BY c.embedding <=> $1::vector LIMIT $" + str(len(params) + 1)
params.append(top_k)
# 5. 执行查询
async with self.pool.acquire() as conn:
rows = await conn.fetch(sql, *params)
# 6. 构建结果
results = [
SearchResult(
chunk_id=row["id"],
content=row["content"],
score=row["similarity"],
document_id=row["document_id"],
document_title=row["document_title"],
metadata=json.loads(row["metadata"])
)
for row in rows
]
return results
处理状态机¶
┌──────────────────────────────────────────┐
│ │
▼ │
┌──────────┐ │
│ uploaded │ ──────────────────────┐ │
└────┬─────┘ │ │
│ │ │
▼ ▼ │
┌──────────┐ ┌──────────┐ │
│processing│ ───────────────→│ error │ ───────┘
└────┬─────┘ └──────────┘ (retry)
│
▼
┌──────────┐
│extracted │
└────┬─────┘
│
▼
┌──────────┐
│ indexed │
└──────────┘
错误处理¶
class PipelineProcessor:
"""管线处理器"""
async def process(self, file_id: str) -> ProcessingResult:
try:
# 记录开始
await self.log_stage(file_id, "started")
# 执行各阶段
for stage in self.stages:
try:
await self.log_stage(file_id, stage.name, "started")
result = await stage.execute(file_id)
await self.log_stage(file_id, stage.name, "completed")
except StageError as e:
await self.log_stage(file_id, stage.name, "failed", str(e))
if stage.required:
raise
# 可选阶段失败,继续执行
await self.update_status(file_id, "completed")
return ProcessingResult(success=True)
except Exception as e:
await self.update_status(file_id, "error", str(e))
return ProcessingResult(success=False, error=str(e))
性能优化¶
并行处理¶
# 并行处理多个页面的 OCR
async def parallel_ocr(pages: list[PageResult]) -> list[PageResult]:
semaphore = asyncio.Semaphore(4) # 限制并发
async def process_with_limit(page):
async with semaphore:
return await ocr_processor.process_page(page)
return await asyncio.gather(*[
process_with_limit(page)
for page in pages if page.needs_ocr
])
批量嵌入¶
流式处理¶
# 大文件流式处理
async def stream_process_large_file(file_path: str):
async for page in pdf_streamer.stream(file_path):
# 处理单页
result = await process_page(page)
yield result