跳转至

核心代码

目录结构

.
├── BaiduAPI_OCR.py
├── ChineseErrorCorrector_7B_ft.zip
├── download.py
├── evaluation.py
├── Gemini_few_shot_learning.py
├── input
├── model_local
│   ├── T5
│   └── TWNLP-7B
├── models
│   ├── T5
│   └── ...
├── output
│   ├── predict.json
│   ├── predict_phase1.json
│   └── prediction.zip
├── preprocess.py
├── T5_ft.py
├── T5_infer.py
├── twnlp_7B_ft.py
├── twnlp_7B_no_ft.py
└── utils
    ├── train_data_transfer.py
    ├── visualize_OCR.py
    └── zip_predict.py

BaiduAPI_OCR.py

import base64
import requests
import time
import json
import os
# 百度官方实现:https://ai.baidu.com/ai-doc/OCR/hk3h7y2qq
API_KEY = "***"  # 替换为你的API Key
SECRET_KEY = "***" # 替换为你的Secret Key
access_token_cache = None

def get_access_token():
    global access_token_cache
    if access_token_cache:
        return access_token_cache
    auth_url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {
        "grant_type": "client_credentials", 
        "client_id": API_KEY, 
        "client_secret": SECRET_KEY
    }
    response = requests.post(auth_url, params = params)

    if response.status_code == 200:
        access_token_cache = response.json().get("access_token")
        return access_token_cache
    else:
        raise Exception(f"Get Access Token Error: {response.text}")

def BaiduOCR(image_path):
    with open(image_path, "rb") as f:
        img = base64.b64encode(f.read()).decode("utf8")

    headers = {
        'Content-Type': 'application/x-www-form-urlencoded',
        'Accept': 'application/json'
    }
    params = {
        "image": img,
        "recognize_granularity": "small",  # 定位单字符位置
    }
    url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/handwriting?access_token={get_access_token()}"

    while True:
        try:
            response = requests.post(
                url,
                headers = headers,
                data = params,
                timeout = 30
            )

            if response.status_code == 200:
                result = response.json()
                if "error_code" in result:
                    print(f"API Error: {result['error_msg']}")
                    return None

                source_text = ""
                bounding_box_list = []

                if "words_result" in result:
                    for words in result["words_result"]:
                        if "chars" in words:
                            for c in words["chars"]:
                                location = c["location"]
                                bounding_box_list.append({
                                    "char": c["char"],
                                    "box": {
                                        'start_x': location["left"],
                                        'start_y': location["top"],
                                        'end_x': location["left"] + location["width"],
                                        'end_y': location["top"] + location["height"]
                                    }
                                })

                source_text = "".join([c["char"] for c in bounding_box_list])
                return source_text, bounding_box_list
            elif response.status_code == 429:
                print("Too many requests, rest a few seconds.")
                time.sleep(10)
                continue
            else:
                print(f"Request Failed: {response.text}")
                return None
        except Exception as e:
            print(f"Error: {e}")
            return None

