跳转至

核心代码

其实就一个代码文件 main.py

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import random
import json
import os
from torch.utils.data import TensorDataset, DataLoader, random_split

def read_json(file_path):
    ''' 读取 json 文件 '''
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def write_json(data, path):
    ''' 写入 json 文件 '''
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False)

def calculate_recall_at_k(result, gold, k):
    result = read_json(result)
    gold = read_json(gold)
    total_queries = len(result)
    recall = 0

    for r, g in zip(result, gold):
        correct_queries = 0
        # 去除开始、结束字符
        relevant_facts = [e['fact_input_list'][1:-1] for e in g['evidence_list']]
        top_k_facts = [e['fact_input_list'][1:-1] for e in r['evidence_list'][:k]]
        # 转换成字符串
        relevant_facts = [' '.join(map(str, sublist)) for sublist in relevant_facts]
        top_k_facts = [' '.join(map(str, sublist)) for sublist in top_k_facts]

        if relevant_facts:  # 检查relevant_facts是否为空
            for answer in relevant_facts:
                if any(answer in fact for fact in top_k_facts):
                    correct_queries += 1
            recall += correct_queries / len(relevant_facts)  # 使用relevant_facts的长度来归一化

    recall_at_k = recall / total_queries
    return recall_at_k

def calculate_mrr_at_k(result, gold, k):
    result = read_json(result)
    gold = read_json(gold)
    total_queries = len(result)
    mrr = 0

    for r, g in zip(result, gold):
        reciprocal_ranks = []
        relevant_facts = [e['fact_input_list'][1:-1] for e in g['evidence_list']]
        top_k_facts = [e['fact_input_list'][1:-1] for e in r['evidence_list'][:k]]
        relevant_facts = [' '.join(map(str, sublist)) for sublist in relevant_facts]
        top_k_facts = [' '.join(map(str, sublist)) for sublist in top_k_facts]

        if relevant_facts:  # 检查relevant_facts是否为空
            for answer in relevant_facts:
                for i, fact in enumerate(top_k_facts):
                    if answer in fact:
                        reciprocal_ranks.append(1 / (i + 1))  # 计算倒数排名并添加到列表中
                        break  # 找到第一个匹配的句子后停止搜索
            mrr += sum(reciprocal_ranks) / len(relevant_facts)

    mrr_at_k = mrr / total_queries
    return mrr_at_k

def zip_fun():
    path=os.getcwd()
    newpath=path+"/output/"
    os.chdir(newpath)
    os.system('zip prediction.zip result.json')
    os.chdir(path)

def preprocess_tokens(data):
    return [' '.join(map(str, item['fact_input_list'][1:-1])) for item in data]

def query2docBytoken(query, document):
    query_tokens_list = [preprocess_tokens(q['evidence_list']) for q in query]
    document_tokens_list = preprocess_tokens(document)

    query_doc_pairs = [(i, j) 
                       for i, q in enumerate(query_tokens_list) 
                       for j, d in enumerate(document_tokens_list) 
                       if any(q_token in d for q_token in q)]

    with open('query2doc.txt', 'w') as f:
        f.writelines(f'{item}\n' for item in query_doc_pairs)

class ModelTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_heads, dropout):
        super(ModelTransformer, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout

        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.transformer = nn.Transformer(hidden_dim, num_heads, num_layers, num_layers, dim_feedforward=hidden_dim*4, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # Transformer expects input of shape (seq_len, batch_size, hidden_dim)
        transformer_out = self.transformer(x, x)
        transformer_out = transformer_out.permute(1, 0, 2)  # Convert back to (batch_size, seq_len, hidden_dim)
        out = self.fc(transformer_out[:, -1, :])  # Use the output from the last time step
        return out

class ModelMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers):
        super(ModelMLP, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        for layer in self.layers:
            x = nn.functional.relu(layer(x))

        x = self.output_layer(x)
        return x

    def initialize_weights(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight.data)
                nn.init.zeros_(layer.bias.data)

def retrieve_top_k_documents(query_embedding, document_embeddings, k=3):
    # from baseline
    similarities = torch.nn.functional.cosine_similarity(query_embedding.unsqueeze(0), document_embeddings, dim=-1)
    # 使用topk获取排序后的索引,然后选择前k个最大的相似度值对应的document索引
    _, top_document_indices = similarities.topk(k)
    return top_document_indices.tolist()

def search_relevant_doc_indices(model, query_embedding, doc_embeddings, k, device):
    # 找k个最有可能是相关的文档
    query_embedding = query_embedding.unsqueeze(0).repeat(len(doc_embeddings), 1).to(device)
    doc_embeddings = doc_embeddings.to(device)
    data = torch.cat([query_embedding, doc_embeddings], dim=1)

    model = model.to(device)
    out = model(data)
    # 取出预测为1(相关)的概率
    similarities = out[:, 1] 
    _, top_k_indices = similarities.topk(k)
    return top_k_indices.tolist()

def inference(model, test_query, doc_embeddings, document, device, model_predict, k1=10):
    results = []
    for item in tqdm(test_query):
        result = {}
        query_embedding = torch.tensor(item['query_embedding']).to(device)
        model.eval()
        model_predict.eval()
        with torch.no_grad():
            top_relevant_document_indices = search_relevant_doc_indices(model, query_embedding, doc_embeddings, k1, device)

        result['query_input_list'] = item['query_input_list']

        doc_embeddings_relevant = doc_embeddings[top_relevant_document_indices]
        query_embedding_pred = model_predict(query_embedding.unsqueeze(0).unsqueeze(1)).squeeze(0)

        top3_docs_indices = retrieve_top_k_documents(query_embedding_pred, doc_embeddings_relevant, k=3)
        rerank_indices = [top_relevant_document_indices[i] for i in top3_docs_indices]


        result['evidence_list'] = [{'fact_input_list': document[index]['fact_input_list']} for index in rerank_indices]
        results.append(result)
    return results

# 主程序正式开始
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
is_load = True

# step 1 load data
# query_trainset(2044条) document(26599条)
query = read_json('input/query_trainset.json')
query_embeddings = torch.tensor([entry['query_embedding'] for entry in query], device=device)
document = read_json('input/document.json')
doc_embeddings = torch.tensor([entry['facts_embedding'] for entry in document], device=device)

# step 2 train model_relevant to search relevant documents
# 判断embedding组合是不是相关的
# step 2.1 data preprocessing
# step 2.1.1 matched data source 1
# 直接就是query的query_embedding和evidence中各个的fact_embedding拼接起来
positive_num1 = sum(len(q['evidence_list']) for q in query)
data1 = torch.empty((positive_num1, 2048)).to(device)
label1 = torch.ones((positive_num1, 1)).to(device)
data1 = torch.stack([
    torch.cat((
        torch.tensor(q['query_embedding'], device=device), 
        torch.tensor(e['fact_embedding'], device=device)
    ))
    for q in query
    for e in q['evidence_list']
])

# step 2.1.2 matched data source 2
# 通过query_trainset的evidence_list之token反向查询document的embedding
query_doc_pairs = []
if not os.path.exists('query2doc.txt'):
    query2docBytoken(query, document)

with open('query2doc.txt', 'r') as f:
    query_doc_pairs = [tuple(map(int, line.strip('()\n').split(','))) for line in f]

positive_num2 = len(query_doc_pairs)
data2 = torch.empty((positive_num2, 2048)).to(device)
label2 = torch.ones((positive_num2, 1)).to(device)
data2 = torch.stack([
    torch.cat((
        query_embeddings[i],
        doc_embeddings[j]
    ))
    for i,j in query_doc_pairs
])

# step 2.2 model_relevant hyper parameters
# MLP
input_dim = 2048
num_classes = 2
hidden_dim = 512
num_layers = 3
criterion = nn.CrossEntropyLoss().to(device)
learning_rate = 5e-4
model = ModelMLP(input_dim, hidden_dim, num_classes, num_layers).to(device)
model.initialize_weights()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

if is_load and os.path.exists('model_relevant.pth'):
    model.load_state_dict(torch.load("model_relevant.pth"))
    print('model_relevant is loaded.')
else:
    # step 2.3 train model_relevant
    epochs = 1000
    batch_size = 256
    negative_num = (positive_num1 + positive_num2) * 2
    train_size = int(0.9 * (negative_num + positive_num1 + positive_num2))
    val_size = (negative_num + positive_num1 + positive_num2) - train_size

    for e in range(epochs):
        # step 2.3.1 create unmatched data through random sampling
        # random negative
        data3 = torch.empty((negative_num, 2048)).to(device)
        label3 = torch.zeros((negative_num, 1)).to(device)
        data3 = torch.stack([
            torch.cat((
                query_embeddings[random.randint(0, 2043)],
                doc_embeddings[random.randint(0, 26598)]
            ))
            for _ in range(negative_num)
        ])

        data = torch.cat((data1, data2, data3), dim=0)
        label = torch.cat((label1, label2, label3), dim=0)
        dataset = TensorDataset(data, label)
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        for d, l in tqdm(train_dataloader):
            optimizer.zero_grad()
            out = model(d)
            true_label = torch.flatten(l).long()

            loss = criterion(out, true_label)
            loss.backward()
            optimizer.step()

        if e % 50 == 0:
            print("Epoch" + str(e))
            model.eval()
            val_acc = 0
            val_loss = 0
            with torch.no_grad():
                for vd, vl in tqdm(val_dataloader):
                    val_out = model(vd)
                    val_true_label = torch.flatten(vl).long()

                    val_loss += criterion(val_out, val_true_label).item()
                    _, val_predicted = torch.max(val_out.data, 1)
                    val_acc += torch.sum(val_predicted == val_true_label).item()

            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataset)
            print(f'Validation Loss: {val_loss:.6f}, Validation Accuracy: {val_acc:.6f}')

    # step 2.4 save model_relevant
    torch.save(model.state_dict(), "model_relevant.pth")

# step 3 train model_predict to rerank
# 生成式模型,进行从query的embedding空间到doc的embedding空间的映射,可以rerank
# step 3.1 model_predict hyper parameters
input_dim_transformer = 1024
hidden_dim_transformer = 512
output_dim_transformer = 1024
num_layers_transformer = 3
num_heads_transformer = 8
dropout_transformer = 0.1

model_predict = ModelTransformer(input_dim_transformer, hidden_dim_transformer, output_dim_transformer, num_layers_transformer, num_heads_transformer, dropout_transformer).to(device)
learning_rate_transformer = 5e-4
criterion = nn.MSELoss()
optimizer = optim.Adam(model_predict.parameters(), lr=learning_rate_transformer)
batch_size_transformer = 256

if is_load and os.path.exists('model_predict.pth'):
    model_predict.load_state_dict(torch.load("model_predict.pth"))
    print('model_predict is loaded.')
else:
    # step 3.2 model_predict data preprocessing
    query_embeddings_transformer = [entry['query_embedding'] for entry in query if entry['evidence_list'] for _ in entry['evidence_list']]
    document_embeddings_transformer = [evidence['fact_embedding'] for entry in query if entry['evidence_list'] for evidence in entry['evidence_list']]

    query_embeddings_transformer = torch.tensor(query_embeddings_transformer, dtype=torch.float32, device=device)
    document_embeddings_transformer = torch.tensor(document_embeddings_transformer, dtype=torch.float32, device=device)

    dataset_transformer = TensorDataset(query_embeddings_transformer, document_embeddings_transformer)
    dataloader_transformer = DataLoader(dataset_transformer, batch_size=batch_size_transformer, shuffle=True)

    # step 3.3 train model_predict
    epochs_transformer = 1600
    model_predict.train()
    for e in range(epochs_transformer):
        for i, (queries, documents) in enumerate(dataloader_transformer):
            outputs = model_predict(queries.unsqueeze(1))
            loss = criterion(outputs, documents)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {e}, loss: {loss.item():.6f}')

    # step 3.4 save model_predict
    torch.save(model_predict.state_dict(), "model_predict.pth")

# 推理阶段
k1 = 7
test_query = read_json('input/query_testset.json')
results = inference(model, test_query, doc_embeddings, document, device, model_predict, k1)

write_json(results, 'output/result.json')
zip_fun()