def main_websocket():
args = set_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model = model.to(device)
model.eval()
if args.save_samples_path:
if not os.path.exists(args.save_samples_path):
os.makedirs(args.save_samples_path)
samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
samples_file.write("聊天记录{}:\n".format(datetime.now()))
# 存储聊天记录,每个utterance以token的id的形式进行存储
history = []
print('等待websocket链接...')
async def server(websocket, path):
# 这个函数将在有新的客户端连接时被调用
try:
async for message in websocket:
# 循环监听客户端发送的消息
text = message
print("收到消息:", text)
filtered_text = filter_risk_words(text, risk_words)
if text != filtered_text:
print("风险拦截:", text)
await websocket.send("".join("websocket 服务连接关闭!"))
await websocket.close()
break;
if args.save_samples_path:
samples_file.write("user:{}\n".format(text))
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
input_ids.extend(history_utr)
input_ids.append(tokenizer.sep_token_id)
input_ids = torch.tensor(input_ids).long().to(device)
input_ids = input_ids.unsqueeze(0)
response = [] # ���据context,生成的response
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=input_ids)
logits = outputs.logits
next_token_logits = logits[0, -1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(response):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
break
response.append(next_token.item())
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
# print("his_text:{}".format(his_text))
history.append(response)
text = tokenizer.convert_ids_to_tokens(response)
if args.save_samples_path:
print("回复消息:", "".join(text))
await websocket.send("".join(text))
except websockets.exceptions.ConnectionClosedError:
print("Connection closed by client")
except websockets.exceptions.ConnectionClosedOK:
print("Connection closed normally")
finally:
await websocket.close()
async def start_server():
try:
async with websockets.serve(server, "localhost", 8765):
print("Starting server at ws://localhost:8765")
await asyncio.Future() # run forever
except OSError as e:
print(f"Error starting server: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
asyncio.run(start_server());
# socket
def main_socket():
args = set_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model = model.to(device)
model.eval()
if args.save_samples_path:
if not os.path.exists(args.save_samples_path):
os.makedirs(args.save_samples_path)
samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
samples_file.write("聊天记录{}:\n".format(datetime.now()))
# 存储聊天记录,每个utterance以token的id的形式进行存储
history = []
print('开始和chatbot聊天,输入CTRL + Z以退出')
HOST = 'localhost'
PORT = 8765
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((HOST, PORT))
server_socket.listen(1)
print('等待客户端连接...')
client_socket, address = server_socket.accept()
while True:
try:
# 接收客户端发送的数据
data = client_socket.recv(1024)
if not data:
break
# 打印接收到的数据
print(f"接收到的数据: {data.decode('utf-8')}")
# text = request.values.get('text')
text = data.decode('utf-8')
if args.save_samples_path:
samples_file.write("user:{}\n".format(text))
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
input_ids.extend(history_utr)
input_ids.append(tokenizer.sep_token_id)
input_ids = torch.tensor(input_ids).long().to(device)
input_ids = input_ids.unsqueeze(0)
response = [] # 根据context,生成的response
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=input_ids)
logits = outputs.logits
next_token_logits = logits[0, -1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(response):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
break
response.append(next_token.item())
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
# print("his_text:{}".format(his_text))
history.append(response)
text = tokenizer.convert_ids_to_tokens(response)
if args.save_samples_path:
client_socket.send("ChatGpt-2: ".encode('utf-8') + "".join(text).encode('utf-8'))
# return json.dumps("chatbot:" + "".join(text), ensure_ascii=False)
等了一个小时都不动,也不报错
求解决!
这个程序实现了一个基于GPT-2的聊天机器人的后端服务,支持通过websocket和socket两种方式与机器人进行交互。它在等待前端接入,以执行相应的操作。
就好比你给大家提供咨询服务,正在你已经准备好了,就等客户上门,没有客户,你就安静地坐着,时刻准备着接待你的客户。