if __name__ == "__main__":
    ocr_data = []
    with open('input/test_data.json', 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    for item in test_data:
        update_item = {
            "fk_homework_id": item["fk_homework_id"],
            "path": item["path"],
            "predict_text": "",
            "bounding_box_list": []
        }
        filename = item.get("path", "")
        image_path = os.path.join("input/preprocessed_test_images", filename)
        if not os.path.exists(image_path):
            print(f"image doesn't exist: {image_path}")
            continue

        source_text, bounding_box_list = BaiduOCR(image_path)
        if source_text:
            update_item["source_text"] = source_text
            update_item["char_bounding_box_list"] = bounding_box_list
        else:
            print(f"{image_path} failed.")
        ocr_data.append(update_item)

    with open('input/ocr_test_data.json', 'w', encoding='utf-8') as f:
        json.dump(ocr_data, f, ensure_ascii = False, indent = 2)

    print("OCR finished.")

ChineseErrorCorrector_7B_ft.zip

适配本实验的微调代码

ChineseErrorCorrector_7B_ft

download.py

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from huggingface_hub import snapshot_download

set_seed(42)
snapshot_download(
    repo_id = "shibing624/mengzi-t5-base-chinese-correction",
    cache_dir = None,
    local_dir = "model_local/T5",
    force_download = False,
    ignore_patterns=["*.ckpt", "*.bin"]
)

snapshot_download(
    repo_id = "twnlp/ChineseErrorCorrector2-7B",
    cache_dir=None,
    local_dir="model_local/TWNLP-7B",
    force_download = False,
    ignore_patterns = ["*.ckpt", "*.bin"]
)


evaluation.py

# 文件路径可自行修改
import json

# 计算IOU
def compute_iou(box1, box2):
    x_left = max(box1["start_x"], box2["start_x"])
    y_top = max(box1["start_y"], box2["start_y"])
    x_right = min(box1["end_x"], box2["end_x"])
    y_bottom = min(box1["end_y"], box2["end_y"])

    #如果没有重叠
    if x_right <= x_left or y_bottom <= y_top:
        return 0.0

    inter_area = (x_right - x_left) * (y_bottom - y_top)
    area1 = (box1["end_x"] - box1["start_x"]) * (box1["end_y"] - box1["start_y"])
    area2 = (box2["end_x"] - box2["start_x"]) * (box2["end_y"] - box2["start_y"])
    union_area = area1 + area2 - inter_area
    return inter_area / union_area if union_area > 0 else 0.0


# F0.5
def compute_f05_char_level(ref, pred):
    ref_chars = set(ref)
    pred_chars = set(pred)
    correct = len(ref_chars & pred_chars)
    pred_total = len(pred_chars)
    ref_total = len(ref_chars)
    if pred_total == 0 or ref_total == 0:
        return 0.0
    precision = correct / pred_total
    recall = correct / ref_total
    beta = 0.5
    return (1 + beta ** 2) * precision * recall / (beta ** 2 * precision + recall) if (precision + recall) > 0 else 0.0


# 加载数据
with open('data/train_data_with_bounding_box.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)

with open('data/train_predict.json', 'r', encoding='utf-8') as f:
    pred_data = json.load(f)

pred_map = {item["fk_homework_id"]: item for item in pred_data}

total_score = 0.0
total_count = len(test_data)

for gt in test_data:
    fkid = gt["fk_homework_id"]

    if fkid not in pred_map:
        f05 = 0.0
        iou_score = 0.0
    else:
        pred = pred_map[fkid]
        f05 = compute_f05_char_level(gt["target_text"], pred["predict_text"])

        gt_boxes = gt.get("bounding_box_list", [])
        pred_boxes = pred.get("bounding_box_list", [])
        iou_score = 0.0
        if pred_boxes:
            ious = []
            for pb in pred_boxes:
                #一对一计算IOU
                max_iou = max(compute_iou(pb, gb) for gb in gt_boxes)
                ious.append(max_iou)
            iou_score = sum(ious) / len(ious) if ious else 0.0
    #加权求和
    final = 0.5 * f05 + 0.5 * iou_score
    total_score += final

average = total_score / total_count if total_count > 0 else 0.0
print(f"平均得分: {average:.4f}")

Gemini_few_shot_learning.py

from google import genai
from google.genai.types import GenerateContentConfig, HttpOptions
import json

client = genai.Client(api_key="***") # 替换为你的 API key
with open("input/ocr_test_data.json", "r", encoding="utf-8") as f:
    data = json.load(f)

results = []
for num, item in enumerate(data):
    char_bounding_box_list = item["char_bounding_box_list"]
    src = item["source_text"]
    new_item = dict(item)
    corrected_text = ""

    response = client.models.generate_content(
        model="gemini-2.5-flash-preview-05-20",
        contents=f"你是一个中文文本纠错助手,请帮我纠正语法和拼写错误。\
                        现在你要进行语法纠错工作。\
                        你要去除文字前面的一些冗余信息,如学号、姓名、作文题目,只保留作文正文。\
                        有时作文题目和第一句话连在了一起,请你自行鉴别并删除这些信息。\
                        你不可以改变原始的表达,只可以改动语法错误,不可以使用更优美的表达来替换原本的简单表述,不可以有任何换行,不允许出现\\n这样的转义字符。\
                        如果原文是英文版的标点符号,比如逗号、分号、引号,不可以换成中文的。\
                        如果拼音或字母混入中文,也要进行相应的修改。\
                        除非必要情况,否则不允许拆分和合并原句,不可以调整语序。\
                        除非必要情况,不可以增加原文没有的形容词或副词。\
                        对于不确定对错的内容不要修改。\
                        你只需要给我返回正确的结果,不可以返回任何额外内容,比如解释。\
                        下面给你看一些例子:\
                        原文:\
                        12232740402-阮氏秋庄现在年轻人运动与养生在当今社会,保持健康的生活方式变得越来越重要,各种运动和养生方式应运而来。我个人特别喜欢瑜伽,它不仅是锻炼身体的方式更是提升灵平和专注力的有效途径。我是父高一的时候就开始炼瑜伽的,到现在也差不多四五年了。我觉得每次练习瑜伽后都让我感到身体的拉伸和放松,同时也能让我在繁忙的生活中找到一份宁静。练习瑜伽能有效提高我的注意力和自我意识。而且它还能增强身体的柔韧性,力量和平衡感,减轻肌肉的紧张和心理压力,帮助我更好地应对生活的挑战。总之,运动和养生方式的选择因人而异,但目标始终是保持身体健康与心理平衡。现在在我们国家年轻人更加的注重养生和运动了。他们一般会通过:“游泳、踢足球、打羽毛球等运动来锻炼身体。\
                        应该被纠错为:\
                        现在年轻人运动与养生的方式在当今社会,保持健康的生活方式变得越来越重要,各种运动和养生方式随之兴起。我个人特别喜欢瑜伽,它不仅是锻炼身体的方式,更是提升心灵平静和专注力的有效途径。我是从高一的时候就开始练习瑜伽的,到现在也差不多有四五年了。我觉得每次练习瑜伽后都让我感到身体的拉伸和放松,同时也能让我在繁忙的生活中找到一份宁静。练习瑜伽能有效提高我的注意力和自我意识。而且它还能增强身体的柔韧性、力量和平衡感,减轻肌肉的紧张和心理压力,帮助我更好地应对生活的挑战。总之,运动和养生方式的选择因人而异,但目标始终是保持身体健康与心理平衡。现在在我们国家年轻人更加注重养生和运动了。他们一般会通过游泳、踢足球、打羽毛球等运动来锻炼身体。\
                        原文:\
                        12222790915.赵吉.我与购物新方式☰二十一世纪带来了我们生活各方面巨大的变化。网络☰很多生活的活动变成互联网的活动。购物是这种变化的缩影。我从小喜欢跟母一起去市场购物。小时得买东西不是购物主要的行动。对童车的我这首先是去市场,☰叫朋友玩,去饭店。以前购物是一种很好玩的经验,购买会当过一个理由。初中和高中我很☰喜欢和同学们去市场买一些衣☰服,特别是新年时买礼品。毕业高中之前,我欣赏了☰☰好菜坞和广告产生的消费梦想。后来我到中国了。2019车中国的线上购物已经是非常发达。可那时我还是很乐意去市场。可是,疫情后状况变了:钱购物被淘汰。线上购买的两个最大的好处是便宜和方便。除的法外,现代的人没有时间和力气。线上☰我们能比线下便宜买东西,而是也不要出门。真完美乎?我则不同意。我肯定☰结上购物的好处,云不过对我购买☰这种活动失去了意义。\
                        应该被纠错为:\
                        二十一世纪给我们生活各方面带来了巨大的变化。网络使很多生活的活动变成互联网的活动。购物是这种变化的缩影。我从小喜欢跟父母一起去市场购物。小时候买东西不是购物主要的行动。对童年的我来说这首先是去市场,然后叫朋友一起玩,最后去饭店。以前购物是一种很好玩的体验,购买会成为一种理由。初中和高中我很喜欢和同学们去市场买一些衣服,特别是新年时买礼品。满足了我毕业高中之前欣赏好莱坞和广告产生的消费梦想。后来我到了中国,2019年中国的线上购物已经是非常发达。可那时我还是很乐意去市场。可是,疫情后状况变了:线下不再是购物的唯一选择。线上购买的两个最大优势是便宜和方便。除此之外,现代的人没有时间和力气。在线上我们能比线下买东西更便宜且也不用出门。真完美乎?我则不同意。我肯定线上购物的好处,但是这让我购买实体商品活动失去了意义。\
                        原文:\
                        全贤基12232720412我介绍的我喜欢☰动是足球。虽然在中国没有人踢足球,但是在韩国小学生到大学生我们都踢足球。我来中国之后我找过踢足球的阴友,但是几乎没有都打蓝球。我在韩国生活的时候也没看过有人打蓝球。可能和国家水平有关系的。韩国蓝球打得很菜,而且没有人对蓝球有兴趣。我们都看足球和棒球。但是中国足球踢得很菜,而且没有人打棒球。很多人规则都不知道。以前我以为蓝球就是个子高的人才能打的运动。但在中国不到一米七的人也爱打蓝球。中国去哪里都有打蓝球的,在公园也能打,但是足球场很少,而且有点贵。所以最近我开始打网球。这个运动以前我没打过,但是和朋友们打,慢慢学感觉也好玩。现在我的生话就是周中下课去健身,周末和朋友们打网球,以前有空就去打牌,现在努力改生活方式。最还努力早睡早起,成功的时候十一点前睡六点半起还吃早餐很自律的。但是周末如果喝酒睡到中午的话,很难回周中的生活,所以我要尽量戒语、戒麻将,坚持自律的健康的生活。我希望我能成功。\
                        应该被纠错为:\
                        我介绍的运动是足球。虽然在中国没有多少人踢足球,但是在韩国,从小学生到大学生,我们都踢足球。我来中国之后,我找过踢足球的朋友,但是几乎没有人在打篮球。我在韩国生活的时候也没看过有人打篮球。可能和国家水平有关系。韩国篮球打得很菜,而且似乎没有人对篮球有兴趣。我们都看足球和棒球。但是中国足球踢得很菜,而且没有人打棒球。很多人规则都不知道。以前我以为篮球就是个子高的人才能打的运动。但在中国不到一米七的人也爱打篮球。中国去哪里都有人打篮球,甚至在公园也能打,但是足球场很少,而且价格可能有点高。所以最近我开始打网球。这个运动以前我没打过,但是和朋友们打,慢慢学感觉也很好玩。现在我的生话就是周中下课去健身,周末和朋友们打网球,以前有空就去打牌,现在努力改变生活方式。最还努力早睡早起,尽量在十一点前睡,六点半起,并且还吃早餐,很自律。但是周末如果喝酒睡到中午的话,很难适应周中的生活,所以我要尽量戒酒、戒麻将,坚持自律的健康生活。我希望我能成功。\
                        原文:\
                        学号122121720420韩国D所以我觉得洪承铉我海用手机买东西、看电影、玩游戏等,手机在我的人生中是不可磨灭的从2018年到现在新☰型冠状病毒改变了很多日常生活。在公共场所与人们保持距离基本上成为了应☰该遵守的礼仪,尽量避免外出,呆在家里的日子越来越多。随着在家中度过的时间代替户外活动增加,家人一起吃饭的次数也增加了,正在努力寻找可以在室内进行☰新兴趣活动。另外,长时间机☰看电视或只玩智能手☰等坏可惯也发生了。我们的家人也不例外,随着逐☰渐适应新☰型冠状病毒的日常生活的同时,随着外出时间的减少,使用智能手机的时间也明显增加了。我认为新型冠状病毒路从金融业务到兴趣爱好,不仅使用智能手☰☰机,而且网上开学后为了我的学业日程,使用☰智能机器的时间比以前大幅增加☰280智能我一天中十个☰小时左右用手机我觉得手机有优点和缺点所以人们应该正确使用智能能手机。☰数字☰的现代我觉得在数学时代智能手机是最具代表性的。在是因为大部分人都有,而且接触起来最方便。\
                        应该被纠错为:\
                        我经常用手机买东西、看电影、玩游戏等,所以我觉得手机在我的人生中是不可或缺的。从2018年到现在疫情时期,我们的日常生活发生了很多改变。在公共场所与人们保持距离基本上成为了我们应该遵守的规范,应该尽量避免外出,因此我们呆在家里的日子越来越多了。随着户外活动的减少,取而代之的是在家中度过的时间逐渐增多,家人一起吃饭的次数也逐渐增多了,并且我们正在努力寻找可以在室内进行的新的娱乐活动。不过,长时间看电视还有只玩智能手机等坏可惯也产生了。我的家人们也不例外,随着逐渐适应疫情时期的同时,外出时间逐渐减少,使用智能手机的时间也明显增加了。我认为疫情对我们的影响小到个人,大到整个社会,我不仅使用智能手机,而且为了我的学业,我使用智能设备的时间比以前大大增加了。我一天中花十个小时左右的时间在手机上,我觉得手机既有优点又有缺点,所以人们应该正确使用智能手机。现在是数字现代我觉得在数字时代智能手机是最具代表性的。因为大部分人都有,而且使用起来最方便。\
                        原文:\
                        六写作(1×20)你觉得现在有压力吗?以前呢?请谈一谈你面对过什么样的压力,你又是如何解压和解决问题的。水晶qiu说明:1.题目根据上面的材料自己确定①owwexapp.E③BG0o-rumagrtransivion2.字数要求300字以上malaaygda kuursournal.我的压力经历我们现代的生活中越来越多的人亚健康因为他们不知道怎么解压。今天我要谈一谈我的以前的现在的压力,而且我如何解决了他们。”以前我在高中时压力特别大怎么考这个考试呢?去学习在哪里?哪个国家,在高三时我要考蒙古高考还有国际高考。比如IELTS,SAT,HSKk我都考了还有申请大学。时间管理特别重要那些考试不难但是☰那时候我特别疲倦。因为考国际考试不免费,比轻贵,我真怕,后来我得到了我的理想的成绩所以没有后悔,而且得镜我还考了蒙古高考,蒙方语,俄语,英语数学等等。我考了因为学校要求,但是压力很大。考了后生活就变轻松多了。现在的压力比高三少多了,但也有一些,以前跟父母又安全又舒服。现在只在手机上见面。这个然后我的现在的汉语水平需要提高。在大学生活比平时忙多了。每天说外语,有的时候我不会表达我完全意思让我很难过。但是我很高兴我有这样的经历,这样的压力,我只能☰努力就好多。现在我要建议大家怎解压。第一个就是锻炼身体。如果我很累我就去跑步就多好。第二个是乌日记我最喜欢的方式,你可以管理时间日常,你的想法。对都有好处。总的来说压力是正常的我们只需要怎解决他们就没问题了。希望没个人找到适合自己的方式。6/6\
                        应该被纠错为:\
                        我的压力经历现代人越来越多处于亚健康的生活状态,因为他们不知道怎么解压。今天我要谈一谈我以前,现在的压力,以及我如何解决了它们。”我在高中时压力特别大,怎么完成考试,去哪个国家学习?高三时,我要考蒙古高考,也要考国际高考。比如IELTS,SAT,HSKk还要申请大学。时间管理特别重要,那些考试不难,但是那时候我特别疲倦。因为考国际考试不免费,比较贵,我真害怕,后来我得到了我的理想的成绩所以不用后悔,甚至还可以得意。因为学校要求,我还考了蒙古高考,蒙方语,俄语,英语数学等等。压力很大。但是通过考试之后,生活就变轻松多了。现在的压力比高三小多了,但也有一些,以前跟父母在一起, 又安全又舒服。现在只在手机上与他们见面。而且我现在的汉语水平需要提高。在大学生活比之前忙多了。每天说外语,有的时候我不会表达我的意思,这让我很难过。但是我也很高兴能有这样的经历和这样的压力,我只能努力这样才能好一点。现在我要告诉大家怎么解压。第一,锻炼身体。如果我很累,我就去跑步。第二,写日记我最喜欢的方式,这对你管理时间和日常生活以及你的想法都有好处。总的来说,压力是正常的。我们需要的是去解决它们。希望每个人都能找到适合自己的方式。6/6\
                        请纠正以下文本中的错误:{src}\
                        ",
        config=GenerateContentConfig(
            temperature=0.5,
            response_mime_type="application/json",
            top_p=0.88,
            max_output_tokens=9216
        )
    )
    # print(response)
    corrected_text = response.text
    new_item["predict_text_phase1"] = corrected_text

    results.append(new_item)

    if (num+1) % 5 == 0:
        print(f"{num+1} finished.")

with open("output/predict_phase1.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent = 2)

print("Correct Phase 1 Finished.")

preprocess.py

# OCR之前的图像预处理
import cv2
import numpy as np
import json
import os
import matplotlib.pyplot as plt

# 增强对比度
def contrast(img):
    clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(12,12))
    return clahe.apply(img)

if __name__ == "__main__":
    with open('input/test_data.json', 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    os.makedirs("input/preprocessed_test_images", exist_ok=True)
    for item in test_data:
        filename = item.get("path", "")
        image_path = os.path.join("input/test_images", filename)
        if not os.path.exists(image_path):
            print(f"{image_path} No exist!")
            continue

        img = cv2.imread(image_path)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 转灰度
        img1 = contrast(gray)
        processed_path = os.path.join("input/preprocessed_test_images", filename)
        cv2.imwrite(processed_path, img1)

    print("preprocess finished.")

T5_ft.py

import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW, AutoTokenizer
from pycorrector.t5.t5_corrector import T5Corrector
import numpy as np
from tqdm import tqdm

class T5CorrectorDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=200):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            json_data = json.load(f)

        for item in json_data:
            self.data.append({
                'source_text': item['source_text'],
                'target_text': item['target_text']
            })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        source_text = item['source_text']
        target_text = item['target_text']
        # T5模型的输入格式为 "correct: [source_text]"
        source_encoding = self.tokenizer(
            f"correct: {source_text}",
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        labels = target_encoding.input_ids.squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # 忽略pad token的损失
        return {
            'input_ids': source_encoding.input_ids.squeeze(),
            'attention_mask': source_encoding.attention_mask.squeeze(),
            'labels': labels
        }

def train_model( train_file, save_path='models/T5',
    batch_size=32,
    epochs=20,
    learning_rate=5e-5,
    max_length=200,
    device='cuda'
):
    if not torch.cuda.is_available() and device == 'cuda':
        print("CUDA is not available. Using CPU instead.")
        device = 'cpu'

    print("device:", device)
    os.makedirs(save_path, exist_ok=True)
    model = T5ForConditionalGeneration.from_pretrained(
        "model_local/T5",
        trust_remote_code=True,
        device_map=device
    )
    tokenizer = AutoTokenizer.from_pretrained("model_local/T5", trust_remote_code=True)

    train_dataset = T5CorrectorDataset(train_file, tokenizer, max_length=max_length)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            progress_bar.set_postfix({'loss': loss.item()})

    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    return model, tokenizer

if __name__ == "__main__":
    train_model(
        train_file='input/train_data.json',
        save_path='models/T5',
        batch_size=32,
        epochs=20,
        learning_rate=5e-5,
        max_length=200,
        device='cuda'
    )

T5_infer.py

import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW, AutoTokenizer
from pycorrector.t5.t5_corrector import T5Corrector
import numpy as np
from tqdm import tqdm
from difflib import SequenceMatcher

if __name__ == "__main__":
    model = T5Corrector('models/T5')
    with open("input/ocr_test_data.json", "r", encoding="utf-8") as f:
        data = json.load(f)

    results = []
    for i, item in enumerate(data):
        src = item["source_text"]
        char_bounding_box_list = item["char_bounding_box_list"]
        corrected_result = model.correct(src)
        corrected_text = corrected_result['target']
        new_item = dict(item)
        new_item["predict_text"] = corrected_text
        matcher = SequenceMatcher(None, src, corrected_text)
        bounding_box_list = []
        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag != 'equal':
                for i in range(i1, i2):
                    if i < len(char_bounding_box_list):
                        box = char_bounding_box_list[i]["box"]
                        bounding_box_list.append(box)

        new_item["bounding_box_list"] = bounding_box_list 
        results.append(new_item)

    with open("output/predict.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print("Correct Finished.")

twnlp_7B_ft.py

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from huggingface_hub import snapshot_download
import os
import json
import numpy as np
import re
from difflib import SequenceMatcher
from tqdm import tqdm
set_seed(42)

model_path = "/root/autodl-tmp/merge2"  
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"
with open("output/predict_phase1.json", "r", encoding="utf-8") as f: # input/ocr_test_data.json
    data = json.load(f) 

results = []
for num, item in enumerate(data):
    char_bounding_box_list = item["char_bounding_box_list"]
    src = item["predict_text_phase1"]
    source_text = item["source_text"]
    new_item = dict(item)

    corrected_text = ""
    messages = [
        {
            "role": "system", 
            "content": "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子。"
        },
        {
            "role": "user", 
            "content": prompt + src
        }
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(**model_inputs, max_new_tokens=1024)
    generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    corrected_text = response.strip()

    new_item["predict_text"] = corrected_text
    matcher = SequenceMatcher(None, source_text, corrected_text)
    bounding_box_list = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag != 'equal':
            for i in range(i1, i2):
                if i < len(char_bounding_box_list):
                    box = char_bounding_box_list[i]["box"]
                    bounding_box_list.append(box)

    new_item["bounding_box_list"] = bounding_box_list
    results.append(new_item)
    if (num+1) % 5 == 0:
        print(f"{num+1} finished.")

with open("output/predict.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Correct Phase 2 Finished.")

twnlp_7B_no_ft.py

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from huggingface_hub import snapshot_download
import os
import json
import numpy as np
import re
from difflib import SequenceMatcher
from tqdm import tqdm
set_seed(42)

model_path = "model_local/TWNLP-7B"  
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')

prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"

with open("input/ocr_test_data.json", "r", encoding="utf-8") as f:
    data = json.load(f)

results = []
for num, item in enumerate(data):
    char_bounding_box_list = item["char_bounding_box_list"]
    src = item["source_text"]
    new_item = dict(item)
    sentences = re.split(r'(?<=[。??!!])', src)
    sentences = [s for s in sentences if s]
    corrected_text = ""
    for s in sentences:
        messages = [
            {
                "role": "user", 
                "content": prompt + s
            }
        ]

        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

        generated_ids = model.generate(**model_inputs, max_new_tokens=512)
        generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        corrected_text += response.strip()

    new_item["predict_text"] = corrected_text
    matcher = SequenceMatcher(None, src, corrected_text)
    bounding_box_list = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag != 'equal':
            for i in range(i1, i2):
                if i < len(char_bounding_box_list):
                    box = char_bounding_box_list[i]["box"]
                    bounding_box_list.append(box)

    new_item["bounding_box_list"] = bounding_box_list
    results.append(new_item)
    if (num+1) % 5 == 0:
        print(f"{num+1} finished.")


with open("output/predict.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Correct Finished.")

utils/train_data_transfer.py

import json
# 为了微调twnlp-7B, 需要将数据集的格式进行一些转化
with open("input/train_data.json", "r", encoding="utf-8") as f:
    source_data = json.load(f)

converted_data = []
for item in source_data:
    converted_data.append({
        "conversations":[
            {
                "from": "human",
                "value": item["source_text"]
            },
            {
                "from": "gpt",
                "value": item["target_text"]
            }
        ]
    })


with open("input/converted_train_data_list.json", "w", encoding="utf-8") as f:
    json.dump(converted_data, f, ensure_ascii=False, indent=2)

with open("input/converted_train_data_list.json", "r", encoding="utf-8") as f:
    list_data = json.load(f)

with open("input/converted_train_data.json", "w", encoding="utf-8") as f:
    for item in list_data:
        # 将每个对象单独写入一行
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"transfer {len(source_data)} finished")

utils/visualize_OCR.py

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import json

def visualize_char_boxes(image_path, boxes, 
                         box_color='red', line_thickness=1):
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"{image_path} loaded error.")

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换为RGB格式(用于matplotlib显示)
    img_with_boxes = img_rgb.copy()
    plt.figure(figsize=(12, 12))
    color_map = {
        'red': (0, 0, 255),
        'green': (0, 255, 0),
        'blue': (255, 0, 0),
        'yellow': (0, 255, 255),
        'purple': (255, 0, 255),
        'cyan': (255, 255, 0),
    }
    box_color_bgr = color_map.get(box_color.lower(), (0, 0, 255))  # 默认红色

    # 在图像上绘制边界框和文本
    for item in boxes:
        x1, y1, x2, y2 = item['box']['start_x'], item['box']['start_y'], item['box']['end_x'],  item['box']['end_y']
        # 绘制矩形边界框
        cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), 
                     box_color_bgr, line_thickness)

    # 使用matplotlib显示图像
    plt.imshow(img_with_boxes)
    plt.title('OCR visualization with bounding boxes')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    return img_with_boxes

image_path = "input/preprocessed_test_images/2110.jpg"
with open('input/ocr_test_data.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)

result = next((item['char_bounding_box_list'] for item in test_data if item['fk_homework_id'] == 2110), None)
visualize_char_boxes(
    image_path=image_path,
    boxes = result,
    box_color='red'
)

utils/zip_predict.py

# 将预测结果压缩后再提交
import os
path=os.getcwd()
newpath=path+"/output/"
os.chdir(newpath)
os.system('zip prediction.zip predict.json')
os.chdir(path)