[Python] 纯文本查看 复制代码
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import tkinterdnd2 as tkdnd
import threading
import queue
import time
import os
import sys
import json
from pathlib import Path
import traceback
import pandas as pd
from pypdf import PdfReader
from docx import Document
from pptx import Presentation
try:
import bs4
from ebooklib import epub
import ebooklib
except:
pass
# 设置DPI感知(在高DPI屏幕上更清晰)
try:
from ctypes import windll
windll.shcore.SetProcessDpiAwareness(1)
except:
pass
import faiss
import numpy as np
from fastembed import TextEmbedding
# 设置国内镜像,解决下载慢的问题
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
MODEL_NAME = "BAAI/bge-small-zh-v1.5"
BASE_DIR = "." # 扫描当前目录
INDEX_FILE = "kb.index"
CHUNKS_FILE = "kb_chunks.npy"
ID_MAP_FILE = "kb_id_map.npy" # id 储存
FILE_STATE_FILE = "kb_file_state.json" # 文件状态记录
CHUNK_INFO_FILE = "kb_chunk_info.json" # chunk详细信息
# 也可设置 分割 size 和 重叠
def split_text(text, size=500, overlap=50):
"""分割文本为chunks"""
res = []
for i in range(0, len(text), size - overlap):
chunk = text[i:i + size].strip()
if len(chunk) > 20:
res.append(chunk)
return res
class FaissIncrementalIndex:
def __init__(self, dimension=128, use_l2=True, use_id_map=True):
"""
初始化FAISS增量索引
Args:
dimension: 向量维度
use_l2: 是否使用L2距离(True为L2,False为内积IP)
use_id_map: 是否使用ID映射以支持稳定ID
"""
self.dimension = dimension
self.use_l2 = use_l2
self.use_id_map = use_id_map
# 创建基础索引
if use_l2:
self.base_index = faiss.IndexFlatL2(dimension)
else:
self.base_index = faiss.IndexFlatIP(dimension)
# 包装为ID映射索引以支持稳定ID
if use_id_map:
self.index = faiss.IndexIDMap(self.base_index)
else:
self.index = self.base_index
# 跟踪已使用的ID(仅在use_id_map=True时有效)
self.used_ids = set()
self.next_id = 0
def _generate_ids(self, n_vectors):
"""生成唯一ID"""
ids = []
for _ in range(n_vectors):
while self.next_id in self.used_ids: #从头 取没有用的 id
self.next_id += 1
ids.append(self.next_id)
self.used_ids.add(self.next_id)
self.next_id += 1
return np.array(ids, dtype=np.int64)
def add_vectors(self, vectors, ids=None):
"""
添加向量到索引
Args:
vectors: 向量数组,形状为(n, dimension)
ids: 可选的ID数组,如果为None则自动生成
"""
vectors = np.asarray(vectors, dtype=np.float32)
if vectors.ndim != 2:
raise ValueError(f"vectors should be 2D, got {vectors.ndim}D")
if vectors.shape[1] != self.dimension:
raise ValueError(f"Vector dimension mismatch: expected {self.dimension}, got {vectors.shape[1]}")
if self.use_id_map:
if ids is None:
ids = self._generate_ids(len(vectors))
else:
ids = np.asarray(ids, dtype=np.int64)
# 检查ID是否已存在
for id_val in ids:
if id_val in self.used_ids:
raise ValueError(f"ID {id_val} already exists in index. Use update_vectors to update.")
self.used_ids.update(ids)
self.index.add_with_ids(vectors, ids)
print(f"Added {len(vectors)} vectors with IDs: {ids[:5]}{'...' if len(ids) > 5 else ''}")
else:
self.index.add(vectors)
print(f"Added {len(vectors)} vectors without ID mapping")
def update_vectors(self, vectors, ids):
"""
更新已存在的向量
Args:
vectors: 新的向量数组
ids: 要更新的ID数组
"""
if not self.use_id_map:
raise RuntimeError("Cannot update vectors without ID mapping. Initialize with use_id_map=True.")
vectors = np.asarray(vectors, dtype=np.float32)
ids = np.asarray(ids, dtype=np.int64)
if len(vectors) != len(ids):
raise ValueError(f"Number of vectors ({len(vectors)}) and IDs ({len(ids)}) must match")
# 检查ID是否存在
for id_val in ids:
if id_val not in self.used_ids:
raise ValueError(f"ID {id_val} not found in index")
# 移除旧向量,添加新向量
self.index.remove_ids(ids)
self.index.add_with_ids(vectors, ids)
print(f"Updated {len(vectors)} vectors with IDs: {ids}")
def remove_vectors(self, ids):
"""
从索引中移除向量
Args:
ids: 要移除的ID数组
"""
if not self.use_id_map:
raise RuntimeError("Cannot remove vectors without ID mapping. Initialize with use_id_map=True.")
ids = np.asarray(ids, dtype=np.int64)
# 从跟踪集合中移除ID
for id_val in ids:
self.used_ids.discard(id_val)
# 从索引中移除
removed_count = self.index.remove_ids(ids)
print(f"Removed {removed_count} vectors with IDs: {ids}")
def search(self, query_vectors, k=5):
"""
搜索最近的k个邻居
Args:
query_vectors: 查询向量
k: 返回的最近邻数量
Returns:
distances, indices
"""
query_vectors = np.asarray(query_vectors, dtype=np.float32)
if query_vectors.ndim == 1:
query_vectors = query_vectors.reshape(1, -1)
distances, indices = self.index.search(query_vectors, k)
return distances, indices
def save_index(self, filepath):
"""保存索引到文件"""
faiss.write_index(self.index, filepath)
print(f"Index saved to {filepath}")
def load_index(self, filepath):
"""从文件加载索引"""
self.index = faiss.read_index(filepath)
if self.use_id_map:
# 重新构建used_ids集合
self.used_ids.clear()
# 注意:FAISS没有直接获取所有ID的API,这里我们需要其他方式
# 在实际使用中,你可能需要单独存储ID列表
print(f"Loaded index from {filepath}. Note: used_ids set is empty after loading.")
else:
print(f"Loaded index from {filepath}")
def read_text_file(filename):
"""读取文本文件,尝试多种编码"""
encodings = ['utf-8', 'gb2312', 'gbk', 'gb18030', 'latin1']
for encoding in encodings:
try:
with open(filename, 'r', encoding=encoding) as f:
return f.read().strip()
except UnicodeDecodeError:
continue
# 如果所有编码都失败,使用错误处理
with open(filename, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
def load_pdf(path):
"""加载PDF文件"""
try:
reader = PdfReader(path)
return "\n".join([page.extract_text().strip() for page in reader.pages if page.extract_text()])
except:
return ""
def load_docx(path):
"""加载DOCX文件"""
try:
doc = Document(path)
return "\n".join([p.text.strip() for p in doc.paragraphs if p.text.strip()])
except:
return ""
def load_pptx(path):
"""加载PPTX文件"""
try:
prs = Presentation(path)
out = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text_frame"):
for p in shape.text_frame.paragraphs:
if p.text.strip():
out.append(p.text.strip())
return "\n".join(out)
except:
return ""
def load_xlsx(path):
"""加载XLSX文件"""
try:
return pd.read_excel(path).to_markdown(index=False)
except:
return ""
def load_epub(path):
"""加载EPUB文件"""
try:
book = epub.read_epub(path)
texts = []
for item in book.get_items_of_type(ebooklib.ITEM_DOCUMENT):
soup = bs4.BeautifulSoup(item.get_content(), 'html.parser')
texts.append(soup.get_text().strip())
return "\n".join(texts)
except:
return ""
def save_chunks_with_ids(chunks, chunk_infos, ids=None,previous_id_slot=None):
"""保存chunks和它们的元数据"""
# 保存chunks文本
np.save(CHUNKS_FILE, np.array(chunks, dtype=object))
chunk_infos.append({"left_last_time":previous_id_slot or []}) # 放在最后,保证性能
# 保存chunk元数据
with open(CHUNK_INFO_FILE, 'w', encoding='utf-8') as f:
json.dump(chunk_infos, f, ensure_ascii=False, indent=2)
# 如果有ID,保存ID映射 #第一次建立是有的
if ids is not None:
np.save(ID_MAP_FILE, np.array(ids, dtype=np.int64))
print(f"\u2705 保存了 {len(chunks)} 个chunks和元数据")
def load_chunks_with_ids():
"""加载chunks和它们的元数据"""
chunks = []
chunk_infos = []
ids = None
try:
if os.path.exists(CHUNKS_FILE):
chunks = np.load(CHUNKS_FILE, allow_pickle=True).tolist()
if os.path.exists(CHUNK_INFO_FILE):
with open(CHUNK_INFO_FILE, 'r', encoding='utf-8') as f:
*chunk_infos,previous_id_slot = json.load(f)
if os.path.exists(ID_MAP_FILE):
ids = np.load(ID_MAP_FILE, allow_pickle=True)
except Exception as e:
print(f"\u26a0\ufe0f 加载chunks数据失败: {e}")
return chunks, chunk_infos, ids,previous_id_slot["left_last_time"]
def update_chunks_incrementally(base_dir, selected_exts, existing_chunks, existing_chunk_infos, existing_ids,
previous_id_slot:list):
"""
增量更新chunks
返回: (新增chunks, 新增chunk_infos, 新增ids, 需要删除的ids)
"""
# 加载历史文件状态
if os.path.exists(FILE_STATE_FILE):
with open(FILE_STATE_FILE, 'r', encoding='utf-8') as f:
history_state = json.load(f)
else:
history_state = {}
# 扫描当前目录状态
current_state = {}
current_state[base_dir] = {}
for root, _, filenames in os.walk(base_dir):
for f in filenames:
ext = os.path.splitext(f)[1].lower()
if ext in selected_exts:
file_path = os.path.join(root, f)
# 使用相对路径作为key
rel_path = os.path.relpath(file_path, base_dir)
stat = os.stat(file_path)
current_state[base_dir][rel_path] = {
"mtime": stat.st_mtime,
"size": stat.st_size
}
# 检测变化
history_files = set(history_state[base_dir].keys())
current_files = set(current_state[base_dir].keys())
added = list(current_files - history_files)
deleted = list(history_files - current_files)
modified = []
for file in history_files & current_files:
if history_state[base_dir][file] != current_state[base_dir][file]:
modified.append(file)
print(f"\U0001f50d 检测到变化: 新增{len(added)}个, 修改{len(modified)}个, 删除{len(deleted)}个")
# 收集需要删除的chunk IDs
ids_to_remove = []
if existing_chunk_infos:
# 对于删除的文件,找到对应的所有chunk IDs
for file_path in deleted:
file_ids = [info["id"] for info in existing_chunk_infos if info.get("file_path") == file_path]
ids_to_remove.extend(file_ids)
print(f"\U0001f5d1\ufe0f 文件 {file_path} 被删除,将移除 {len(file_ids)} 个chunks")
# 对于修改的文件,先删除旧的,再添加新的
for file_path in modified:
file_ids = [info["id"] for info in existing_chunk_infos if info.get("file_path") == file_path]
ids_to_remove.extend(file_ids)
print(f"\u270f\ufe0f 文件 {file_path} 被修改,将更新 {len(file_ids)} 个chunks")
# 处理新增和修改的文件(修改的文件需要重新处理)
files_to_process = added + modified
new_chunks = []
new_chunk_infos = []
new_ids = []
# 先填满 ids_to_remove
if existing_ids is not None and len(existing_ids) > 0:
next_id = int(existing_ids.max()) + 1
else:
next_id = 0
previous_id_slot = previous_id_slot
previous_id_slot.extend(ids_to_remove.copy()) #浅拷贝
len_of_previous_id_slot = len(previous_id_slot)
for file_rel_path in files_to_process:
file_path = os.path.join(base_dir, file_rel_path)
ext = os.path.splitext(file_path)[1].lower()
# 加载文件内容
content = ""
if ext == '.pdf':
content = load_pdf(file_path)
elif ext == '.docx':
content = load_docx(file_path)
elif ext == '.pptx':
content = load_pptx(file_path)
elif ext == '.xlsx':
content = load_xlsx(file_path)
elif ext == '.epub':
content = load_epub(file_path)
elif ext == '.txt':
content = read_text_file(file_path)
if content:
file_name = os.path.basename(file_path)
chunk_list = split_text(content)
for chunk_text in chunk_list:
if len_of_previous_id_slot >0:
chunk_id = previous_id_slot.pop()
len_of_previous_id_slot -= 1
else:
chunk_id = next_id
next_id += 1
new_ids.append(chunk_id)
new_chunks.append(f"【来源:{file_name}】\n{chunk_text}")
new_chunk_infos.append({
"id": chunk_id,
"file_path": file_rel_path,
"file_name": file_name,
"chunk_index": len(new_chunks) - 1
})
# next_id += 1
# 保存新的文件状态
with open(FILE_STATE_FILE, 'w', encoding='utf-8') as f:
json.dump(current_state, f, ensure_ascii=False, indent=2)
return new_chunks, new_chunk_infos, new_ids, ids_to_remove,previous_id_slot
class EmbeddingApp:
def __init__(self, root):
self.root = root
self.root.title("智能知识库构建与搜索系统 (支持增量更新)")
self.root.geometry("1200x800")
# 设置图标和样式
self.setup_styles()
# 创建队列用于线程间通信
self.build_queue = queue.Queue()
self.search_queue = queue.Queue()
# 设置变量
self.base_dir = tk.StringVar(value=".") # 知识库根目录,准备扫描的
self.progress_var = tk.DoubleVar(value=0)
self.status_var = tk.StringVar(value="就绪")
self.search_mode = tk.StringVar(value="语义搜索")
# 支持的扩展名及其变量
self.extensions = {
'.pdf': tk.BooleanVar(value=True),
'.docx': tk.BooleanVar(value=True),
'.pptx': tk.BooleanVar(value=True),
'.xlsx': tk.BooleanVar(value=True),
'.epub': tk.BooleanVar(value=True),
'.txt': tk.BooleanVar(value=True)
}
# 初始化组件
self.setup_ui()
# 检查拖放支持
self.setup_drag_drop()
# 开始处理队列的循环
self.process_queue()
# 绑定窗口关闭事件
self.root.protocol("WM_DELETE_WINDOW", self.on_closing)
self.load_index_embeder() # 加载embeder和index,chunks
self.update_incremental_button_state() #检测知识库 + 切换增量更新按钮
def setup_styles(self):
"""设置应用样式"""
style = ttk.Style()
style.theme_use('clam')
# 自定义颜色
self.root.configure(bg='#f0f0f0')
# 配置样式
style.configure('Title.TLabel', font=('微软雅黑', 16, 'bold'), background="#ffffff")
style.configure('Subtitle.TLabel', font=('微软雅黑', 12), background="#ffffff")
style.configure('Status.TLabel', font=('微软雅黑', 10), background='#f0f0f0', foreground='#666666')
style.configure('Accent.TButton', font=('微软雅黑', 10), padding=10)
style.configure('Frame.TFrame', background='#ffffff', relief=tk.RAISED, borderwidth=1)
# 进度条样式
style.configure("Custom.Horizontal.TProgressbar",
troughcolor='#e0e0e0',
background='#4CAF50',
lightcolor='#4CAF50',
darkcolor='#4CAF50',
bordercolor='#e0e0e0',
borderwidth=1)
def setup_ui(self):
"""设置用户界面"""
# 创建主容器
main_container = tk.PanedWindow(self.root, orient=tk.HORIZONTAL, sashwidth=5, sashrelief=tk.RAISED)
main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 左侧:构建面板
build_frame = ttk.Frame(main_container, style='Frame.TFrame')
main_container.add(build_frame, minsize=400)
# 右侧:搜索面板
search_frame = ttk.Frame(main_container, style='Frame.TFrame')
main_container.add(search_frame, minsize=400)
# 构建左侧面板
self.setup_build_panel(build_frame)
# 构建右侧面板
self.setup_search_panel(search_frame)
# 底部状态栏
self.setup_status_bar()
def setup_build_panel(self, parent):
"""构建知识库面板"""
# 标题
title_label = ttk.Label(parent, text="\U0001f527 知识库构建 (支持增量更新)", style='Title.TLabel')
title_label.pack(pady=(15, 10))
# 目录选择区域
dir_frame = ttk.Frame(parent)
dir_frame.pack(fill=tk.X, padx=20, pady=5)
ttk.Label(dir_frame, text="扫描目录:").pack(side=tk.LEFT)
dir_entry = ttk.Entry(dir_frame, textvariable=self.base_dir, width=40)
dir_entry.pack(side=tk.LEFT, padx=5, fill=tk.X, expand=True)
browse_btn = ttk.Button(dir_frame, text="浏览...", command=self.browse_directory, width=10)
browse_btn.pack(side=tk.LEFT)
# 拖放区域
drop_frame = ttk.Frame(parent, relief=tk.SUNKEN, borderwidth=2)
drop_frame.pack(fill=tk.X, padx=20, pady=10, ipady=30)
self.drop_label = ttk.Label(drop_frame, text="\U0001f4c1 拖放文件夹到此处",
font=('微软雅黑', 12), foreground='#666666')
self.drop_label.pack(expand=True)
# 文件类型选择
type_frame = ttk.LabelFrame(parent, text="支持的文件类型", padding=10)
type_frame.pack(fill=tk.X, padx=20, pady=10)
# 创建两列复选框
col1_frame = ttk.Frame(type_frame)
col1_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
col2_frame = ttk.Frame(type_frame)
col2_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
ext_items = list(self.extensions.items())
mid = len(ext_items) // 2
for i, (ext, var) in enumerate(ext_items[:mid]):
cb = ttk.Checkbutton(col1_frame, text=ext, variable=var)
cb.pack(anchor=tk.W, pady=2)
for i, (ext, var) in enumerate(ext_items[mid:]):
cb = ttk.Checkbutton(col2_frame, text=ext, variable=var)
if ext in {".epub", ".txt"}:
var.set(False)
cb.pack(anchor=tk.W, pady=2)
# 构建按钮和进度条
build_frame = ttk.Frame(parent)
build_frame.pack(fill=tk.X, padx=20, pady=10)
# 按钮容器
button_container = ttk.Frame(build_frame)
button_container.pack(fill=tk.X)
# 增量更新按钮
self.incremental_btn = ttk.Button(
button_container,
text="增量更新",
command=self.start_incremental_thread,
state=tk.DISABLED # 初始禁用
)
self.incremental_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
# 构建按钮
self.build_btn = ttk.Button(
button_container,
text="开始构建知识库",
command=self.start_build_thread,
style='Accent.TButton'
)
self.build_btn.pack(side=tk.LEFT, fill=tk.X, expand=True)
# 进度条
self.progress_bar = ttk.Progressbar(
build_frame,
variable=self.progress_var,
style="Custom.Horizontal.TProgressbar"
)
self.progress_bar.pack(fill=tk.X, pady=5)
# 日志区域
log_label = ttk.Label(parent, text="构建日志:", style='Subtitle.TLabel')
log_label.pack(anchor=tk.W, padx=20, pady=(10, 5))
self.build_log = scrolledtext.ScrolledText(parent, height=12, font=('Consolas', 9))
self.build_log.pack(fill=tk.BOTH, expand=True, padx=20, pady=(0, 10))
# 清除日志按钮
clear_btn = ttk.Button(parent, text="清除日志", command=self.clear_build_log)
clear_btn.pack(anchor=tk.E, padx=20, pady=(0, 10))
def setup_search_panel(self, parent):
"""构建搜索面板"""
# 标题
title_label = ttk.Label(parent, text="\U0001f50d 知识库搜索", style='Title.TLabel')
title_label.pack(pady=(15, 10))
# 搜索模式选择
mode_frame = ttk.Frame(parent)
mode_frame.pack(fill=tk.X, padx=20, pady=5)
ttk.Label(mode_frame, text="搜索模式:").pack(side=tk.LEFT)
semantic_rb = ttk.Radiobutton(mode_frame, text="语义搜索",
variable=self.search_mode, value="语义搜索")
semantic_rb.pack(side=tk.LEFT, padx=10)
keyword_rb = ttk.Radiobutton(mode_frame, text="关键词搜索",
variable=self.search_mode, value="关键词搜索")
keyword_rb.pack(side=tk.LEFT)
# 搜索输入区域
search_frame = ttk.Frame(parent)
search_frame.pack(fill=tk.X, padx=20, pady=10)
self.search_entry = ttk.Entry(search_frame, font=('微软雅黑', 11))
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 10))
self.search_entry.bind('<Return>', lambda e: self.start_search_thread())
self.search_btn = ttk.Button(search_frame, text="搜索",
command=self.start_search_thread, width=10)
self.search_btn.pack(side=tk.LEFT)
# 搜索结果数量
result_frame = ttk.Frame(parent)
result_frame.pack(fill=tk.X, padx=20, pady=(5, 0))
ttk.Label(result_frame, text="显示结果数量:").pack(side=tk.LEFT)
self.result_count = tk.IntVar(value=5)
result_spin = ttk.Spinbox(result_frame, from_=1, to=20, width=5,
textvariable=self.result_count)
result_spin.pack(side=tk.LEFT, padx=5)
# 搜索结果区域
result_label = ttk.Label(parent, text="搜索结果:", style='Subtitle.TLabel')
result_label.pack(anchor=tk.W, padx=20, pady=(10, 5))
# 使用PanedWindow实现自适应结果区域
result_paned = tk.PanedWindow(parent, orient=tk.VERTICAL, sashwidth=3)
result_paned.pack(fill=tk.BOTH, expand=True, padx=20, pady=(0, 10))
# 列表区域
list_frame = ttk.Frame(result_paned)
result_paned.add(list_frame, minsize=100)
# 创建Treeview显示结果
columns = ('score', 'source', 'preview')
self.result_tree = ttk.Treeview(list_frame, columns=columns, show='tree headings', height=8)
# 设置列
self.result_tree.heading('#0', text='序号', anchor=tk.W)
self.result_tree.column('#0', width=50, stretch=False)
self.result_tree.heading('score', text='相关性', anchor=tk.W)
self.result_tree.column('score', width=80, stretch=False)
self.result_tree.heading('source', text='来源文件', anchor=tk.W)
self.result_tree.column('source', width=150, stretch=False)
self.result_tree.heading('preview', text='内容预览', anchor=tk.W)
# 添加滚动条
scrollbar = ttk.Scrollbar(list_frame, orient=tk.VERTICAL, command=self.result_tree.yview)
self.result_tree.configure(yscrollcommand=scrollbar.set)
self.result_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
# 绑定选择事件
self.result_tree.bind('<<TreeviewSelect>>', self.on_result_select)
# 详情区域
detail_frame = ttk.Frame(result_paned)
result_paned.add(detail_frame, minsize=100)
detail_label = ttk.Label(detail_frame, text="详细内容:", style='Subtitle.TLabel')
detail_label.pack(anchor=tk.W, pady=(5, 5))
self.detail_text = scrolledtext.ScrolledText(detail_frame, font=('微软雅黑', 10))
self.detail_text.pack(fill=tk.BOTH, expand=True)
# 复制按钮
copy_btn = ttk.Button(detail_frame, text="复制内容", command=self.copy_detail)
copy_btn.pack(anchor=tk.E, pady=(5, 0))
def setup_status_bar(self):
"""设置状态栏"""
status_bar = ttk.Frame(self.root, relief=tk.SUNKEN, borderwidth=1)
status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 状态标签
status_label = ttk.Label(status_bar, textvariable=self.status_var,
style='Status.TLabel')
status_label.pack(side=tk.LEFT, padx=10)
# 知识库状态
self.kb_status = tk.StringVar(value="知识库: 未构建")
kb_label = ttk.Label(status_bar, textvariable=self.kb_status,
style='Status.TLabel')
kb_label.pack(side=tk.RIGHT, padx=10)
def setup_drag_drop(self):
"""设置拖放功能"""
try:
# 注册拖放
self.root.drop_target_register(tkdnd.DND_FILES)
self.root.dnd_bind('<<Drop>>', self.on_drop)
except Exception as e:
self.log_build(f"\u26a0\ufe0f 拖放功能初始化失败: {str(e)}")
def on_drop(self, event):
"""处理拖放事件"""
try:
# 解析拖放的文件列表
if hasattr(event, 'data'):
files = self.root.tk.splitlist(event.data)
else:
files = [event.data]
# 只处理第一个文件夹
if files:
path = files[0].strip('{}')
if os.path.isdir(path):
self.base_dir.set(path)
self.log_build(f"\U0001f4c2 已选择文件夹: {path}")
# 更新增量更新按钮状态
self.update_incremental_button_state(load_file_by_user=True) #不和drag ,
else:
self.log_build(f"\u26a0\ufe0f 请拖放文件夹而不是文件: {path}")
except Exception as e:
self.log_build(f"\u274c 拖放处理错误: {str(e)}")
def browse_directory(self):
"""浏览选择目录"""
directory = filedialog.askdirectory(title="选择扫描目录", initialdir=self.base_dir.get())
if directory:
self.base_dir.set(directory)
self.log_build(f"\U0001f4c2 已选择文件夹: {directory}")
# 更新增量更新按钮状态
self.update_incremental_button_state(load_file_by_user=True)
def update_incremental_button_state(self,load_file_by_user:bool=False):
"""更新增量更新按钮状态"""
# 检查是否存在索引文件和ID映射文件
has_index = os.path.exists(INDEX_FILE) and os.path.exists(ID_MAP_FILE)
has_chunks = os.path.exists(CHUNKS_FILE) and os.path.exists(CHUNK_INFO_FILE)
if has_index and has_chunks:
if os.path.exists(FILE_STATE_FILE):
# 加载历史文件状态
with open(FILE_STATE_FILE, 'r', encoding='utf-8') as f:
history_state = json.load(f)
history_state_root_dir = list(history_state.keys())[0]
if load_file_by_user:
if self.base_dir.get() != history_state_root_dir:
self.incremental_btn.config(state=tk.DISABLED)
return
self.base_dir.set(history_state_root_dir)
self.incremental_btn.config(state=tk.NORMAL)
self.log_build("\u2705 检测到已有知识库,增量更新功能已启用")
else:
self.incremental_btn.config(state=tk.DISABLED)
self.log_build("\u26a0\ufe0f 未检测到完整知识库,请先构建知识库")
def start_build_thread(self):
"""启动构建线程"""
if not self.base_dir.get():
messagebox.showwarning("警告", "请先选择扫描目录!")
return
# 检查是否选择了至少一种文件类型
selected_exts = [ext for ext, var in self.extensions.items() if var.get()]
if not selected_exts:
messagebox.showwarning("警告", "请至少选择一种文件类型!")
return
# 禁用构建按钮和增量更新按钮
self.build_btn.config(state=tk.DISABLED, text="构建中...")
self.incremental_btn.config(state=tk.DISABLED)
self.status_var.set("正在构建知识库...")
# 启动构建线程
thread = threading.Thread(target=self.build_knowledge_base,
args=(selected_exts,), daemon=True)
thread.start()
def start_incremental_thread(self):
"""启动增量更新线程"""
if not self.base_dir.get():
messagebox.showwarning("警告", "请先选择扫描目录!")
return
# 检查是否选择了至少一种文件类型
selected_exts = [ext for ext, var in self.extensions.items() if var.get()]
if not selected_exts:
messagebox.showwarning("警告", "请至少选择一种文件类型!")
return
# 检查是否存在知识库
if not (os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE)):
messagebox.showwarning("警告", "请先构建知识库!")
return
# 禁用构建按钮和增量更新按钮
self.build_btn.config(state=tk.DISABLED)
self.incremental_btn.config(state=tk.DISABLED, text="增量更新中...")
self.status_var.set("正在增量更新知识库...")
# 启动增量更新线程
thread = threading.Thread(target=self.incremental_update_knowledge_base,
args=(selected_exts,), daemon=True)
thread.start()
def build_knowledge_base(self, selected_exts):
"""构建知识库(全量构建)"""
try:
# 保存当前支持的扩展
BASE_DIR = self.base_dir.get()
self.build_queue.put(('progress', 10))
self.build_queue.put(('log', f"\U0001f4c2 开始扫描目录: {BASE_DIR}"))
# 扫描文件并生成chunks
chunks = []
chunk_infos = []
chunk_id = 0
for root, _, filenames in os.walk(BASE_DIR):
for f in filenames:
ext = os.path.splitext(f)[1].lower()
if ext in selected_exts:
file_path = os.path.join(root, f)
rel_path = os.path.relpath(file_path, BASE_DIR)
# 加载文件内容
content = self.load_file_content(file_path, ext)
if content:
file_name = os.path.basename(file_path)
chunk_list = split_text(content)
for chunk_text in chunk_list:
chunks.append(f"【来源:{file_name}】\n{chunk_text}")
chunk_infos.append({
"id": chunk_id,
"file_path": rel_path,
"file_name": file_name,
"chunk_index": len(chunks) - 1
})
chunk_id += 1
if not chunks:
self.build_queue.put(('error', "没有找到有效内容!"))
return
self.build_queue.put(('progress', 30))
self.build_queue.put(('log', f"\U0001f4e6 生成 {len(chunks)} 个文本块"))
# 保存文件状态
current_state = {}
current_state[self.base_dir.get()] = {}
for root, _, filenames in os.walk(BASE_DIR):
for f in filenames:
ext = os.path.splitext(f)[1].lower()
if ext in selected_exts:
file_path = os.path.join(root, f)
rel_path = os.path.relpath(file_path, BASE_DIR)
stat = os.stat(file_path)
current_state[self.base_dir.get()][rel_path] = {
"mtime": stat.st_mtime,
"size": stat.st_size
}
with open(FILE_STATE_FILE, 'w', encoding='utf-8') as f:
json.dump(current_state, f, ensure_ascii=False, indent=2)
self.build_queue.put(('progress', 40))
self.build_queue.put(('log', "\U0001f4dd 保存文件状态记录"))
# 初始化模型
self.build_queue.put(('log', "\U0001f680 正在加载嵌入模型..."))
model_name = MODEL_NAME
try:
embedder = TextEmbedding(model_name=model_name)
except Exception as e:
self.build_queue.put(('log', f"\u26a0\ufe0f 模型加载失败: {e}"))
self.build_queue.put(('log', "尝试清理缓存并重新下载..."))
import shutil
cache_dir = os.path.join(os.getcwd(), ".fastembed")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
embedder = TextEmbedding(model_name=model_name)
# 向量化
self.build_queue.put(('log', "正在将文本转化为向量..."))
embeddings_generator = embedder.embed(chunks)
embeddings_np = np.array(list(embeddings_generator)).astype('float32')
# 创建增量索引
dim = embedder.embedding_size
index = FaissIncrementalIndex(dimension=dim, use_l2=False, use_id_map=True)
# 生成ID
ids = np.arange(len(chunks), dtype=np.int64)
# 添加向量到索引
index.add_vectors(embeddings_np, ids)
# 保存索引
index.save_index(INDEX_FILE)
# 保存chunks和ID映射
save_chunks_with_ids(chunks, chunk_infos, ids)
# 重置缓存
self.index = self.chunks = self.embedder = None
self.build_queue.put(('progress', 100))
self.build_queue.put(('log', "\u2705 知识库构建完成!"))
self.build_queue.put(('complete', len(chunks))) #切换 增量更新按钮
except Exception as e:
self.build_queue.put(('error', f"构建失败: {str(e)}\n{traceback.format_exc()}"))
def incremental_update_knowledge_base(self, selected_exts):
"""增量更新知识库"""
try:
BASE_DIR = self.base_dir.get()
self.build_queue.put(('progress', 10))
self.build_queue.put(('log', "\U0001f504 开始增量更新知识库..."))
# 加载现有的chunks和元数据
existing_chunks, existing_chunk_infos, existing_ids, previous_id_slot= load_chunks_with_ids()
if not existing_chunks:
self.build_queue.put(('error', "无法加载现有知识库数据!(知识库数据为空,请重新建立)"))
return
self.build_queue.put(('progress', 20))
self.build_queue.put(('log', f"\U0001f4ca 现有知识库: {len(existing_chunks)} 个chunks"))
# 检测变化并获取需要更新的数据
new_chunks, new_chunk_infos, new_ids, ids_to_remove,previous_id_slot_ = update_chunks_incrementally(
BASE_DIR, selected_exts, existing_chunks, existing_chunk_infos, existing_ids,previous_id_slot
)
self.build_queue.put(('progress', 40))
self.build_queue.put(('log', f"\U0001f4c8 检测到 {len(new_chunks)} 个新增/修改的chunks"))
self.build_queue.put(('log', f"\U0001f5d1\ufe0f 需要删除 {len(ids_to_remove)} 个chunks"))
if not new_chunks and not ids_to_remove:
self.build_queue.put(('progress', 100))
self.build_queue.put(('log', "\u2705 没有检测到变化,知识库已是最新"))
self.build_queue.put(('complete', len(existing_chunks)))
return
# 加载模型
self.build_queue.put(('log', "\U0001f680 正在加载嵌入模型..."))
embedder = TextEmbedding(model_name=MODEL_NAME)
# 加载增量索引
if os.path.exists(INDEX_FILE):
incremental_index = FaissIncrementalIndex(dimension=embedder.embedding_size, use_l2=False, use_id_map=True)
incremental_index.load_index(INDEX_FILE)
# 注意:需要重新构建used_ids集合
# 在实际应用中,需要保存和加载used_ids
else:
self.build_queue.put(('error', "索引文件不存在!"))
return
self.build_queue.put(('progress', 60))
# 处理删除
if ids_to_remove:
self.build_queue.put(('log', f"正在删除 {len(ids_to_remove)} 个chunks..."))
incremental_index.remove_vectors(np.array(ids_to_remove, dtype=np.int64))
# 从现有数据中删除
ids_to_remove_set = set(ids_to_remove)
updated_chunks = [chunk for i, chunk in enumerate(existing_chunks)
if existing_chunk_infos[i]["id"] not in ids_to_remove_set]
updated_chunk_infos = [info for info in existing_chunk_infos
if info["id"] not in ids_to_remove_set]
updated_ids = [id_val for id_val in existing_ids if id_val not in ids_to_remove_set]
else:
updated_chunks = existing_chunks
updated_chunk_infos = existing_chunk_infos
updated_ids = existing_ids
self.build_queue.put(('progress', 70))
# 处理新增
if new_chunks:
self.build_queue.put(('log', f"正在处理 {len(new_chunks)} 个新增chunks..."))
# 向量化新chunks
embeddings_generator = embedder.embed(new_chunks)
new_embeddings = np.array(list(embeddings_generator)).astype('float32')
# 添加到索引
incremental_index.add_vectors(new_embeddings, new_ids)
# 更新数据
updated_chunks.extend(new_chunks)
updated_chunk_infos.extend(new_chunk_infos)
if updated_ids is not None:
updated_ids = np.concatenate([updated_ids, new_ids])
else:
updated_ids = new_ids
self.build_queue.put(('progress', 80))
self.build_queue.put(('log', "\U0001f4be 保存更新后的数据..."))
# 保存更新后的索引
incremental_index.save_index(INDEX_FILE)
# 保存更新后的chunks和ID映射
save_chunks_with_ids(updated_chunks, updated_chunk_infos, updated_ids,previous_id_slot_)
# 重置缓存
self.index = self.chunks = self.embedder = None
self.build_queue.put(('progress', 100))
self.build_queue.put(('log', f"\u2705 增量更新完成!知识库现有 {len(updated_chunks)} 个chunks"))
self.build_queue.put(('complete', len(updated_chunks)))
except Exception as e:
self.build_queue.put(('error', f"增量更新失败: {str(e)}\n{traceback.format_exc()}"))
def scan_files_with_filter(self, selected_exts):
"""带过滤的扫描文件函数"""
chunks = []
for root, _, filenames in os.walk(self.base_dir.get()):
for f in filenames:
ext = os.path.splitext(f)[1].lower()
if ext in selected_exts:
file_path = os.path.join(root, f)
# 加载文件内容
content = self.load_file_content(file_path, ext)
if content:
file_name = os.path.basename(file_path)
chunk_list = split_text(content)
chunks.extend([f"【来源:{file_name}】\n{c}" for c in chunk_list])
return chunks
def load_file_content(self, file_path, ext):
"""加载文件内容"""
try:
if ext == '.pdf':
return load_pdf(file_path)
elif ext == '.docx':
return load_docx(file_path)
elif ext == '.pptx':
return load_pptx(file_path)
elif ext == '.xlsx':
return load_xlsx(file_path)
elif ext == '.epub':
return load_epub(file_path)
elif ext == '.txt':
return read_text_file(file_path)
except Exception as e:
self.log_build(f"\u274c 读取文件失败 {file_path}: {str(e)}")
return ""
def start_search_thread(self):
"""启动搜索线程"""
query = self.search_entry.get().strip()
if not query:
messagebox.showwarning("警告", "请输入搜索内容!")
return
if not os.path.exists(INDEX_FILE):
messagebox.showwarning("警告", "请先构建知识库!")
return
# 禁用搜索按钮
self.search_btn.config(state=tk.DISABLED, text="搜索中...")
self.status_var.set("正在搜索...")
# 清空之前的搜索结果
self.result_tree.delete(*self.result_tree.get_children())
self.detail_text.delete(1.0, tk.END)
# index是否存在
if not self.index:
self.load_index_embeder()
k = self.result_count.get()
# 启动搜索线程
thread = threading.Thread(target=self.search_knowledge_base,
args=(query,k), daemon=True)
thread.start()
def load_index_embeder(self):
"""加载索引、嵌入器和chunks"""
self.index = self.chunks = self.embedder = None
try:
# 加载索引
if os.path.exists(INDEX_FILE):
self.index = faiss.read_index(INDEX_FILE)
if os.path.exists(CHUNKS_FILE):
# 加载chunks和ID映射
self.chunks, self.chunk_infos, self.ids, _ = load_chunks_with_ids()
# 建立ID到chunk的映射
self.id_to_chunk_map = {}
if self.chunk_infos and len(self.chunks) == len(self.chunk_infos):
for info, chunk in zip(self.chunk_infos, self.chunks):
self.id_to_chunk_map[info["id"]] = chunk
# 加载模型
self.embedder = TextEmbedding(model_name=MODEL_NAME)
except Exception as e:
self.log_build(f"\u26a0\ufe0f 加载索引失败: {e}")
def search_knowledge_base(self, query,top_K:int):
"""搜索知识库(在后台线程中运行)"""
try:
# 生成查询向量
query_vec = list(self.embedder.embed([query]))[0]
# 搜索
D, I = self.index.search(np.array([query_vec], dtype="float32"), k=top_K)
# 准备结果
results = []
for i, (score, idx) in enumerate(zip(D[0], I[0]), 1):
if idx >= 0:
# 使用ID到chunk的映射(如果有的话)
if hasattr(self, 'id_to_chunk_map') and idx in self.id_to_chunk_map:
chunk = self.id_to_chunk_map[idx]
else:
# 如果没有映射,假设ID就是列表索引
if idx < len(self.chunks):
chunk = self.chunks[idx]
else:
# ID超出范围,跳过
continue
# 提取来源和内容
lines = chunk.split('\n', 1)
source = lines[0].replace('【来源:', '').strip('】') if len(lines) > 0 else "未知"
content = lines[1] if len(lines) > 1 else chunk
# 缩短预览
preview = content[:500]
results.append({
'index': i,
'score': f"{score:.3f}",
'source': source,
'preview': preview,
'full_content': chunk
})
self.search_queue.put(('results', results))
except Exception as e:
self.search_queue.put(('error', f"搜索失败: {str(e)}"))
def on_result_select(self, event):
"""处理结果选择事件"""
selection = self.result_tree.selection()
if selection:
item = self.result_tree.item(selection[0])
values = item['values']
# 在详情区域显示完整内容
self.detail_text.delete(1.0, tk.END)
# 查找完整内容
if values:
self.detail_text.insert(tk.END, f"来源: {values[1]}\n")
self.detail_text.insert(tk.END, f"相关性: {values[0]}\n")
self.detail_text.insert(tk.END, "\n内容:\n")
self.detail_text.insert(tk.END, values[2])
def copy_detail(self):
"""复制详情内容到剪贴板"""
content = self.detail_text.get(1.0, tk.END)
if content.strip():
self.root.clipboard_clear()
self.root.clipboard_append(content.strip())
messagebox.showinfo("成功", "内容已复制到剪贴板!")
def clear_build_log(self):
"""清除构建日志"""
self.build_log.delete(1.0, tk.END)
def log_build(self, message):
"""记录构建日志(线程安全)"""
timestamp = time.strftime("%H:%M:%S")
self.build_log.insert(tk.END, f"[{timestamp}] {message}\n")
self.build_log.see(tk.END)
self.root.update_idletasks()
def process_queue(self):
"""处理线程队列"""
# 处理构建队列
try:
while True:
msg_type, data = self.build_queue.get_nowait()
if msg_type == 'progress':
self.progress_var.set(data)
elif msg_type == 'log':
self.log_build(data)
elif msg_type == 'complete':
self.progress_var.set(100)
self.build_btn.config(state=tk.NORMAL, text="开始构建知识库")
self.incremental_btn.config(state=tk.NORMAL, text="增量更新")
self.status_var.set("构建完成")
self.kb_status.set(f"知识库: {data}个文档")
messagebox.showinfo("成功", f"知识库构建完成!共处理{data}个文本块。")
# 更新增量更新按钮状态
self.update_incremental_button_state()
elif msg_type == 'error':
self.build_btn.config(state=tk.NORMAL, text="开始构建知识库")
self.incremental_btn.config(state=tk.NORMAL, text="增量更新")
self.status_var.set("构建失败")
messagebox.showerror("错误", data)
except queue.Empty:
pass
# 处理搜索队列
try:
while True:
msg_type, data = self.search_queue.get_nowait()
if msg_type == 'results':
# 显示结果
self.result_tree.delete(*self.result_tree.get_children())
for result in data:
self.result_tree.insert('', tk.END,
text=str(result['index']),
values=(result['score'],
result['source'],
result['preview']))
self.search_btn.config(state=tk.NORMAL, text="搜索")
self.status_var.set(f"找到 {len(data)} 个结果")
elif msg_type == 'error':
self.search_btn.config(state=tk.NORMAL, text="搜索")
self.status_var.set("搜索失败")
messagebox.showerror("错误", data)
except queue.Empty:
pass
# 每100ms检查一次队列
self.root.after(100, self.process_queue)
def on_closing(self):
"""关闭应用程序"""
if messagebox.askokcancel("退出", "确定要退出应用程序吗?"):
self.root.destroy()
sys.exit()
def main():
"""主函数"""
# 创建主窗口
root = tkdnd.TkinterDnD.Tk() # 使用TkinterDnD的Tk
# 创建应用
app = EmbeddingApp(root)
# 启动主循环
root.mainloop()
if __name__ == "__main__":
main()