環境セットアップ
前回の準備段階で必要な知識と環境構築が完了しました。本日はコード実装に焦点を当てます。Javaプログラミングの基礎知識があれば、Python構文の理解は難しくないでしょう。時間が経てば自然と習得できるはずです。それでは始めましょう!
必要なライブラリのインストール
使用前にいくつかの準備コマンドを実行する必要があります。Milvusの公式ドキュメントをよく読んでいれば、これらのステップは既に把握しているはずです。以下に実行すべきコマンド例を示します:
pip3 install langchain
pip3 install openai
pip3 install protobuf==3.20.0
pip3 install grpcio-tools
python3 -m pip install pymilvus==2.3.2
python3 -c "from pymilvus import Collection"
Milvus基本操作
まず、LangChainを統合せずに公式サンプルを使用して、挿入や検索操作を完了するために必要なコード量を確認してみましょう。公式サンプルはコメント内で全プロセスを詳細に説明しています。全体の流れは以下の通りです:
- データベースに接続
- コレクションを作成(ここではパーティションの概念もありますが、今回は深入りしません)
- ベクトルデータを挿入(公式ドキュメントでは簡単な数値を挿入しています)
- インデックスを作成(公式ドキュメントによると、一定量のデータが蓄積された場合にインデックスを作成することが一般的です)
- データを検索
- データを削除
- データベースとの接続を切断
これらのステップを通じて、MySQLデータベースに接続する操作と非常に似ていることがわかります。
# milvus_basic_demo.pyはPyMilvusの基本操作をデモンストレーションします。
# 1. Milvusに接続
# 2. コレクションを作成
# 3. データを挿入
# 4. インデックスを作成
# 5. エンティティに対する検索、クエリ、ハイブリッド検索
# 6. PKでエンティティを削除
# 7. コレクションを削除
import time
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
fmt = "\n=== {:30} ===\n"
search_latency_fmt = "検索レイテンシ = {:.4f}s"
num_entities, dim = 2500, 8
#################################################################################
# 1. Milvusに接続
# Milvusサーバーに新しい接続エイリアス'default'を追加します(localhost:19530)
# 実際には'default'エイリアスはPyMilvusに組み込まれています。
# Milvusのアドレスがlocalhost:19530と同じ場合は、すべてのパラメータを省略してメソッドを呼び出すことができます:`connections.connect()`。
#
# 注:以下のメソッドの'using'パラメータはデフォルトで"default"です。
print(fmt.format("Milvusへの接続を開始"))
connections.connect("default", host="localhost", port="19530")
has = utility.has_collection("basic_demo")
print(f"コレクションbasic_demoはMilvusに存在しますか: {has}")
#################################################################################
# 2. コレクションを作成
# 3つのフィールドを持つコレクションを作成します。
# +-+------------+------------+------------------+------------------------------+
# | | フィールド名 | フィールドタイプ | その他の属性 | フィールドの説明 |
# +-+------------+------------+------------------+------------------------------+
# |1| "id" | VarChar | is_primary=True | "プライマリフィールド" |
# | | | | auto_id=False | |
# +-+------------+------------+------------------+------------------------------+
# |2| "value" | Double | | "doubleフィールド" |
# +-+------------+------------+------------------+------------------------------+
# |3|"vectors"| FloatVector| dim=8 | "dim=8のfloatベクトル" |
# +-+------------+------------+------------------+------------------------------+
fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="value", dtype=DataType.DOUBLE),
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
schema = CollectionSchema(fields, "basic_demoはAPIを紹介する最もシンプルなデモ")
print(fmt.format("コレクション`basic_demo`を作成"))
basic_demo = Collection("basic_demo", schema, consistency_level="Strong")
################################################################################
# 3. データを挿入
# basic_demoに2500行のデータを挿入します
# 挿入するデータはフィールドごとに整理する必要があります。
#
# insert()メソッドは以下を返します:
# - スキーマでauto_id=Trueの場合はMilvusによって自動生成されたプライマリキー
# - スキーマでauto_id=Falseの場合はエンティティから既存のプライマリキーフィールド
print(fmt.format("エンティティの挿入を開始"))
rng = np.random.default_rng(seed=19530)
entities = [
# auto_idがFalseに設定されているためpkフィールドを提供
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(), # valueフィールド、listのみサポート
rng.random((num_entities, dim)), # vectorsフィールド、numpy.ndarrayとlistをサポート
]
insert_result = basic_demo.insert(entities)
basic_demo.flush()
print(f"Milvus内のエンティティ数: {basic_demo.num_entities}") # エンティティ数を確認
################################################################################
# 4. インデックスを作成
# basic_demoコレクション用にIVF_FLATインデックスを作成します。
# create_index()はFloatVectorおよびBinaryVectorフィールドにのみ適用できます。
print(fmt.format("IVF_FLATインデックスの作成を開始"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}
basic_demo.create_index("vectors", index)
################################################################################
# 5. 検索、クエリ、ハイブリッド検索
# データがMilvusに挿入され、インデックスが作成された後、以下を実行できます:
# - ベクトル類似性に基づく検索
# - スカラーフィルタリング(ブール値、整数など)に基づくクエリ
# - ベクトル類似性とスカラーフィルタリングに基づくハイブリッド検索
# 検索またはクエリを実行する前に、basic_demoのデータをメモリにロードする必要があります。
print(fmt.format("ロードを開始"))
basic_demo.load()
# -----------------------------------------------------------------------------
# ベクトル類似性に基づく検索
print(fmt.format("ベクトル類似性に基づく検索を開始"))
vectors_to_search = entities[-1][-2:]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
start_time = time.time()
result = basic_demo.search(vectors_to_search, "vectors", search_params, limit=3, output_fields=["value"])
end_time = time.time()
for hits in result:
for hit in hits:
print(f"ヒット: {hit}, valueフィールド: {hit.entity.get('value')}")
print(search_latency_fmt.format(end_time - start_time))
# -----------------------------------------------------------------------------
# スカラーフィルタリングに基づくクエリ(ブール値、整数など)
print(fmt.format("random > 0.5でクエリを実行"))
start_time = time.time()
result = basic_demo.query(expr="value > 0.5", output_fields=["value", "vectors"])
end_time = time.time()
print(f"クエリ結果:\n-{result[0]}")
print(search_latency_fmt.format(end_time - start_time))
# -----------------------------------------------------------------------------
# ページネーション
r1 = basic_demo.query(expr="value > 0.5", limit=4, output_fields=["value"])
r2 = basic_demo.query(expr="value > 0.5", offset=1, limit=3, output_fields=["value"])
print(f"クエリページネーション(limit=4):\n\t{r1}")
print(f"クエリページネーション(offset=1, limit=3):\n\t{r2}")
# -----------------------------------------------------------------------------
# ハイブリッド検索
print(fmt.format("random > 0.5でハイブリッド検索を実行"))
start_time = time.time()
result = basic_demo.search(vectors_to_search, "vectors", search_params, limit=3, expr="value > 0.5",
output_fields=["value"])
end_time = time.time()
for hits in result:
for hit in hits:
print(f"ヒット: {hit}, valueフィールド: {hit.entity.get('value')}")
print(search_latency_fmt.format(end_time - start_time))
###############################################################################
# 6. PKでエンティティを削除
# ブール式を使用してPK値でエンティティを削除できます。
ids = insert_result.primary_keys
expr = f'id in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"expr `{expr}`で削除を開始"))
result = basic_demo.query(expr=expr, output_fields=["value", "vectors"])
print(f"削除前のクエリ(expr=`{expr}`)-> 結果: \n-{result[0]}\n-{result[1]}\n")
basic_demo.delete(expr)
result = basic_demo.query(expr=expr, output_fields=["value", "vectors"])
print(f"削除後のクエリ(expr=`{expr}`)-> 結果: {result}\n")
###############################################################################
# 7. コレクションを削除
# 最後にbasic_demoコレクションを削除します
print(fmt.format("コレクション`basic_demo`を削除"))
utility.drop_collection("basic_demo")
LangChain統合版
次にLangChainバージョンのコードを見てみましょう。Milvusをラップしているため、埋め込みモデルが必要です。ここではHuggingFaceEmbeddingsのsensenova/piccolo-base-zhモデルを例として選択しましたが、もちろん他のモデルも選択可能です。LangChainで定義された関数呼び出しに変数として渡せるものであれば何でも使用できます。
以下はデータベース接続、データ挿入、クエリ、スコアリングを含む簡単な例です:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Milvus
model_name = "sensenova/piccolo-base-zh"
embeddings = HuggingFaceEmbeddings(model_name=model_name)
print("データベースに接続")
vector_db = Milvus(
embeddings,
connection_args={"host": "localhost", "port": "19530"},
collection_name="langchain_demo",
)
print("いくつかの値を簡単に追加")
vector_db.add_texts(["テストデータ1","テストデータ2","AI開発者は常に新しい技術を学び、プロジェクトに応用します","こんにちは","おはよう"])
print("最も類似した3つの結果を検索")
docs = vector_db.similarity_search_with_score("こんにちは",3)
print("スコア状況を確認(スコアが低いほど類似度が高い)")
for text in docs:
print('テキスト:%s,スコア:%s'%(text[0].page_content,text[1]))
注意:上記のコードは単純な例であり、具体的な実装はニーズに応じて調整および最適化する必要があります。
LangChainバージョンのコードでは、Docker内のMilvusコンテナを起動するだけでなく、ネットワークプロキシも必要になる場合があります。ここでは詳しく説明しませんが、HuggingFaceコミュニティは国内にないためです。
カスタムQ&Aシステム
最後に、OpenAIモデルを呼び出して質問に回答する方法を詳しく学びましょう!
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate;
from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseOutputParser
# .env環境変数からkey値を読み込む
load_dotenv()
# 出力形式を整形
class CustomOutputParser(BaseOutputParser):
"""LLMコールの出力をカンマ区切りのリストに解析します。"""
def parse(self, text: str):
"""LLMコールの出力を解析します。"""
return text.strip().split(", ")
# まずデータベースから関連情報を検索
docs = vector_db.similarity_search("AI開発者について")
doc = docs[0].page_content
chat = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
template = "提供された情報に基づいて質問に回答してください。情報: {input_docs}"
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
chain = LLMChain(
llm=chat,
prompt=chat_prompt,
output_parser=CustomOutputParser()
)
chain.run(input_docs=doc, question="AI開発者とは?")
コードを実行すると、期待する回答が得られます。以下のように画面に表示されます。システムがこれらの質問の答えを知らない場合、正しい回答を提供することはできません。