websocket NLP

运行一个对话系统时websocket一直在连接,动不了

def main_websocket():

    args = set_args()
    logger = create_logger(args)
    # 当用户使用GPU,并且GPU可用时
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    device = 'cuda' if args.cuda else 'cpu'
    logger.info('using device:{}'.format(device))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
    # tokenizer = BertTokenizer(vocab_file=args.voca_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model = model.to(device)
    model.eval()
    if args.save_samples_path:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
        samples_file.write("聊天记录{}:\n".format(datetime.now()))
    # 存储聊天记录,每个utterance以token的id的形式进行存储
    history = []
    print('等待websocket链接...')

    async def server(websocket, path):
        # 这个函数将在有新的客户端连接时被调用
       
        try:
            async for message in websocket:
                # 循环监听客户端发送的消息
                text = message
                print("收到消息:", text)

                filtered_text = filter_risk_words(text, risk_words)

                if text != filtered_text:
                    print("风险拦截:", text)
                    await websocket.send("".join("websocket 服务连接关闭!"))   
                    await websocket.close()
                    break;

                if args.save_samples_path:
                    samples_file.write("user:{}\n".format(text))
                text_ids = tokenizer.encode(text, add_special_tokens=False)
                history.append(text_ids)
                input_ids = [tokenizer.cls_token_id]  # 每个input以[CLS]为开头

                for history_id, history_utr in enumerate(history[-args.max_history_len:]):
                    input_ids.extend(history_utr)
                    input_ids.append(tokenizer.sep_token_id)
                input_ids = torch.tensor(input_ids).long().to(device)
                input_ids = input_ids.unsqueeze(0)
                response = []  # ���据context,生成的response
                # 最多生成max_len个token
                for _ in range(args.max_len):
                    outputs = model(input_ids=input_ids)
                    logits = outputs.logits
                    next_token_logits = logits[0, -1, :]
                    # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
                    for id in set(response):
                        next_token_logits[id] /= args.repetition_penalty
                    next_token_logits = next_token_logits / args.temperature
                    # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
                    next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
                    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
                    # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
                    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
                    if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明response生成结束
                        break
                    response.append(next_token.item())
                    input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
                    # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
                    # print("his_text:{}".format(his_text))
                history.append(response)
                text = tokenizer.convert_ids_to_tokens(response)
                if args.save_samples_path:
                    
                    print("回复消息:", "".join(text))
                    await websocket.send("".join(text))   
        except websockets.exceptions.ConnectionClosedError:
            print("Connection closed by client")
        except websockets.exceptions.ConnectionClosedOK:
            print("Connection closed normally")    
        finally:
            await websocket.close()

    async def start_server():
        try:
            async with websockets.serve(server, "localhost", 8765):
                print("Starting server at ws://localhost:8765")
                await asyncio.Future()  # run forever
        except OSError as e:
            print(f"Error starting server: {e}")
        except Exception as e:
            print(f"Unexpected error: {e}")

    asyncio.run(start_server());
    
# socket
def main_socket():

    args = set_args()
    logger = create_logger(args)
    # 当用户使用GPU,并且GPU可用时
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    device = 'cuda' if args.cuda else 'cpu'
    logger.info('using device:{}'.format(device))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
    # tokenizer = BertTokenizer(vocab_file=args.voca_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model = model.to(device)
    model.eval()
    if args.save_samples_path:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
        samples_file.write("聊天记录{}:\n".format(datetime.now()))
    # 存储聊天记录,每个utterance以token的id的形式进行存储
    history = []
    print('开始和chatbot聊天,输入CTRL + Z以退出')

    HOST = 'localhost'
    PORT = 8765

    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind((HOST, PORT))
    server_socket.listen(1)
    
    print('等待客户端连接...')
    client_socket, address = server_socket.accept()


    while True:
        try:
            # 接收客户端发送的数据
            data = client_socket.recv(1024)
            if not data:
                break
            # 打印接收到的数据
            print(f"接收到的数据: {data.decode('utf-8')}")
            # text = request.values.get('text')
            text = data.decode('utf-8')
            if args.save_samples_path:
                samples_file.write("user:{}\n".format(text))
            text_ids = tokenizer.encode(text, add_special_tokens=False)
            history.append(text_ids)
            input_ids = [tokenizer.cls_token_id]  # 每个input以[CLS]为开头

            for history_id, history_utr in enumerate(history[-args.max_history_len:]):
                input_ids.extend(history_utr)
                input_ids.append(tokenizer.sep_token_id)
            input_ids = torch.tensor(input_ids).long().to(device)
            input_ids = input_ids.unsqueeze(0)
            response = []  # 根据context,生成的response
            # 最多生成max_len个token
            for _ in range(args.max_len):
                outputs = model(input_ids=input_ids)
                logits = outputs.logits
                next_token_logits = logits[0, -1, :]
                # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
                for id in set(response):
                    next_token_logits[id] /= args.repetition_penalty
                next_token_logits = next_token_logits / args.temperature
                # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
                next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
                filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
                # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
                if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明response生成结束
                    break
                response.append(next_token.item())
                input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
                # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
                # print("his_text:{}".format(his_text))
            history.append(response)
            text = tokenizer.convert_ids_to_tokens(response)
            if args.save_samples_path:
                client_socket.send("ChatGpt-2: ".encode('utf-8') + "".join(text).encode('utf-8'))
                # return json.dumps("chatbot:" + "".join(text), ensure_ascii=False)

img

等了一个小时都不动,也不报错
求解决!

这个程序实现了一个基于GPT-2的聊天机器人的后端服务,支持通过websocket和socket两种方式与机器人进行交互。它在等待前端接入,以执行相应的操作。
就好比你给大家提供咨询服务,正在你已经准备好了,就等客户上门,没有客户,你就安静地坐着,时刻准备着接待你的客户。

不知道你这个问题是否已经解决, 如果还没有解决的话:

如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^