K近傍法によるワインデータセットのクラスタリング実装

K近傍法の基礎概念

K近傍法(K-Nearest Neighbors)は、教師あり学習の一種で分類と回帰の両方に利用できるアルゴリズムです。特徴空間内で最も近いK個の学習サンプルに基づいて予測を行います。

K近傍法の基本要素

  • K値: 近傍点の数を決定します。小さい値ではノイズの影響を受けやすく、大きい値ではクラス境界が曖昧になります
  • 距離指標: サンプル間の類似度を測定します。ユークリッド距離、マンハッタン距離などが一般的です
  • 分類決定規則: 多数決または距離に基づく重み付き投票が用いられます

実験環境の設定

必要な知識

  • Pythonプログラミングの基礎
  • 機械学習の基本概念(K近傍法、教師あり学習、距離計算)

環境要件

  • MindSpore 2.0以降
  • CPU/GPU/Ascendでの実行が可能

データセットの準備と前処理

ワインデータセットの特徴

ワインデータセットは、イタリア産の3種類のワインの化学分析結果を含み、13の属性を持っています:

  1. アルコール度数
  2. リンゴ酸含有量
  3. 灰分
  4. 灰分のアルカリ度
  5. マグネシウム含有量
  6. 総フェノール含有量
  7. フラボノイド含有量
  8. 非フラボノイドフェノール含有量
  9. プロアントシアニジン含有量
  10. 色の強度
  11. 色調
  12. OD280/OD315値
  13. プロリン含有量
from download import download

# データセットのダウンロード
dataset_url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip"
download_path = download(dataset_url, "./", kind="zip", replace=True)

データの読み込みと前処理

import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import nn, ops

# 実行環境の設定
ms.set_context(device_target="CPU")

# データの読み込み
with open('wine.data') as data_file:
    wine_data = list(csv.reader(data_file, delimiter=','))

特徴量とラベルの抽出

# 特徴量データの抽出
features = np.array([[float(val) for val in sample[1:]] for sample in wine_data[:178]], np.float32)
# ラベルデータの抽出
labels = np.array([sample[0] for sample in wine_data[:178]], np.int32)

データの可視化

attribute_names = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols',
                  'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue',
                  'OD280/OD315 of diluted wines', 'Proline']

plt.figure(figsize=(12, 10))
for plot_num in range(4):
    plt.subplot(2, 2, plot_num + 1)
    attr1, attr2 = 2 * plot_num, 2 * plot_num + 1
    plt.scatter(features[:59, attr1], features[:59, attr2], label='Class 1', alpha=0.7)
    plt.scatter(features[59:130, attr1], features[59:130, attr2], label='Class 2', alpha=0.7)
    plt.scatter(features[130:, attr1], features[130:, attr2], label='Class 3', alpha=0.7)
    plt.xlabel(attribute_names[attr1])
    plt.ylabel(attribute_names[attr2])
    plt.legend()
plt.tight_layout()
plt.show()

データセットの分割

# 訓練データとテストデータの分割
train_indices = np.random.choice(178, 128, replace=False)
test_indices = np.array([idx for idx in range(178) if idx not in train_indices])

train_features, train_labels = features[train_indices], labels[train_indices]
test_features, test_labels = features[test_indices], labels[test_indices]

K近傍法モデルの実装

class NearestNeighborsModel(nn.Cell):
    def __init__(self, neighbor_count):
        super(NearestNeighborsModel, self).__init__()
        self.neighbor_count = neighbor_count

    def construct(self, query_point, reference_points):
        # クエリ点を参照点の数だけ複製
        expanded_query = ops.tile(query_point, (128, 1))
        # ユークリッド距離の計算
        squared_differences = ops.square(expanded_query - reference_points)
        sum_squared = ops.sum(squared_differences, 1)
        distances = ops.sqrt(sum_squared)
        # 距離の逆数でソート(近い点ほど値が大きい)
        top_values, top_indices = ops.topk(-distances, self.neighbor_count)
        return top_indices

def predict_classification(knn_model, sample_point, ref_features, ref_labels):
    sample_tensor = ms.Tensor(sample_point)
    ref_tensor = ms.Tensor(ref_features)
    neighbor_indices = knn_model(sample_tensor, ref_tensor)
    
    class_votes = [0] * 3
    for idx in neighbor_indices.asnumpy():
        class_votes[ref_labels[idx]] += 1
    
    predicted_class = np.argmax(class_votes)
    return predicted_class

モデルの評価

correct_predictions = 0
knn_model = NearestNeighborsModel(5)

for test_sample, true_label in zip(test_features, test_labels):
    predicted_label = predict_classification(knn_model, test_sample, train_features, train_labels)
    correct_predictions += (predicted_label == true_label)
    print(f'正解ラベル: {true_label}, 予測ラベル: {predicted_label}')

accuracy = correct_predictions / len(test_labels)
print(f'検証精度: {accuracy:.3f}')

タグ: K近傍法 MindSpore ワインデータセット 機械学習 分類アルゴリズム

5月29日 02:17 投稿