ONNX形式のRobertaおよびMacBERTを用いた文脈マスク推論と中国語誤字訂正

RobertaモデルのONNX変換と文脈マスク予測

モデルのONNX形式への変換

モデルはHugging Faceから入手可能です。
def setup_nlp_components():
    tokenizer = BertTokenizer.from_pretrained("../model/chinese_roberta_L-2_H-512")
    model = BertForMaskedLM.from_pretrained("../model/chinese_roberta_L-2_H-512")
    return model, tokenizer

def save_as_onnx(model, tokenizer, output_file):
    convert('pt', model, Path(output_file), 11, tokenizer)

実行テスト

def run_inference(onnx_file, topk=3):
    tokenizer = BertTokenizer.from_pretrained("../model/chinese_roberta_L-2_H-512")
    input_text = '北京是中国的' + tokenizer.mask_token + '。'
    tokenized = tokenizer(input_text)
    mask_pos = np.where(np.array(tokenized['input_ids']) == tokenizer.mask_token_id)[0][0]
    input_ids = np.array([tokenized['input_ids']], dtype=np.int64)
    attention_mask = np.array([tokenized['attention_mask']], dtype=np.int64)
    token_type_ids = np.array([tokenized['token_type_ids']], dtype=np.int64)

    onnx_model = onnx.load(onnx_file)
    inference_session = ort.InferenceSession(onnx_model.SerializeToString())
    prediction = inference_session.run(
        None,
        {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
    )[0]
    top_indices = np.argsort(-prediction[0][mask_pos])[:topk]
    for idx in top_indices:
        token = tokenizer.convert_ids_to_tokens([idx])[0]
        print(f'インデックス: {idx} -> トークン: {token}')

# 出力結果
# インデックス: 1266 -> トークン: 首都
# インデックス: 4242 -> トークン: 都市
# インデックス: 3307 -> トークン: 京

MacBERTモデルのONNX変換と中国語誤字訂正

モデルのONNX形式への変換

モデルはHugging Faceから入手可能です。
def initialize_correction_model():
    tokenizer = BertTokenizer.from_pretrained("../model/macbert4csc-base-chinese")
    model = BertForMaskedLM.from_pretrained("../model/macbert4csc-base-chinese")
    return model, tokenizer

def export_correction_model(model, tokenizer, output_path):
    convert('pt', model, Path(output_path), 11, tokenizer)

実行テスト

def correct_spelling(onnx_path):
    tokenizer = BertTokenizer.from_pretrained("../model/macbert4csc-base-chinese")
    input_sentence = '他经常去图书馆看书,非常开兴。'
    tokenized = tokenizer(input_sentence)
    input_ids = np.array([tokenized['input_ids']], dtype=np.int64)
    attention_mask = np.array([tokenized['attention_mask']], dtype=np.int64)
    token_type_ids = np.array([tokenized['token_type_ids']], dtype=np.int64)

    onnx_model = onnx.load(onnx_path)
    session = ort.InferenceSession(onnx_model.SerializeToString())
    result = session.run(
        None,
        {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
    )[0]
    corrected_tokens = np.argmax(result[0], axis=-1)
    corrected = tokenizer.decode(corrected_tokens, skip_special_tokens=True).replace(' ', '')
    print(f'元文: {input_sentence} -> 修正後: {corrected[:len(input_sentence)]}')

# 出力結果
# 元文: 他经常去图书馆看书,非常开兴。 -> 修正後: 他经常去图书馆看书,非常开心。

タグ: HuggingFace ONNX Runtime BERT Chinese Spelling Correction

7月5日 17:37 投稿