Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions frontend/public/config/error-code.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"0": "成功",
"rag.0001": "RAG 配置错误",
"rag.0002": "知识库不存在",
"rag.0003": "知识库名称已存在",
"cleaning.0001": "清洗任务不存在",
"cleaning.0002": "清洗任务名称重复",
"cleaning.0003": "清洗模板不存在",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ export default function CreateKnowledgeBase({
setOpen(false);
onUpdate();
} catch (error) {
message.error(t("knowledgeBase.create.messages.operationFailed") + error.data.message);
// 错误已由全局拦截器统一处理,此处不再重复提示
console.error("知识库操作失败:", error);
}
};

Expand Down
98 changes: 52 additions & 46 deletions frontend/src/pages/KnowledgeBase/components/KnowledgeGraphView.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React, {useMemo, useRef, useEffect} from "react";
import React, { useMemo, useRef, useEffect, useState } from "react";
import ForceGraph2D from "react-force-graph-2d";
import type {KnowledgeGraphEdge, KnowledgeGraphNode} from "../knowledge-base.model";
import type { KnowledgeGraphEdge, KnowledgeGraphNode } from "../knowledge-base.model";

export type GraphEntitySelection =
| { type: "node"; data: KnowledgeGraphNode }
Expand All @@ -16,19 +16,45 @@ interface KnowledgeGraphViewProps {
const COLOR_PALETTE = ["#60a5fa", "#f87171", "#fbbf24", "#34d399", "#a78bfa", "#fb7185", "#22d3ee", "#818cf8", "#fb923c", "#4ade80"];

const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
nodes,
edges,
height = 520,
onSelectEntity,
}) => {
nodes,
edges,
height = 520,
onSelectEntity,
}) => {
const graphRef = useRef<any>();
// 新增:用于监听尺寸的容器引用
const containerRef = useRef<HTMLDivElement>(null);
// 新增:保存当前实际宽高的状态
const [dimensions, setDimensions] = useState({ width: 0, height: 0 });

// --- 核心修复:监听容器大小变化 ---
useEffect(() => {
if (!containerRef.current) return;

const resizeObserver = new ResizeObserver((entries) => {
for (const entry of entries) {
const { width, height } = entry.contentRect;
setDimensions({ width, height });

// 强制通知 force-graph 组件更新内部 canvas 尺寸
if (graphRef.current) {
graphRef.current.width(width);
graphRef.current.height(height);
// 可选:如果希望尺寸变化后图谱自动居中,取消下行注释
// graphRef.current.zoomToFit(400);
}
}
});

resizeObserver.observe(containerRef.current);
return () => resizeObserver.disconnect();
}, []);

useEffect(() => {
if (graphRef.current) {
// 1. 调整力导向平衡:减小斥力让独立图块靠近,增加向心力防止飘散
graphRef.current.d3Force("charge").strength(-250); // 斥力适中
graphRef.current.d3Force("link").distance(120); // 边长适中
graphRef.current.d3Force("center").strength(0.8); // 增强向心力,让孤立集群往中间靠
graphRef.current.d3Force("charge").strength(-250);
graphRef.current.d3Force("link").distance(120);
graphRef.current.d3Force("center").strength(0.8);
}
}, [nodes]);

Expand All @@ -43,7 +69,7 @@ const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
nodes: nodes.map((node) => ({
...node,
color: typeColorMap.get(node.properties?.entity_type || (node.labels && node.labels[0]) || 'default'),
val: 8 // 统一基础大小,使视觉更整洁
val: 8
})),
links: edges.map((edge) => ({
...edge,
Expand All @@ -53,9 +79,15 @@ const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
}), [nodes, edges, typeColorMap]);

return (
<div style={{width: "100%", height, background: "#01030f"}}>
<div
ref={containerRef}
style={{ width: "100%", height, background: "#01030f", overflow: "hidden" }}
>
<ForceGraph2D
ref={graphRef}
// 传入动态计算的宽高
width={dimensions.width}
height={dimensions.height}
graphData={graphData}
backgroundColor="#01030f"

Expand All @@ -67,8 +99,8 @@ const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
linkCurvature={0.1}

// --- 节点绘制 ---
nodeCanvasObject={(node: any, ctx, globalScale) => {
const {x, y, val: radius, color, id} = node;
nodeCanvasObject={(node: never, ctx, globalScale) => {
const { x, y, val: radius, color, id } = node;
if (!Number.isFinite(x) || !Number.isFinite(y)) return;

ctx.save();
Expand All @@ -79,7 +111,6 @@ const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
ctx.shadowColor = color;
ctx.fill();

// 节点名称
if (globalScale > 0.4) {
const fontSize = 12 / globalScale;
ctx.font = `${fontSize}px Sans-Serif`;
Expand All @@ -95,82 +126,57 @@ const KnowledgeGraphView: React.FC<KnowledgeGraphViewProps> = ({
linkPointerAreaPaint={(link: any, color, ctx, globalScale) => {
const label = link.keywords;
if (!label || globalScale < 1.1) return;

const start = link.source;
const end = link.target;
if (typeof start !== 'object' || typeof end !== 'object') return;

const fontSize = 9 / globalScale;
const textPos = {x: start.x + (end.x - start.x) * 0.5, y: start.y + (end.y - start.y) * 0.5};
const textPos = { x: start.x + (end.x - start.x) * 0.5, y: start.y + (end.y - start.y) * 0.5 };
const angle = Math.atan2(end.y - start.y, end.x - start.x);
const bRotate = angle > Math.PI / 2 || angle < -Math.PI / 2;

ctx.save();
ctx.translate(textPos.x, textPos.y);
ctx.rotate(bRotate ? angle + Math.PI : angle);

ctx.font = `${fontSize}px Sans-Serif`;
const textWidth = ctx.measureText(label).width;

// 绘制一个与文字大小相同的透明矩形,颜色必须使用参数中的 'color'
// 这是 react-force-graph 识别点击对象的关键(Color-picking 技术)
ctx.fillStyle = color;
ctx.fillRect(-textWidth / 2 - 2, -fontSize / 2 - 2, textWidth + 4, fontSize + 4);
ctx.restore();
}}

// --- 边文字绘制:优化大小、位置和翻转逻辑 ---
linkCanvasObjectMode={() => 'after'}
linkCanvasObject={(link: any, ctx, globalScale) => {
const MAX_DISPLAY_SCALE = 1.1;
if (globalScale < MAX_DISPLAY_SCALE) return;

const label = link.keywords;
const start = link.source;
const end = link.target;
if (typeof start !== 'object' || typeof end !== 'object') return;

// 边文字比节点文字小一点点(节点12,边11)
const fontSize = 11 / globalScale;

const textPos = {
x: start.x + (end.x - start.x) * 0.5,
y: start.y + (end.y - start.y) * 0.5
};

const textPos = { x: start.x + (end.x - start.x) * 0.5, y: start.y + (end.y - start.y) * 0.5 };
let angle = Math.atan2(end.y - start.y, end.x - start.x);

// --- 核心修复:防止文字倒挂 ---
// 如果角度在 90度 到 270度 之间,旋转180度让文字保持正向
const bRotate = angle > Math.PI / 2 || angle < -Math.PI / 2;

ctx.save();
ctx.translate(textPos.x, textPos.y);
ctx.rotate(bRotate ? angle + Math.PI : angle);

ctx.font = `${fontSize}px Sans-Serif`;
const textWidth = ctx.measureText(label).width;

// 绘制极小的背景遮罩,紧贴文字
ctx.fillStyle = 'rgba(1, 3, 15, 0.7)';
ctx.fillRect(-textWidth / 2 - 1, -fontSize / 2, textWidth + 2, fontSize);

ctx.fillStyle = '#94e2d5';
ctx.textAlign = 'center';
ctx.textBaseline = 'middle';
// y轴偏移设为0,使其紧贴线条中心
ctx.fillText(label, 0, 0);
ctx.restore();
}}

onNodeClick={(node: any) => onSelectEntity?.({type: "node", data: node})}
onNodeClick={(node: any) => onSelectEntity?.({ type: "node", data: node })}
onLinkClick={(link: any) => {
const originalData = link.__originalEdge || link;
onSelectEntity?.({type: "edge", data: originalData});
onSelectEntity?.({ type: "edge", data: originalData });
}}
onBackgroundClick={() => onSelectEntity?.(null)}
cooldownTicks={120}
d3VelocityDecay={0.4} // 增加阻力,使布局更快稳定
d3VelocityDecay={0.4}
/>
</div>
);
Expand Down
53 changes: 28 additions & 25 deletions runtime/datamate-python/app/module/rag/service/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _process_graph_files(
logger.exception("初始化知识图谱失败: %s", e)
for rag_file in files:
file_repo = RagFileRepository(db)
await self._mark_failed(db, file_repo, str(rag_file.id), f"知识图谱初始化失败: {str(e)}") # type: ignore
await self._mark_failed(db, file_repo, str(rag_file.id), f"知识图谱初始化失败: {str(e)}")

@staticmethod
async def _initialize_graph_rag(db: AsyncSession, knowledge_base: KnowledgeBase):
Expand All @@ -145,32 +145,35 @@ async def _process_single_graph_file(
file_repo = RagFileRepository(db)

try:
await self._update_status(db, file_repo, str(rag_file.id), FileStatus.PROCESSING, 10) # type: ignore
await self._update_status(db, file_repo, str(rag_file.id), FileStatus.PROCESSING, 10)
await db.commit()

file_path = get_file_path(rag_file)
if not file_path or not Path(file_path).exists():
await self._mark_failed(db, file_repo, str(rag_file.id), "文件不存在") # type: ignore
await self._mark_failed(db, file_repo, str(rag_file.id), "文件不存在")
return

documents = load_documents(file_path)
if not documents:
await self._mark_failed(db, file_repo, str(rag_file.id), "文件解析失败,未生成文档") # type: ignore
await self._mark_failed(db, file_repo, str(rag_file.id), "文件解析失败,未生成文档")
return

await self._update_progress(db, file_repo, str(rag_file.id), 30) # type: ignore
await self._update_progress(db, file_repo, str(rag_file.id), 30)
await db.commit()

for idx, doc in enumerate(documents):
logger.info("插入文档到知识图谱: %s, 进度: %d/%d", str(rag_file.file_name), idx + 1, len(documents)) # type: ignore
await rag_instance.ainsert(input=doc.page_content, file_paths=[file_path])
all_content = "\n\n".join(doc.page_content for doc in documents)
doc_id = str(rag_file.id)
logger.info("插入文档到知识图谱: %s, doc_id=%s, 文档数=%d", str(rag_file.file_name), doc_id, len(documents))
await rag_instance.ainsert(input=all_content, file_paths=[file_path], ids=doc_id)

await self._mark_success(db, file_repo, str(rag_file.id), len(documents)) # type: ignore
logger.info("文件 %s 知识图谱处理完成", str(rag_file.file_name))
doc_status_data = await rag_instance.doc_status.get_by_id(doc_id)
chunk_count = len(doc_status_data.get("chunks_list", [])) if doc_status_data else 0
await self._mark_success(db, file_repo, str(rag_file.id), chunk_count)
logger.info("文件 %s 知识图谱处理完成, 实际分块数: %d", str(rag_file.file_name), chunk_count)

except Exception as e:
logger.exception("文件 %s 知识图谱处理失败: %s", str(rag_file.file_name), e) # type: ignore
await self._mark_failed(db, file_repo, str(rag_file.id), str(e)) # type: ignore
logger.exception("文件 %s 知识图谱处理失败: %s", str(rag_file.file_name), e)
await self._mark_failed(db, file_repo, str(rag_file.id), str(e))

async def _process_single_file(
self,
Expand All @@ -182,12 +185,12 @@ async def _process_single_file(
file_repo = RagFileRepository(db)

try:
await self._update_status(db, file_repo, rag_file.id, FileStatus.PROCESSING, 5) # type: ignore
await self._update_status(db, file_repo, rag_file.id, FileStatus.PROCESSING, 5)
await db.commit()

file_path = get_file_path(rag_file)
if not file_path or not Path(file_path).exists():
await self._mark_failed(db, file_repo, rag_file.id, "文件不存在") # type: ignore
await self._mark_failed(db, file_repo, rag_file.id, "文件不存在")
return

base_metadata = MetadataBuilder.build_chunk_metadata(rag_file, knowledge_base)
Expand All @@ -201,39 +204,39 @@ async def _process_single_file(
)

if not chunks:
await self._mark_failed(db, file_repo, rag_file.id, "文档解析后未生成任何分块") # type: ignore
await self._mark_failed(db, file_repo, rag_file.id, "文档解析后未生成任何分块")
return

logger.info("文件 %s 分块完成,共 %d 个分块", rag_file.file_name, len(chunks))

valid_chunks = self._filter_and_clean_chunks(chunks, rag_file)
if not valid_chunks:
await self._mark_failed(db, file_repo, rag_file.id, "文件没有有效的分块内容") # type: ignore
await self._mark_failed(db, file_repo, rag_file.id, "文件没有有效的分块内容")
return

embedding = await self._get_embeddings(db, knowledge_base)
vectorstore = VectorStoreFactory.create(
collection_name=str(knowledge_base.name), # type: ignore
collection_name=str(knowledge_base.name),
embedding=embedding,
)

await self._update_progress(db, file_repo, rag_file.id, 60) # type: ignore
await self._update_progress(db, file_repo, rag_file.id, 60)
await db.commit()

MetadataBuilder.add_to_chunks(valid_chunks, {
"rag_file_id": str(rag_file.id), # type: ignore
"original_file_id": str(rag_file.file_id), # type: ignore
"knowledge_base_id": str(knowledge_base.id), # type: ignore
"rag_file_id": str(rag_file.id),
"original_file_id": str(rag_file.file_id),
"knowledge_base_id": str(knowledge_base.id),
})

await BatchProcessor.store_in_batches(vectorstore, valid_chunks)

await self._mark_success(db, file_repo, rag_file.id, len(valid_chunks)) # type: ignore
await self._mark_success(db, file_repo, rag_file.id, len(valid_chunks))
logger.info("文件 %s ETL 处理完成", rag_file.file_name)

except Exception as e:
logger.exception("文件 %s 处理失败: %s", rag_file.file_name, e)
await self._mark_failed(db, file_repo, rag_file.id, str(e)) # type: ignore
await self._mark_failed(db, file_repo, rag_file.id, str(e))

@staticmethod
def _filter_and_clean_chunks(chunks: list, rag_file: RagFile) -> list:
Expand All @@ -256,12 +259,12 @@ def _filter_and_clean_chunks(chunks: list, rag_file: RagFile) -> list:

@staticmethod
async def _get_embeddings(db: AsyncSession, knowledge_base: KnowledgeBase):
embedding_entity = await get_model_by_id(db, str(knowledge_base.embedding_model)) # type: ignore
embedding_entity = await get_model_by_id(db, str(knowledge_base.embedding_model))
if not embedding_entity:
raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}")

return EmbeddingFactory.create_embeddings(
model_name=str(embedding_entity.model_name), # type: ignore
model_name=str(embedding_entity.model_name),
base_url=getattr(embedding_entity, "base_url", None),
api_key=getattr(embedding_entity, "api_key", None),
)
Expand Down
Loading
Loading