金倉データを用いたLoRAファインチューニングとSpring AIによるカスタムQwenモデル構築

LoRA技術の概要

大規模言語モデルのファインチューニングには、高い計算リソースと専門知識が必要とされます。LoRA(Low-Rank Adaptation)技術は、モデル全体を再トレーニングすることなく、特定タスクに特化した効率的な調整を可能にします。この手法により、既存の事前学習済みモデルを基盤として、企業固有のデータでカスタマイズすることが現実的になります。

環境構築

データ収集

データベース移行に関する専門知識をモデルに学習させるため、金倉コミュニティから関連文書を収集しました。HTMLとPDF形式のドキュメントを対象とし、データ品質と多様性を確保するため、複数の情報源から資料を入手しました。

データセット構築

大規模言語モデルのファインチューニングには、指示応答形式のデータセットが必要です。以下のような構造でデータを準備します:

[{
    "instruction": "KESとは何ですか?",
    "input": "",
    "output": "KESは人大金倉の略称です。"
}]

ドキュメント解析

Spring Bootプロジェクトを構築し、Spring AIを活用してドキュメント解析を行います。必要な依存関係を設定します:

<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-bom</artifactId>
            <version>1.0.0-M8</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

<dependencies>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-tika-document-reader</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-pdf-document-reader</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-jsoup-document-reader</artifactId>
    </dependency>
</dependencies>

HTMLドキュメント解析の実装例:

public List<Document> extractHtmlContent() throws IOException {
    ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
    Resource[] htmlFiles = resolver.getResources("classpath:/documents/*.html");
    
    return Arrays.stream(htmlFiles)
            .flatMap(file -> {
                JsoupDocumentReader htmlReader = new JsoupDocumentReader(file,
                        JsoupDocumentReaderConfig.builder()
                                .selector("div.content")
                                .charset("UTF-8")
                                .build());
                return htmlReader.read().stream();
            })
            .collect(Collectors.toList());
}

PDFドキュメント解析の実装例:

public List<Document> extractPdfContent() throws IOException {
    ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
    Resource[] pdfFiles = resolver.getResources("classpath:/documents/*.pdf");
    
    return Arrays.stream(pdfFiles)
            .flatMap(file -> {
                try {
                    ParagraphPdfDocumentReader pdfReader = new ParagraphPdfDocumentReader(file,
                            PdfDocumentReaderConfig.builder()
                                    .withPageTopMargin(10)
                                    .withPagesPerDocument(1)
                                    .build());
                    return pdfReader.read().stream();
                } catch (Exception e) {
                    PagePdfDocumentReader fallbackReader = new PagePdfDocumentReader(file);
                    return fallbackReader.read().stream();
                }
            })
            .collect(Collectors.toList());
}

データ変換

抽出したドキュメントをAIモデルが処理可能な形式に変換します:

public List<TrainingExample> prepareTrainingData() throws IOException {
    List<Document> htmlDocs = extractHtmlContent();
    List<Document> pdfDocs = extractPdfContent();
    
    List<Document> allDocuments = new ArrayList<>();
    allDocuments.addAll(htmlDocs);
    allDocuments.addAll(pdfDocs);
    
    TokenTextSplitter textSplitter = new TokenTextSplitter();
    List<Document> segmentedDocs = textSplitter.apply(allDocuments);
    
    ExecutorService processingPool = Executors.newFixedThreadPool(4);
    List<Future<List<TrainingExample>>> processingTasks = segmentedDocs.stream()
            .map(doc -> processingPool.submit(() -> convertToTrainingFormat(doc)))
            .collect(Collectors.toList());
    
    List<TrainingExample> trainingData = new ArrayList<>();
    for (Future<List<TrainingExample>> task : processingTasks) {
        try {
            trainingData.addAll(task.get());
        } catch (Exception e) {
            logger.error("データ変換エラー", e);
        }
    }
    
    return trainingData;
}

モデルファインチューニング

環境設定

GPUメモリ20GB以上のサーバー環境を準備します。Qwen2.5-7B-Instructモデルをベースとして、LoRAによるファインチューニングを実施します。

データセット統合

import pandas as pd
from datasets import Dataset
import os

training_files = [f for f in os.listdir('training_data') if f.endswith('.json')]
combined_data = pd.concat([pd.read_json(f) for f in training_files])
training_dataset = Dataset.from_pandas(combined_data)

ファインチューニング設定

def format_training_example(example):
    MAX_SEQUENCE_LENGTH = 512
    system_prompt = "金倉データベース専門アシスタントとして応答してください"
    
    formatted_input = tokenizer(
        f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        f"<|im_start|>user\n{example['instruction']}<|im_end|>\n"
        f"<|im_start|>assistant\n",
        add_special_tokens=False
    )
    
    return {
        'input_ids': formatted_input['input_ids'],
        'attention_mask': formatted_input['attention_mask']
    }

モデル評価

ファインチューニング後のモデル性能を評価します:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_path = '/models/qwen2.5-7b-instruct'
lora_weights_path = './training_output/checkpoint-500'

tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

fine_tuned_model = PeftModel.from_pretrained(base_model, lora_weights_path)

test_prompt = "KDTS移行ツールが対応するデータベースを教えてください"
input_data = tokenizer.apply_chat_template(
    [
        {"role": "system", "content": "金倉データベース専門アシスタント"},
        {"role": "user", "content": test_prompt}
    ],
    return_tensors="pt"
).to('cuda')

generated_output = fine_tuned_model.generate(
    input_data,
    max_new_tokens=256,
    temperature=0.7
)
response = tokenizer.decode(generated_output[0], skip_special_tokens=True)

Webインターフェース実装

Streamlitを使用して対話型インターフェースを構築します:

import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

@st.cache_resource
def initialize_model():
    base_model_path = '/models/qwen2.5-7b-instruct'
    adapter_path = './training_output/checkpoint-500'
    
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        device_map="auto",
        torch_dtype=torch.float16
    )
    fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_path)
    return tokenizer, fine_tuned_model

tokenizer, model = initialize_model()

st.title("金倉データベースアシスタント")
user_input = st.text_input("質問を入力してください")

if user_input:
    input_sequence = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": "金倉データベース専門アシスタントです"},
            {"role": "user", "content": user_input}
        ],
        return_tensors="pt"
    )
    
    generated_response = model.generate(input_sequence, max_new_tokens=300)
    response_text = tokenizer.decode(generated_response[0], skip_special_tokens=True)
    st.write(response_text)

タグ: LoRa Spring AI Qwen 微細調整 データベース移行

5月11日 13:06 投稿