RobertaモデルのONNX変換と文脈マスク予測
モデルのONNX形式への変換
モデルはHugging Faceから入手可能です。
def setup_nlp_components():
tokenizer = BertTokenizer.from_pretrained("../model/chinese_roberta_L-2_H-512")
model = BertForMaskedLM.from_pretrained("../model/chinese_roberta_L-2_H-512")
return model, tokenizer
def save_as_onnx(model, tokenizer, output_file):
convert('pt', model, Path(output_file), 11, tokenizer)
実行テスト
def run_inference(onnx_file, topk=3):
tokenizer = BertTokenizer.from_pretrained("../model/chinese_roberta_L-2_H-512")
input_text = '北京是中国的' + tokenizer.mask_token + '。'
tokenized = tokenizer(input_text)
mask_pos = np.where(np.array(tokenized['input_ids']) == tokenizer.mask_token_id)[0][0]
input_ids = np.array([tokenized['input_ids']], dtype=np.int64)
attention_mask = np.array([tokenized['attention_mask']], dtype=np.int64)
token_type_ids = np.array([tokenized['token_type_ids']], dtype=np.int64)
onnx_model = onnx.load(onnx_file)
inference_session = ort.InferenceSession(onnx_model.SerializeToString())
prediction = inference_session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids
}
)[0]
top_indices = np.argsort(-prediction[0][mask_pos])[:topk]
for idx in top_indices:
token = tokenizer.convert_ids_to_tokens([idx])[0]
print(f'インデックス: {idx} -> トークン: {token}')
# 出力結果
# インデックス: 1266 -> トークン: 首都
# インデックス: 4242 -> トークン: 都市
# インデックス: 3307 -> トークン: 京
MacBERTモデルのONNX変換と中国語誤字訂正
モデルのONNX形式への変換
モデルはHugging Faceから入手可能です。
def initialize_correction_model():
tokenizer = BertTokenizer.from_pretrained("../model/macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("../model/macbert4csc-base-chinese")
return model, tokenizer
def export_correction_model(model, tokenizer, output_path):
convert('pt', model, Path(output_path), 11, tokenizer)
実行テスト
def correct_spelling(onnx_path):
tokenizer = BertTokenizer.from_pretrained("../model/macbert4csc-base-chinese")
input_sentence = '他经常去图书馆看书,非常开兴。'
tokenized = tokenizer(input_sentence)
input_ids = np.array([tokenized['input_ids']], dtype=np.int64)
attention_mask = np.array([tokenized['attention_mask']], dtype=np.int64)
token_type_ids = np.array([tokenized['token_type_ids']], dtype=np.int64)
onnx_model = onnx.load(onnx_path)
session = ort.InferenceSession(onnx_model.SerializeToString())
result = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids
}
)[0]
corrected_tokens = np.argmax(result[0], axis=-1)
corrected = tokenizer.decode(corrected_tokens, skip_special_tokens=True).replace(' ', '')
print(f'元文: {input_sentence} -> 修正後: {corrected[:len(input_sentence)]}')
# 出力結果
# 元文: 他经常去图书馆看书,非常开兴。 -> 修正後: 他经常去图书馆看书,非常开心。