调用

2026年06月01日 16:39

import os
import json
import torch
from pathlib import Path
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# ⚠️ 核心修复方案:强制 UTF-8 读取,防止 Windows 下读取中文报错
_original_read_text = Path.read_text


def _forced_utf8_read_text(self, *args, **kwargs):
if 'encoding' not in kwargs:
kwargs['encoding'] = 'utf-8'
return _original_read_text(self, *args, **kwargs)


Path.read_text = _forced_utf8_read_text

# ==================== 配置区 ====================
# 本地基座模型路径(确保和你下载的一致)
BASE_MODEL = "./models/qwen/Qwen2.5-7B-Instruct"
# 训练完成后 LoRA 权重的保存路径(改成你实际训练好的文件夹名)
LORA_PATH = "./qwen2.5-7b-sft-output/checkpoint-12"
# 聊天记录保存目录
CHAT_LOG_DIR = "./chat_logs"


# ================================================

def load_model():
print("📦 正在加载基座模型与分词器...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# 5070 Ti 显存够大,直接上 bfloat16 全精度,不要 4bit 量化!
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16
)

print(f"⚙️ 正在合并 LoRA 权重: {LORA_PATH}")
# 加载你训练好的川渝话术 LoRA
model = PeftModel.from_pretrained(model, LORA_PATH)
model.eval()

print(f"✅ 模型加载完成 | Device: {next(model.parameters()).device}")
return model, tokenizer


def chat(model, tokenizer, messages: list, max_new_tokens=512):
"""
使用 Qwen 官方的 apply_chat_template 处理多轮对话历史
"""
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)

# 截取生成的部分,去除输入的 Prompt
response_ids = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
return response


def save_history(history):
"""将当前历史记录保存为带时间戳的 JSON 文件"""
if not history or len(history) <= 1:
return
os.makedirs(CHAT_LOG_DIR, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join(CHAT_LOG_DIR, f"chat_{timestamp}.json")

with open(file_path, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
print(f"\n💾 聊天记录已成功保存至: {file_path}")


def load_latest_history():
"""尝试加载最新的聊天记录"""
if not os.path.exists(CHAT_LOG_DIR):
return None

files = sorted([f for f in os.listdir(CHAT_LOG_DIR) if f.endswith(".json")], reverse=True)
if not files:
return None

latest_file = os.path.join(CHAT_LOG_DIR, files[0])
try:
with open(latest_file, "r", encoding="utf-8") as f:
history = json.load(f)
print(f"📂 发现历史记录文件: {latest_file}")
choice = input("是否加载该记录继续对话?(y/n): ").strip().lower()
if choice == 'y':
return history
except Exception as e:
print(f"⚠️ 读取历史记录失败: {e}")
return None


if __name__ == "__main__":
model, tokenizer = load_model()

# 初始化 System Prompt(川渝 AI 的灵魂)
default_system_msg = {"role": "system",
"content": "你是 A Chen,由陈渝铨创造的川渝 AI。说话直爽,喜欢用老子吊毛等词汇。"}

# 检查是否有历史记录可加载
loaded_history = load_latest_history()
if loaded_history:
history = loaded_history
else:
history = [default_system_msg]

print("\n🎉 模型已就绪!进入【多轮交互模式】")
print("💡 指令提示:输入 'clear' 清空历史 | 'save' 手动保存 | 'exit' 退出并保存")
print("=" * 50)

try:
while True:
user_input = input("\n🧑 我: ").strip()
if not user_input:
continue

if user_input.lower() in ["exit", "quit", "0"]:
break

if user_input.lower() == "clear":
history.clear()
history.append(default_system_msg)
print("🗑️ 历史记忆已清空!")
continue

if user_input.lower() == "save":
save_history(history)
continue

# 执行多轮对话
history.append({"role": "user", "content": user_input})
response = chat(model, tokenizer, history)
history.append({"role": "assistant", "content": response})
print(f"🤖 A Chen: {response}")

except KeyboardInterrupt:
print("\n⚠️ 检测到强制中断 (Ctrl+C)")

finally:
# 无论正常退出还是中断,都确保记录被保存
save_history(history)
print("👋 进程已安全退出。")