Spark ML

Рассмотрим что такое Spark ML и как с ним работать.

Apache Spark

Apache Spark – это распределенный фреймворк обработки данных, ставший де-факто стандартом в обработке больших данных.
Spark состоит из нескольких компонентов, в число, которых входит и библиотеки машинного обучения.

Spark stack

Spark ML предоставляет базовый набор инструментов машинного обучения:

  • Алгоритмы, такие как классификация, регрессия, кластеризация и совместная фильтрация.
  • Методы работы с признаками
  • Конвейеры (pipelines)
  • Сохранение и загрузка моделей и конвейеров
  • Утилиты: линейная алгебра, статистика, обработка данных и т.д.

По сравнению с другими библиотеками машинного обучения, такими как scikit-learn например, набор алгоритмов в Spark ML выглядит скромнее, но он содержит все основные методы. Кроме того, Spark ML позволяет добавлять свои методы и реализовывать недостающие алгоритмы.

Spark ML состоит из двух библиотек:

  • spark.ml – это библиотека машинного обучения, основанная на DataFrame API
  • spark.mllib – на RDD API

Начиная с версии 2.0 основной библиотекой является spark.ml, но библиотека spark.mllib содержит типы данных, используемые в библиотеке spark.ml

Оба варианта Spark ML хорошо описаны в документации. Но я не буду пересказывать документацию. Рассмотрим как работать со Spark ML на конкретном примере.

Загружаем Spark

Spark можно запустить в локальном режиме, без кластера. Это позволяет познакомится с API, посмотреть особенности работы с ним.

Spark работает на JVM. Поэтому для запуска заданий и разработки приложений на компьютере должен быть установлен JDK, путь к java должен находиться в переменной PATH, и должна быть установлена переменная JAVA_HOME.

Что запустить Spark в локальном режиме надо проделать следующее:

  1. Cкачать дистрибутив Spark на свой компьютер: http://spark.apache.org/downloads.html
    • Из списка версий надо выбрать ту, которая используется у вас на работе. Если на работе Spark не используется, а есть потребность в его изучении, то лучше скачивать последнюю версию.
    • Помимо версии самого Spark есть выбор предоставляемых библиотек Hadoop. Так как мы собираемся запускать Spark локально, то вариант “Pre-built with user-provided Apache Hadoop” нам не подходит, так как в этом случае придётся скачивать и устанавливать ещё и библиотеки Hadoop. Надо выбрать один из “Pre-built for Apache Hadoop …”
  2. Распаковать архив, например в папку /opt/spark
  3. При желании можно изменить параметры, установленные по-умолчанию. Они находятся в папке conf:
    • log4j.properties – параметры логирования (Например, заменить INFO на WARN)
    • spark-defaults.conf – параметры spark-submit (Например, увеличить память драйвера)
  4. Прописать переменную SPARK_HOME

Запускаем Spark

Прежде, чем писать и компилировать программу для Spark, желательно поработать с ним в интерактивном режиме (REPL).

Для этого есть несколько вариантов:

  • spark-shell (pyspark)
    Консольный Scala/Python REPL с настроенным Spark. Входит в дистрибутив Spark. Неудобен при длительной работе.
  • Apache Zeppelin
    Сервис ноутбуков в браузере. Поддерживает большое количество интерпретаторов, включая Spark, Scala и Python. Удобен тем, что как и стандартный консольный REPL предоставляет настроенный Spark.
  • Apache Livy
    REST сервис для Spark. Позволяет запускать задания и работать интерактивно.
  • Apache Toree
    Ядро для Jupyter Notebook для работы со Spark.
  • Almond
    Scala ядро для Jupyter. Поддерживает Spark.
  • JetBrains Big Data Tools
    Плагин для IntelliJ IDEA, DataGrip и PyCharm IDE от JetBrains. Позволяет прямо из IDE работать с ноутбуками Zeppelin, предоставляет доступ к мониторингу Spark и Kafka, доступ к HDFS и т.п.

Лично я предпочитаю использовать Apache Zeppelin вместе с JetBrains Big Data Tools.

Задача машинного обучения

В качестве примера возьмём задачу предсказания оттока клиентов банка.
Описание задачи и набор данных находится на сайте Kaggle: https://www.kaggle.com/sakshigoyal7/credit-card-customers

Этот набор данных состоит из 10 000 клиентов и содержит такие признаки, как возраст, зарплата, статус по состоянию здоровья, лимит кредитной карты, категорию кредитной карты и т.д., а также переменную Attrition_Flag с признаком оттока (перестал ли клиент пользоваться услугами банка).

Мы решаем задачу бинарной классификации. Нам надо построить модель, предсказывающую к какой группе относится клиент.

Этапы ML

Из каких же этапов должен состоять проект ML?

Есть несколько методологий. Будем использовать CRISP-DM

CRISP-DM

CRISP-DM (Cross-Industry Standard Process for Data Mining) — наиболее распространённая методология по исследованию данных.

Исследование данных по методологии CRISP-DM состоит из следующих фаз:

  1. Понимание бизнес-целей (Business Understanding)
  2. Понимание данных (Data Understanding)
  3. Подготовка данных (Data Preparation)
  4. Моделирование (Modeling)
  5. Оценка (Evaluation)
  6. Внедрение (Deployment)

Будем решать нашу задачу по этим шагам.

Понимание бизнес-целей

С бизнес-целями в нашем случае всё просто. Банк заинтересован в сохранении клиентов. Предсказав клиентов, которые относятся к группе, склонной к уходу из банка, можно сработать на опережение и предложить им выгодные условия, чтобы они остались клиентами банка.

Понимание данных

Давайте загрузим набор данных и посмотрим на него.

Данные находятся в файле в формате CSV. Загрузим его стандартным для Spark способом в переменную raw типа DataFrame:

val raw = spark
        .read
        .option("header", "true")
        .option("inferSchema", "true")
        .csv(s"$basePath/data/BankChurners.csv")

Переменная basePath содержит путь к рабочему каталогу этого проекта.

В описании этого набора сказано: “PLEASE IGNORE THE LAST 2 COLUMNS (NAIVE BAYES CLAS…)”. А первая колонка содержит уникальный идентификатор клиента, который для построения модели совершенно не нужен.

Подготовим список колонок, которые надо исключить из загруженного набора – это первая и две последние колонки. Получим список колонок из DataFrame, выделим последние два элемента и добавим первый.

val columns: Array[String] = raw.columns
val columnsLen: Int = columns.length
val colsToDrop: Array[String] = columns.slice(columnsLen - 2, columnsLen) :+ columns.head

Переменная colsToDrop – это массив имён колонок, которые надо исключить из загруженного набора данных.

Для удаления колонок из DataFrame используется метод drop, аргументами которого является одно или несколько названий колонок – аргументы переменной длины. Чтобы преобразовать массив в аргументы метода в Scala применяется конструкция array: _*

val df = raw.drop(colsToDrop: _*)

Итак, переменная df типа DataFrame содержит исходный набор данных без первой и двух последних колонок. Полезно посмотреть на несколько первых записей этого набора:

df.show(5, truncate = false)
+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+
|Attrition_Flag   |Customer_Age|Gender|Dependent_count|Education_Level|Marital_Status|Income_Category|Card_Category|Months_on_book|Total_Relationship_Count|Months_Inactive_12_mon|Contacts_Count_12_mon|Credit_Limit|Total_Revolving_Bal|Avg_Open_To_Buy|Total_Amt_Chng_Q4_Q1|Total_Trans_Amt|Total_Trans_Ct|Total_Ct_Chng_Q4_Q1|Avg_Utilization_Ratio|
+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+
|Existing Customer|45          |M     |3              |High School    |Married       |$60K - $80K    |Blue         |39            |5                       |1                     |3                    |12691.0     |777                |11914.0        |1.335               |1144           |42            |1.625              |0.061                |
|Existing Customer|49          |F     |5              |Graduate       |Single        |Less than $40K |Blue         |44            |6                       |1                     |2                    |8256.0      |864                |7392.0         |1.541               |1291           |33            |3.714              |0.105                |
|Existing Customer|51          |M     |3              |Graduate       |Married       |$80K - $120K   |Blue         |36            |4                       |1                     |0                    |3418.0      |0                  |3418.0         |2.594               |1887           |20            |2.333              |0.0                  |
|Existing Customer|40          |F     |4              |High School    |Unknown       |Less than $40K |Blue         |34            |3                       |4                     |1                    |3313.0      |2517               |796.0          |1.405               |1171           |20            |2.333              |0.76                 |
|Existing Customer|40          |M     |3              |Uneducated     |Married       |$60K - $80K    |Blue         |21            |5                       |1                     |0                    |4716.0      |0                  |4716.0         |2.175               |816            |28            |2.5                |0.0                  |
+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+
only showing top 5 rows

Определяем типы колонок

Для понимания данных полезно узнать кого типа колонки есть в наборе данных.

Чаще всего для вывода схемы DataFrame используется метод printSchema:

df.printSchema
root
 |-- Attrition_Flag: string (nullable = true)
 |-- Customer_Age: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Dependent_count: integer (nullable = true)
 |-- Education_Level: string (nullable = true)
 |-- Marital_Status: string (nullable = true)
 |-- Income_Category: string (nullable = true)
 |-- Card_Category: string (nullable = true)
 |-- Months_on_book: integer (nullable = true)
 |-- Total_Relationship_Count: integer (nullable = true)
 |-- Months_Inactive_12_mon: integer (nullable = true)
 |-- Contacts_Count_12_mon: integer (nullable = true)
 |-- Credit_Limit: double (nullable = true)
 |-- Total_Revolving_Bal: integer (nullable = true)
 |-- Avg_Open_To_Buy: double (nullable = true)
 |-- Total_Amt_Chng_Q4_Q1: double (nullable = true)
 |-- Total_Trans_Amt: integer (nullable = true)
 |-- Total_Trans_Ct: integer (nullable = true)
 |-- Total_Ct_Chng_Q4_Q1: double (nullable = true)
 |-- Avg_Utilization_Ratio: double (nullable = true)

Этот метод хорошо подходит для интерактивной работы, но для обработки результата лучше использовать метод dtypes

Выведем в удобном виде названия колонок и их тип:

df.dtypes.foreach { dt => println(f"${dt._1}%25s\t${dt._2}") }
           Attrition_Flag	StringType
             Customer_Age	IntegerType
                   Gender	StringType
          Dependent_count	IntegerType
          Education_Level	StringType
           Marital_Status	StringType
          Income_Category	StringType
            Card_Category	StringType
           Months_on_book	IntegerType
 Total_Relationship_Count	IntegerType
   Months_Inactive_12_mon	IntegerType
    Contacts_Count_12_mon	IntegerType
             Credit_Limit	DoubleType
      Total_Revolving_Bal	IntegerType
          Avg_Open_To_Buy	DoubleType
     Total_Amt_Chng_Q4_Q1	DoubleType
          Total_Trans_Amt	IntegerType
           Total_Trans_Ct	IntegerType
      Total_Ct_Chng_Q4_Q1	DoubleType
    Avg_Utilization_Ratio	DoubleType

И посмотрим сколько колонок каждого типа

df.dtypes.groupBy(_._2).mapValues(_.length).foreach(println)
(DoubleType,5)
(StringType,6)
(IntegerType,9)

Проверим числовые колонки

Выделим числовые колонки и применим к ним метод summary. Этот метод вычисляет такие статистики как:

  • count
  • mean
  • stddev
  • min
  • max
  • arbitrary approximate percentiles specified as a percentage (e.g. 75%)
val numericColumns: Array[String] = df.dtypes.filter(!_._2.equals("StringType")).map(_._1)
df.select(numericColumns.map(col): _*).summary().show
+-------+-----------------+------------------+------------------+------------------------+----------------------+---------------------+-----------------+-------------------+-----------------+--------------------+-----------------+-----------------+-------------------+---------------------+
|summary|     Customer_Age|   Dependent_count|    Months_on_book|Total_Relationship_Count|Months_Inactive_12_mon|Contacts_Count_12_mon|     Credit_Limit|Total_Revolving_Bal|  Avg_Open_To_Buy|Total_Amt_Chng_Q4_Q1|  Total_Trans_Amt|   Total_Trans_Ct|Total_Ct_Chng_Q4_Q1|Avg_Utilization_Ratio|
+-------+-----------------+------------------+------------------+------------------------+----------------------+---------------------+-----------------+-------------------+-----------------+--------------------+-----------------+-----------------+-------------------+---------------------+
|  count|            10127|             10127|             10127|                   10127|                 10127|                10127|            10127|              10127|            10127|               10127|            10127|            10127|              10127|                10127|
|   mean|46.32596030413745|2.3462032191172115|35.928409203120374|      3.8125802310654686|    2.3411671768539546|   2.4553174681544387|8631.953698034848| 1162.8140614199665|7469.139636614887|  0.7599406536980376|4404.086303939963|64.85869457884863| 0.7122223758269962|   0.2748935518909845|
| stddev|8.016814032549046|  1.29890834890379|  7.98641633087208|        1.55440786533883|    1.0106223994182844|   1.1062251426359249|9088.776650223148|  814.9873352357533|9090.685323679114|  0.2192067692307027|3397.129253557085|23.47257044923301|0.23808609133294137|  0.27569146925238736|
|    min|               26|                 0|                13|                       1|                     0|                    0|           1438.3|                  0|              3.0|                 0.0|              510|               10|                0.0|                  0.0|
|    25%|               41|                 1|                31|                       3|                     2|                    2|           2555.0|                357|           1322.0|               0.631|             2155|               45|              0.581|                0.022|
|    50%|               46|                 2|                36|                       4|                     2|                    2|           4549.0|               1276|           3472.0|               0.736|             3899|               67|              0.702|                0.175|
|    75%|               52|                 3|                40|                       5|                     3|                    3|          11067.0|               1784|           9857.0|               0.859|             4741|               81|              0.818|                0.503|
|    max|               73|                 5|                56|                       6|                     6|                    6|          34516.0|               2517|          34516.0|               3.397|            18484|              139|              3.714|                0.999|
+-------+-----------------+------------------+------------------+------------------------+----------------------+---------------------+-----------------+-------------------+-----------------+--------------------+-----------------+-----------------+-------------------+---------------------+

Видно, что в данных нет пропусков и выбросов.

Теперь давайте посмотрим на значения колонки Customer_Age

df.groupBy($"Customer_Age").count().show(100)

JetBrains Big Data Tools позволяет представлять вывод в виде графиков.

Видно, что значение колонки Customer_Age имеет практически нормальное распределение.

Целевая колонка

Колонка Attrition_Flag содержит признак оттока в ввиде текстового описания. Для моделирования надо привести его к числовому виду. Поэтому введём новую колонку target, которая будет равна 0, когда значение Attrition_Flag равно “Existing Customer”, и 1 в остальных случаях.

val dft = df.withColumn("target", when($"Attrition_Flag" === "Existing Customer", 0).otherwise(1))

dft – новый DataFrame с целевой колонкой target.

Проверка сбалансированности данных

Следующее, что надо сделать – проверить набор данных на сблансированность классов.

Мы решаем задачу бинарной классификации, у нас два класса. Проверим количество записей в каждом классе.

dft.groupBy("target").count.show
+------+-----+
|target|count|
+------+-----+
|     1| 1627|
|     0| 8500|
+------+-----+

Есть несколько методов решения проблемы несбалансированных данных. Чаще всего применят undersampling – уменьшение количества записей большего класса, и oversampling – увеличение количества записей меньшего класса.

Данных у нас не очень много, поэтому будем использовать oversampling.

Oversampling

Выделим в отдельные переменные данные разных классов и сохраним количество записей в каждом классе.

val df1 = dft.filter($"target" === 1)
val df0 = dft.filter($"target" === 0)

val df1count = df1.count
val df0count = df0.count

Нужно увеличить количество записей в наборе df1 в df0count / df1count раз:

val df1Over = df1
        .withColumn("dummy", explode(lit((1 to (df0count / df1count).toInt).toArray)))
        .drop("dummy")

Давайте рассмотим это подробнее.

Конструкция (1 to (df0count / df1count).toInt).toArray создаёт массив со значениями от 1 до (df0count / df1count)

(1 to (df0count / df1count).toInt).toArray

res77: Array[Int] = Array(1, 2, 3, 4, 5)

Функция lit создаёт колонки с определённым значением. Мы добавляем колонку с именем dummy, значением которой является массив:

df1
        .withColumn("dummy", lit((1 to (df0count / df1count).toInt).toArray))
        .select("Attrition_Flag", "Customer_Age", "dummy")
        .show(10)
+-----------------+------------+---------------+
|   Attrition_Flag|Customer_Age|          dummy|
+-----------------+------------+---------------+
|Attrited Customer|          62|[1, 2, 3, 4, 5]|
|Attrited Customer|          66|[1, 2, 3, 4, 5]|
|Attrited Customer|          54|[1, 2, 3, 4, 5]|
|Attrited Customer|          56|[1, 2, 3, 4, 5]|
|Attrited Customer|          48|[1, 2, 3, 4, 5]|
|Attrited Customer|          55|[1, 2, 3, 4, 5]|
|Attrited Customer|          47|[1, 2, 3, 4, 5]|
|Attrited Customer|          53|[1, 2, 3, 4, 5]|
|Attrited Customer|          48|[1, 2, 3, 4, 5]|
|Attrited Customer|          59|[1, 2, 3, 4, 5]|
+-----------------+------------+---------------+
only showing top 10 rows

Функция explode создаёт новую строку для каждого элемента массива:

df1
        .withColumn("dummy", explode(lit((1 to (df0count / df1count).toInt).toArray)))
        .select("Attrition_Flag", "Customer_Age", "dummy")
        .show(10)
+-----------------+------------+-----+
|   Attrition_Flag|Customer_Age|dummy|
+-----------------+------------+-----+
|Attrited Customer|          62|    1|
|Attrited Customer|          62|    2|
|Attrited Customer|          62|    3|
|Attrited Customer|          62|    4|
|Attrited Customer|          62|    5|
|Attrited Customer|          66|    1|
|Attrited Customer|          66|    2|
|Attrited Customer|          66|    3|
|Attrited Customer|          66|    4|
|Attrited Customer|          66|    5|
+-----------------+------------+-----+
only showing top 10 rows

Итак, df1Over – это набор, содержащий записи класса target = 1, увеличенный в df0count / df1count раз.

Объединим этот новый набор с набором записей второго класса и проверим сбалансированность исходного набора:

val data = df0.unionAll(df1Over)
data.groupBy("target").count.show
+------+-----+
|target|count|
+------+-----+
|     1| 8135|
|     0| 8500|
+------+-----+

DataFrame data – это сбалансированный набор данных, с которым мы будем дальше работать.

Подготовка данных (работа с признаками)

Для этапа подготовки данных в Spark ML есть следующие группы алгоритмов:

  • Extraction – извлечение объектов из “необработанных” данных
  • Transformation – масштабирование, преобразование или изменение объектов
  • Selection – выбор подмножества из большего набора объектов
  • Locality Sensitive Hashing (LSH) – этот класс алгоритмов сочетает в себе аспекты преобразования признаков с другими алгоритмами

Работают они похожим образом:

  • Создаём объект-преобразователь с нужными параметрами
  • Применяем этот объект к исходному набору данных
  • Получаем новый набор данных, с которым продолжаем работать

Перейдём к работе с признаками нашего набора данных.

Проверим корреляции числовых признаков

Надо проверять корреляцию числовых признаков между собой и исключать признаки с высокой корреляцией.

Составим список всех пар числовых признаков:

val numericColumnsPairs = numericColumns.flatMap(f1 => numericColumns.map(f2 => (f1, f2)))

Переменная numericColumns – это массив названий колонок с числовыми типами значений (целые или с плавающей точкой).

Список всех пар можно также получить таким способом:

for {
  x <- numericColumns
  y <- numericColumns
} yield (x, y)

Фактически, это разные способы записи одного и того же действия.

Проверить корреляцию в Spark можно двумя способами:

  • DataFrameStatFunctions – Статистические функции для DataFrame
  • Correlation – API для корреляционных функций в MLlib

Проверим корреляцию наших числовых признаков обоими способами.

Вариант 1: DataFrameStatFunctions

Составим список всех пар числовых признаков, уберём пары из одинаковых названий, отсортируем пары в лексиграфическом порядке, и оставим только уникальные комбинациии пар:

val pairs = numericColumnsPairs
        .filter { p => !p._1.equals(p._2) }
        .map { p => if (p._1 < p._2) (p._1, p._2) else (p._2, p._1) }
        .distinct

Для каждой пары применим статистическую функцию вычисления корреляции к сбалансированному набору данных и выделим пары с корреляцией больше 0.6:

val corr = pairs
        .map { p => (p._1, p._2, data.stat.corr(p._1, p._2)) }
        .filter(_._3 > 0.6)

Выведем результат в удобном виде:

corr.sortBy(_._3).reverse.foreach { c => println(f"${c._1}%25s${c._2}%25s\t${c._3}") }

          Avg_Open_To_Buy             Credit_Limit	0.9952040726156253
          Total_Trans_Amt           Total_Trans_Ct	0.8053901681243808
             Customer_Age           Months_on_book	0.7805047706891142
    Avg_Utilization_Ratio      Total_Revolving_Bal	0.6946855441968229

Вариант 2: Correlation

Чтобы воспользоваться вторым способом, надо собрать все числовые признаки в одну колонку типа Vector. Для этого используется преобразователь VectorAssembler. Применив VectorAssembler к нашему набору данных, получим новый набор данных numeric с колонкой features, содержащей вектор с числовыми признаками.

Применив метод corr объекта Correlation к новому набору данных numeric, можно получить матрицу корреляции числовых признаков:

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.stat.Correlation
import org.apache.spark.ml.linalg.Matrix
import org.apache.spark.sql.Row

val numericAssembler = new VectorAssembler()
  .setInputCols(numericColumns)
  .setOutputCol("features")

val numeric = numericAssembler.transform(data)
val Row(matrix: Matrix) = Correlation.corr(numeric, "features").head

Переменная matrix – это матрица корреляции числовых признаков:

matrix: org.apache.spark.ml.linalg.Matrix =
1.0                    -0.13575515707704905   ... (14 total)
-0.13575515707704905   1.0                    ...
0.780504770689084      -0.11728062823959522   ...
-0.026525310066416643  -0.032664177863511015  ...
0.13116552936201348    -0.0106575011505989...

Теперь сопоставим матрицу корреляции с названиями числовых признаков, уберём пары из одинаковых названий, выделим пары с корреляцией больше 0.6, отсортируем пары в лексиграфическом порядке, и оставим только уникальные комбинациии пар:

val corr2 = matrix.toArray
        .zip(numericColumnsPairs)
        .map(cnn => (cnn._2._1, cnn._2._2, cnn._1))
        .filter(_._3 < 1.0)
        .filter(_._3 > 0.6)
        .map { p => if (p._1 < p._2) (p._1, p._2, p._3) else (p._2, p._1, p._3) }
        .distinct

Выведем результат в удобном виде:

corr2.sortBy(_._3).reverse.foreach { c => println(f"${c._1}%25s${c._2}%25s\t${c._3}") }

          Avg_Open_To_Buy             Credit_Limit	0.9952040726156179
          Total_Trans_Amt           Total_Trans_Ct	0.8053901681243786
             Customer_Age           Months_on_book	0.780504770689084
    Avg_Utilization_Ratio      Total_Revolving_Bal	0.6946855441968222

Видно, что результат, полученный разными способами, совпадает.

Для проверки представим результаты в виже множеств и посмотрим на их пересечение:

corr.toSet.intersect(corr2.toSet)

res84: scala.collection.immutable.Set[(String, String, Double)] = Set()

Получили пустое множество, что подтверждает эквивалентность результатов, полученных разными способами.

Соберём список числовых признаков с низкой корреляцией в переменную numericColumnsFinal:

val numericColumnsFinal = numericColumns.diff(corr.map(_._2))

Категориальные признаки

Теперь займёмся категориальными признаками.

Категориальный признак – это признак, значения которого обозначают принадлежность объекта к какой-то категории. Значения категориальных признаков – это наборы дискретных значений.

Но подавляющее большинство методов классификации и регрессии сформулированы в терминах метрических пространств, то есть подразумевают представление данных в виде вещественных векторов одинаковой размерности.

Поэтому для использования категориальных признаков их надо кодировать – преобразовать в непрерывные. Вместо одной категориальной переменной создается несолько, по количеству уникальных значений категориальной переменной. Значениями новых переменных будут 1.0 и 0.0 в соответствии со значением категориальной переменной.

Для кодирования категориальных переменных в Spark ML используется преобразователь OneHotEncoder.

Но прежде, чем применять его к признакам, содержащим строки, их надо проиндексировать. Для этого используется преобразователь StringIndexer.

В нашем наборе данных категориальными являются только колонки, содержащие строки.

Иногда к категориальным относят такой признак, как возраст. Но, как мы видели, в нашем случае возраст имеет практически нормальное распределение. Вот если бы у нас была переменная с группами возрастов, тогда с такой переменной надо работать как с категориальной.

Индексируем строковые колонки

Составим список всех строковых колонок за исключением колонки Attrition_Flag, которая является целевой, и проиндесируем их, создав новые колонки, добавив _Indexed к названию исходных колонок.

import org.apache.spark.ml.feature.StringIndexer

val stringColumns = data
        .dtypes
        .filter(_._2.equals("StringType"))
        .map(_._1)
        .filter(!_.equals("Attrition_Flag"))

val stringColumnsIndexed = stringColumns.map(_ + "_Indexed")

val indexer = new StringIndexer()
        .setInputCols(stringColumns)
        .setOutputCols(stringColumnsIndexed)

val indexed = indexer.fit(data).transform(data)

indexed – это новый набор данных с проиндексированными строковыми колонками.

Кодируем категориальные признаки

Теперь можно перейти к кодированию категориальных признаков.

Кодированные признаки будут находится в новых колонках, к названию которых будет добавлено _Coded

import org.apache.spark.ml.feature.OneHotEncoder

val catColumns = stringColumnsIndexed.map(_ + "_Coded")
    
val encoder = new OneHotEncoder()
        .setInputCols(stringColumnsIndexed)
        .setOutputCols(catColumns)

val encoded = encoder.fit(indexed).transform(indexed)

encoded – это новый набор данных с кодированными категориальными признаками

Собираем признаки в вектор

После обработки категориальных признаков надо собрать все признаки в вектор.

Для этого используется преобразователь VectorAssembler, с которым мы уже встречались, когда вычисляли корреляцию числовых признаков вторым способом.

Применим его к списку числовых признаков с низкой корреляцией, объединному со списком кодированных категориальных переменных.

val featureColumns = numericColumnsFinal ++ catColumns

val assembler = new VectorAssembler()
  .setInputCols(featureColumns)
  .setOutputCol("features")

val assembled = assembler.transform(encoded)

assembled – это набо данных, содержий колонку features, значениями которой является вектор признаков.

Нормализация

Давайте посмотрим на вектор признаков, который получился у нас в итоге.

assembled.select("features").show(5, truncate = false)
+--------------------------------------------------------------------------------------------------------------------+
|features                                                                                                            |
+--------------------------------------------------------------------------------------------------------------------+
|(28,[0,1,2,3,4,5,6,7,8,9,12,17,23,25],[45.0,3.0,5.0,1.0,3.0,11914.0,1.335,1144.0,1.625,0.061,1.0,1.0,1.0,1.0])      |
|(28,[0,1,2,3,4,5,6,7,8,9,10,11,18,20,25],[49.0,5.0,6.0,1.0,2.0,7392.0,1.541,1291.0,3.714,0.105,1.0,1.0,1.0,1.0,1.0])|
|(28,[0,1,2,3,5,6,7,8,11,17,22,25],[51.0,3.0,4.0,1.0,3418.0,2.594,1887.0,2.333,1.0,1.0,1.0,1.0])                     |
|(28,[0,1,2,3,4,5,6,7,8,9,10,12,19,20,25],[40.0,4.0,3.0,4.0,1.0,796.0,1.405,1171.0,2.333,0.76,1.0,1.0,1.0,1.0,1.0])  |
|(28,[0,1,2,3,5,6,7,8,14,17,23,25],[40.0,3.0,5.0,1.0,4716.0,2.175,816.0,2.5,1.0,1.0,1.0,1.0])                        |
+--------------------------------------------------------------------------------------------------------------------+
only showing top 5 rows

Видна большая разница в значениях признаков.

Рекомендуется провести стандартизацию (удаление среднего и масштабирование дисперсии) или нормализацию (масштабирования отдельных образцов до единичной нормы) набора данных.

В Spark ML есть несколько методов, с помощью которых можно сделать такие преобразования:

  • Normalizer – нормализует вектор для получения единичной нормы
  • StandardScaler – нормализация каждого признака для получения единичного стандартного отклонения и/или нулевого среднего
  • RobustScaler – удаление медианы и масштабирование данных в соответствии с определенным диапазоном квантилей
  • MinMaxScaler – масштабирование каждого признака в определенном диапазоне (часто [0, 1])
  • MaxAbsScaler – масштабирование каждого признака в диапазоне [-1, 1] путем деления на максимальное абсолютное значение

Применим MinMaxScaler для нормализации набора данных:

import org.apache.spark.ml.feature.MinMaxScaler

val scaler = new MinMaxScaler()
  .setInputCol("features")
  .setOutputCol("scaledFeatures")

val scaled = scaler.fit(assembled).transform(assembled)

scaled – это набор данных с вектором нормализованных признаков в колонке scaledFeatures

scaled.select("features", "scaledFeatures").show(5, truncate = false)

+--------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|features                                                                                                            |scaledFeatures                                                                                                                                                                                                                      |
+--------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(28,[0,1,2,3,4,5,6,7,8,9,12,17,23,25],[45.0,3.0,5.0,1.0,3.0,11914.0,1.335,1144.0,1.625,0.061,1.0,1.0,1.0,1.0])      |(28,[0,1,2,3,4,5,6,7,8,9,12,17,23,25],[0.40425531914893614,0.6000000000000001,0.8,0.16666666666666666,0.5,0.3451163329759801,0.39299381807477185,0.03527317236007566,0.43753365643511044,0.061061061061061066,1.0,1.0,1.0,1.0])     |
|(28,[0,1,2,3,4,5,6,7,8,9,10,11,18,20,25],[49.0,5.0,6.0,1.0,2.0,7392.0,1.541,1291.0,3.714,0.105,1.0,1.0,1.0,1.0,1.0])|(28,[0,1,2,3,4,5,6,7,8,9,10,11,18,20,25],[0.48936170212765956,1.0,1.0,0.16666666666666666,0.3333333333333333,0.21409324022831977,0.4536355607889314,0.043451652386780906,1.0,0.10510510510510511,1.0,1.0,1.0,1.0,1.0])              |
|(28,[0,1,2,3,5,6,7,8,11,17,22,25],[51.0,3.0,4.0,1.0,3418.0,2.594,1887.0,2.333,1.0,1.0,1.0,1.0])                     |(28,[0,1,2,3,5,6,7,8,11,17,22,25],[0.5319148936170213,0.6000000000000001,0.6000000000000001,0.16666666666666666,0.09894822240894735,0.7636149543715043,0.07661065984199399,0.6281637049003771,1.0,1.0,1.0,1.0])                     |
|(28,[0,1,2,3,4,5,6,7,8,9,10,12,19,20,25],[40.0,4.0,3.0,4.0,1.0,796.0,1.405,1171.0,2.333,0.76,1.0,1.0,1.0,1.0,1.0])  |(28,[0,1,2,3,4,5,6,7,8,9,10,12,19,20,25],[0.2978723404255319,0.8,0.4,0.6666666666666666,0.16666666666666666,0.02297684930316113,0.41360023550191344,0.036775342160899074,0.6281637049003771,0.7607607607607608,1.0,1.0,1.0,1.0,1.0])|
|(28,[0,1,2,3,5,6,7,8,14,17,23,25],[40.0,3.0,5.0,1.0,4716.0,2.175,816.0,2.5,1.0,1.0,1.0,1.0])                        |(28,[0,1,2,3,5,6,7,8,14,17,23,25],[0.2978723404255319,0.6000000000000001,0.8,0.16666666666666666,0.13655723930113292,0.6402708272004709,0.017024591075998664,0.6731287022078623,1.0,1.0,1.0,1.0])                                   |
+--------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 5 rows

Feature Selection (отбор признаков)

Вектор признаков нашего набора данных содержит 28 признаков. Это не очень много. Тем не менее рассмотрим процедуру отбора признаков – выделения наиболее важных из них

UnivariateFeatureSelector – это универсальный преобразователь, которые позволяет выделить наиболее важные признаки. Он работает с категориальными/непрерывными признаками и категориальными/непрерывными целевыми переменными. Функция оценки выбирается исходя из типа признаков и целевой переменной:

featureType |  labelType |score function
------------|------------|--------------
categorical |categorical | chi-squared (chi2)
continuous  |categorical | ANOVATest (f_classif)
continuous  |continuous  | F-value (f_regression)

Поддерживаются следующие методы отбора:

  • numTopFeatures – фиксированное число отбираемых признаков
  • percentile – выбор по перцентилю
  • fpr отбирает признаки, p-value которых ниже порогового значения
  • fdr использует процедуру Бенджамини-Хохберга для выбора признаков, частота ложных обнаружений которых ниже порогового значения
  • fwe отбирает признаки, p-value которых ниже порогового значения. Пороговое значение масштабируется по 1/numFeatures

Применим UnivariateFeatureSelector с выбором по перцентилю с пороговым значением 0.75

import org.apache.spark.ml.feature.UnivariateFeatureSelector

val selector = new UnivariateFeatureSelector()
  .setFeatureType("continuous")
  .setLabelType("categorical")
  .setSelectionMode("percentile")
  .setSelectionThreshold(0.75)
  .setFeaturesCol("scaledFeatures")
  .setLabelCol("target")
  .setOutputCol("selectedFeatures")

val dataF = selector.fit(scaled).transform(scaled)

dataF – это набор данных с вектором отобранных признаков в колонке selectedFeatures

dataF.select("scaledFeatures", "selectedFeatures").show(5, truncate = false)

+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|scaledFeatures                                                                                                                                                                                                                      |selectedFeatures                                                                                                                                                                                      |
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(28,[0,1,2,3,4,5,6,7,8,9,12,17,23,25],[0.40425531914893614,0.6000000000000001,0.8,0.16666666666666666,0.5,0.3451163329759801,0.39299381807477185,0.03527317236007566,0.43753365643511044,0.061061061061061066,1.0,1.0,1.0,1.0])     |(21,[0,1,2,3,4,5,6,7,8,11,14,19],[0.40425531914893614,0.6000000000000001,0.8,0.16666666666666666,0.5,0.39299381807477185,0.03527317236007566,0.43753365643511044,0.061061061061061066,1.0,1.0,1.0])   |
|(28,[0,1,2,3,4,5,6,7,8,9,10,11,18,20,25],[0.48936170212765956,1.0,1.0,0.16666666666666666,0.3333333333333333,0.21409324022831977,0.4536355607889314,0.043451652386780906,1.0,0.10510510510510511,1.0,1.0,1.0,1.0,1.0])              |(21,[0,1,2,3,4,5,6,7,8,9,10,15,17],[0.48936170212765956,1.0,1.0,0.16666666666666666,0.3333333333333333,0.4536355607889314,0.043451652386780906,1.0,0.10510510510510511,1.0,1.0,1.0,1.0])              |
|(28,[0,1,2,3,5,6,7,8,11,17,22,25],[0.5319148936170213,0.6000000000000001,0.6000000000000001,0.16666666666666666,0.09894822240894735,0.7636149543715043,0.07661065984199399,0.6281637049003771,1.0,1.0,1.0,1.0])                     |(21,[0,1,2,3,5,6,7,10,14],[0.5319148936170213,0.6000000000000001,0.6000000000000001,0.16666666666666666,0.7636149543715043,0.07661065984199399,0.6281637049003771,1.0,1.0])                           |
|(28,[0,1,2,3,4,5,6,7,8,9,10,12,19,20,25],[0.2978723404255319,0.8,0.4,0.6666666666666666,0.16666666666666666,0.02297684930316113,0.41360023550191344,0.036775342160899074,0.6281637049003771,0.7607607607607608,1.0,1.0,1.0,1.0,1.0])|(21,[0,1,2,3,4,5,6,7,8,9,11,16,17],[0.2978723404255319,0.8,0.4,0.6666666666666666,0.16666666666666666,0.41360023550191344,0.036775342160899074,0.6281637049003771,0.7607607607607608,1.0,1.0,1.0,1.0])|
|(28,[0,1,2,3,5,6,7,8,14,17,23,25],[0.2978723404255319,0.6000000000000001,0.8,0.16666666666666666,0.13655723930113292,0.6402708272004709,0.017024591075998664,0.6731287022078623,1.0,1.0,1.0,1.0])                                   |(21,[0,1,2,3,5,6,7,14,19],[0.2978723404255319,0.6000000000000001,0.8,0.16666666666666666,0.6402708272004709,0.017024591075998664,0.6731287022078623,1.0,1.0])                                         |
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 5 rows

Мы сократили количество признаков с 28 до 21.

На этом заканчивается этап подготовки данных и можно переходить к следующему этапу – моделирование

Моделирование

Для построения моделей Spark ML предлагает такой набор алгоритмов:

Продемонстрируем этап моделирования на нашем примере.

Обучающая и тестовая выборки

Перед тем, как перейти к построению модели, необходимо разбить набор данных на обучающую и тестовую выборки. Для этого в Spark есть стандартный метод randomSplit, аргументом которого является массив с пропорциями разделения.

val tt = dataF.randomSplit(Array(0.7, 0.3))
val training = tt(0)
val test = tt(1)

training – это обучающая выборка с 70% записей, а test – это тестовая выборка с, соответственно, 30% записей.

Логистическая регрессия

Мы решаем задачу бинарной классификации. Будем использовать логистическую регрессию, как хорошо зарекомендовавший себя алгоритм.

Для этого надо использовать объект LogisticRegression, основными параметрами которого являются:

  • elasticNetParamα
  • regParamλ

Выберем для начала эти параметры произвольным образом.

import org.apache.spark.ml.classification.LogisticRegression

val lr = new LogisticRegression()
        .setMaxIter(1000)
        .setRegParam(0.2)
        .setElasticNetParam(0.8)
        .setFamily("binomial")
        .setFeaturesCol("selectedFeatures")
        .setLabelCol("target")

val lrModel = lr.fit(training)

lrModel – это обученная модель.

Training Summary

Мы можем получить основную информацию об обученной модели. Для этого используется объект BinaryLogisticRegressionTrainingSummary:

val trainingSummary = lrModel.binarySummary

println(s"accuracy: ${trainingSummary.accuracy}")
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
accuracy: 0.6986124278203912
areaUnderROC: 0.7455570759572957

Мы получили AUROC примерно 0.75, что, в принципе, неплохо.

Оценка

Проверяем модель на тестовой выборке

Применим обученную модель к тестовой выборке и посмотрим на результат.

val predicted = lrModel.transform(test)

Набор predicted содержит новые колонки: rawPrediction, probability и prediction:

predicted.select("target", "rawPrediction", "probability", "prediction").show(10, truncate = false)

+------+----------------------------------------------+----------------------------------------+----------+
|target|rawPrediction                                 |probability                             |prediction|
+------+----------------------------------------------+----------------------------------------+----------+
|0     |[0.040262722641592585,-0.040262722641592585]  |[0.5100643211022606,0.48993567889773937]|0.0       |
|0     |[-0.009994173386193073,0.009994173386193073]  |[0.4975014774501823,0.5024985225498177] |1.0       |
|0     |[0.18939904012242004,-0.18939904012242004]    |[0.547208721739737,0.452791278260263]   |0.0       |
|0     |[0.057015021317521175,-0.057015021317521175]  |[0.5142498953455751,0.4857501046544249] |0.0       |
|0     |[-0.030423805917813296,0.030423805917813296]  |[0.4923946351436886,0.5076053648563115] |1.0       |
|0     |[-0.023886323507694818,0.023886323507694818]  |[0.49402870303387675,0.5059712969661232]|1.0       |
|0     |[-0.05167062375069831,0.05167062375069831]    |[0.4870852173158024,0.5129147826841975] |1.0       |
|0     |[0.0026721987834114613,-0.0026721987834114613]|[0.5006680492983275,0.49933195070167247]|0.0       |
|0     |[-0.05085343844943349,0.05085343844943349]    |[0.48728937948478424,0.5127106205152158]|1.0       |
|0     |[-0.026746472062121662,0.026746472062121662]  |[0.4933137805752158,0.5066862194247842] |1.0       |
+------+----------------------------------------------+----------------------------------------+----------+
only showing top 10 rows

В идеале значения в колонках target и prediction должны совпадать. Но, как мы видим, разница есть даже в первых десяти записях.

Для оценки применения модели к тестовой выборке можно воспользоваться объектом BinaryClassificationEvaluator:

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

val evaluator = new BinaryClassificationEvaluator().setLabelCol("target")

println(s"areaUnderROC: ${evaluator.evaluate(predicted)}\n")
areaUnderROC: 0.7445924078797251

AUROC на тестовой выборке тоже примерно 0.75

Confusion Matrix (матрица ошибок)

Полезным способом оценки модели является Матрица ошибок.

  • True Positive (TP) – label is positive and prediction is also positive
  • True Negative (TN) – label is negative and prediction is also negative
  • False Positive (FP) – label is negative but prediction is positive
  • False Negative (FN) – label is positive but prediction is negative

В Spark ML нет методов, вычисляющих матрицу ошибок непосредственно, но её легко вычислить непосредственно:

val tp = predicted.filter(($"target" === 1) and ($"prediction" === 1)).count
val tn = predicted.filter(($"target" === 0) and ($"prediction" === 0)).count
val fp = predicted.filter(($"target" === 0) and ($"prediction" === 1)).count
val fn = predicted.filter(($"target" === 1) and ($"prediction" === 0)).count

println(s"Confusion Matrix:\n$tp\t$fp\n$fn\t$tn\n")
Confusion Matrix:
1272	309
1198	2253

Желательно, чтобы значения на главной диагонали матрицы были большими, а на побочной – маленькими.

Accuracy, Precision, Recall

Следующими широко используемыми метриками оценки качества являются:

  • Accuracy (доля правильных ответов) = TP + TN / TP + TN + FP + FN
  • Precision (точность) = TP / TP + FP
  • Recall (полнота) = TP / TP + FN

Их легко вычислить по матрице ошибок:

val accuracy = (tp + tn) / (tp + tn + fp + fn).toDouble
val precision = tp / (tp + fp).toDouble
val recall = tp / (tp + fn).toDouble

println(s"Accuracy = $accuracy")
println(s"Precision = $precision")
println(s"Recall = $recall\n")
Accuracy = 0.700516693163752
Precision = 0.8045540796963947
Recall = 0.5149797570850202

Настройка моделей

Подбор гиперпараметров

При обучении нашей модели мы выбирали регуляризационные параметры произвольным образом. Давайте теперь посмотрим как можно подобрать оптимальные параметры для модели.

Для подбора гиперпараметров (выбора модели) Spark ML предлагает два инструмента: CrossValidator и TrainValidationSplit.

В обоих случаях требуется предоставить:

  • Estimator – алгоритм, который надо настроить
  • Набор параметров: параметры для выбора (“сетка параметров”)
  • Evaluator – объект для оценки модели

В общем случае процесс подбора гиперпараметров выглядит так:

  • Набор данных разбивается на обучающую и тестовую выборки
  • Для каждой пары (training, test) перебираются параметры из сетки параметров
  • Для каждого набора парметров применяется Estimator для построения модели
  • Evaluator оценивает каждую модель
  • Выбирается модель с лучшими показателями

В качестве Evaluator может использоваться:

Для построения сетки параметров используется объект ParamGridBuilder.

CrossValidator разбивает набор данных на набор folds, сочетания которых используются для обучения и тестирования. Оценка модели проходит для всех сочетаний folds.

TrainValidationSplit разбивает набор на обучающую и тестовую выборку и оценивает модель на этом разбиение.

Для подбора гиперпараметров будем использовать TrainValidationSplit:

import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}

val paramGrid = new ParamGridBuilder()
  .addGrid(lr.regParam, Array(0.01, 0.1, 0.5))
  .addGrid(lr.fitIntercept)
  .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
  .build()

val trainValidationSplit = new TrainValidationSplit()
  .setEstimator(lr)
  .setEvaluator(evaluator)
  .setEstimatorParamMaps(paramGrid)
  .setTrainRatio(0.7)
  .setParallelism(2)

val model = trainValidationSplit.fit(dataF)

Лучшая модель находтся в bestmodel:

model.bestModel.extractParamMap()

res89: org.apache.spark.ml.param.ParamMap =
{
	logreg_2eef3ae8c923-aggregationDepth: 2,
	logreg_2eef3ae8c923-elasticNetParam: 0.0,
	logreg_2eef3ae8c923-family: binomial,
	logreg_2eef3ae8c923-featuresCol: selectedFeatures,
	logreg_2eef3ae8c923-fitIntercept: true,
	logreg_2eef3ae8c923-labelCol: target,
	logreg_2eef3ae8c923-maxBlockSizeInMB: 0.0,
	logreg_2eef3ae8c923-maxIter: 1000,
	logreg_2eef3ae8c923-predictionCol: prediction,
	logreg_2eef3ae8c923-probabilityCol: probability,
	logreg_2eef3ae8c923-rawPredictionCol: rawPrediction,
	logreg_2eef3ae8c923-regParam: 0.01,
	logreg_2eef3ae8c923-standardization: true,
	logreg_2eef3ae8c923-threshold: 0.5,
	logreg_2eef3ae8c923-tol: 1.0E-6
}

Сохраним лучшую модель для дальнейшего использования:

val bestML = model.bestModel

Внедрение

ML Pipelines

Что важно для внедрения моделей? Безошибочная повторяемость.

Давайте вспомним все этапы подготовки и расчёта моделей:

  1. Отобрали числовые признаки (numericColumnsFinal)
  2. Проиндексировали строковые признаки (indexer)
  3. Закодировали категориальные признки (encoder)
  4. Собрали признаки в вектор (assembler)
  5. Нормализовали признаки (scaler)
  6. Провели отбор признаков (selector)
  7. Рассчитали модель (bestML)

Прежде, чем применять расчитанную модель, мы должны применить весь набор преобразований к набору данных. При повторении расчётов легко ошибиться в этих этапах или, даже, пропустить какой-нибуть из них.

Хорошо бы построить модель, включающую в себя все необходимые преобразования.

ML Pipelines позволяют объединить все преобразования и алгоритмы в один конвейер или рабочий процесс:

import org.apache.spark.ml.Pipeline

val pipeline = new Pipeline().setStages(Array(indexer, encoder, assembler, scaler, selector, bestML))

Теперь, используя Pipeline, мы можем построить модель, включающую все необходимые преобразования.

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

val pipelineModel = pipeline.fit(trainingData)

ML Persistence

Чтобы переиспользовать подготовленную модель нужна возможность сохранять и загружать их. Это обеспечивает ML persistence.

Сохраним конвейерную модель (PipelineModel):

pipelineModel.write.overwrite().save(s"$basePath/pipelineModel")

Spark ML Production

Сохранённую модель можно загружать и использовать отдельно от исследовательского проекта, в котором мы её подготовили.

Загрузим набор данных (мы будем использовать тот же самый набор данных, но на практике обученную модель применяют к новому набору данных), загрузим конвейерную модель и применим её к набору данных:

val data = spark
        .read
        .option("header", "true")
        .option("inferSchema", "true")
        .csv(s"$basePath/data/BankChurners.csv")

import org.apache.spark.ml.PipelineModel

val model = PipelineModel.load(s"$basePath/pipelineModel")

val prediction = model.transform(data)

prediction – это набор данных, который содержит исходные данные, данные, полученные в результате преобразований, и результат применения модели – предсказание.

prediction.show(5)

+---------+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------+-----------------------+--------------+-----------------------+---------------------+-----------------------------+----------------------------+--------------------+-----------------------------+---------------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|CLIENTNUM|   Attrition_Flag|Customer_Age|Gender|Dependent_count|Education_Level|Marital_Status|Income_Category|Card_Category|Months_on_book|Total_Relationship_Count|Months_Inactive_12_mon|Contacts_Count_12_mon|Credit_Limit|Total_Revolving_Bal|Avg_Open_To_Buy|Total_Amt_Chng_Q4_Q1|Total_Trans_Amt|Total_Trans_Ct|Total_Ct_Chng_Q4_Q1|Avg_Utilization_Ratio|Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1|Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2|Marital_Status_Indexed|Income_Category_Indexed|Gender_Indexed|Education_Level_Indexed|Card_Category_Indexed|Income_Category_Indexed_Coded|Marital_Status_Indexed_Coded|Gender_Indexed_Coded|Education_Level_Indexed_Coded|Card_Category_Indexed_Coded|            features|      scaledFeatures|    selectedFeatures|       rawPrediction|         probability|prediction|
+---------+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------+-----------------------+--------------+-----------------------+---------------------+-----------------------------+----------------------------+--------------------+-----------------------------+---------------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|768805383|Existing Customer|          45|     M|              3|    High School|       Married|    $60K - $80K|         Blue|            39|                       5|                     1|                    3|     12691.0|                777|        11914.0|               1.335|           1144|            42|              1.625|                0.061|                                                                                                                         9.3448E-5|                                                                                                                           0.99991|                   0.0|                    3.0|           1.0|                    1.0|                  0.0|                (5,[3],[1.0])|               (3,[0],[1.0])|           (1,[],[])|                (6,[1],[1.0])|              (3,[0],[1.0])|(28,[0,1,2,3,4,5,...|(28,[0,1,2,3,4,5,...|(21,[0,1,2,3,4,5,...|[3.80023517501469...|[0.97812376182671...|       0.0|
|818770008|Existing Customer|          49|     F|              5|       Graduate|        Single| Less than $40K|         Blue|            44|                       6|                     1|                    2|      8256.0|                864|         7392.0|               1.541|           1291|            33|              3.714|                0.105|                                                                                                                         5.6861E-5|                                                                                                                           0.99994|                   1.0|                    0.0|           0.0|                    0.0|                  0.0|                (5,[0],[1.0])|               (3,[1],[1.0])|       (1,[0],[1.0])|                (6,[0],[1.0])|              (3,[0],[1.0])|(28,[0,1,2,3,4,5,...|(28,[0,1,2,3,4,5,...|(21,[0,1,2,3,4,5,...|[10.8273709791065...|[0.99998015167282...|       0.0|
|713982108|Existing Customer|          51|     M|              3|       Graduate|       Married|   $80K - $120K|         Blue|            36|                       4|                     1|                    0|      3418.0|                  0|         3418.0|               2.594|           1887|            20|              2.333|                  0.0|                                                                                                                         2.1081E-5|                                                                                                                           0.99998|                   0.0|                    2.0|           1.0|                    0.0|                  0.0|                (5,[2],[1.0])|               (3,[0],[1.0])|           (1,[],[])|                (6,[0],[1.0])|              (3,[0],[1.0])|(28,[0,1,2,3,5,6,...|(28,[0,1,2,3,5,6,...|(21,[0,1,2,3,5,6,...|[7.25204091946045...|[0.99929177547928...|       0.0|
|769911858|Existing Customer|          40|     F|              4|    High School|       Unknown| Less than $40K|         Blue|            34|                       3|                     4|                    1|      3313.0|               2517|          796.0|               1.405|           1171|            20|              2.333|                 0.76|                                                                                                                         1.3366E-4|                                                                                                                           0.99987|                   2.0|                    0.0|           0.0|                    1.0|                  0.0|                (5,[0],[1.0])|               (3,[2],[1.0])|       (1,[0],[1.0])|                (6,[1],[1.0])|              (3,[0],[1.0])|(28,[0,1,2,3,4,5,...|(28,[0,1,2,3,4,5,...|(21,[0,1,2,3,4,5,...|[5.90527865832718...|[0.99728238324157...|       0.0|
|709106358|Existing Customer|          40|     M|              3|     Uneducated|       Married|    $60K - $80K|         Blue|            21|                       5|                     1|                    0|      4716.0|                  0|         4716.0|               2.175|            816|            28|                2.5|                  0.0|                                                                                                                         2.1676E-5|                                                                                                                           0.99998|                   0.0|                    3.0|           1.0|                    3.0|                  0.0|                (5,[3],[1.0])|               (3,[0],[1.0])|           (1,[],[])|                (6,[3],[1.0])|              (3,[0],[1.0])|(28,[0,1,2,3,5,6,...|(28,[0,1,2,3,5,6,...|(21,[0,1,2,3,5,6,...|[7.90379412318342...|[0.99963079680242...|       0.0|
+---------+-----------------+------------+------+---------------+---------------+--------------+---------------+-------------+--------------+------------------------+----------------------+---------------------+------------+-------------------+---------------+--------------------+---------------+--------------+-------------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------+-----------------------+--------------+-----------------------+---------------------+-----------------------------+----------------------------+--------------------+-----------------------------+---------------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
only showing top 5 rows

Проверим результат

Для проверки результата вычислим Матрицу ошибок:

val tp = prediction.filter(($"Attrition_Flag" === "Attrited Customer") and ($"prediction" === 1)).count
val tn = prediction.filter(($"Attrition_Flag" === "Existing Customer") and ($"prediction" === 0)).count
val fp = prediction.filter(($"Attrition_Flag" === "Existing Customer") and ($"prediction" === 1)).count
val fn = prediction.filter(($"Attrition_Flag" === "Attrited Customer") and ($"prediction" === 0)).count

println(s"Confusion Matrix:\n$tp\t$fp\n$fn\t\t$tn\n")
Confusion Matrix:
1199	1893
428	6607

Вычислим также Accuracy, Precision, Recall:

val accuracy = (tp + tn) / (tp + tn + fp + fn).toDouble
val precision = tp / (tp + fp).toDouble
val recall = tp / (tp + fn).toDouble

println(s"Accuracy = $accuracy")
println(s"Precision = $precision")
println(s"Recall = $recall\n")
Accuracy = 0.7708107040584576
Precision = 0.38777490297542044
Recall = 0.7369391518131531

Предварительный расчёт (Precompute)

Разумеется никто не использует для Production. Для этого пишется код, собираемый в исполняемый файл, запускаемый на кластере.

Одним из способов использования ML в Production является Предварительный расчёт (Precompute). В пакетном режиме, по расписанию, обученная модель применяется к набору данных. Идентификаторы клиентов, для которых модель предсказывает отток, сохраняются для использования в дальнейших бизнес-процессах.

Исходный код использования способа Предварительный расчёт выглядит так:

package ru.otus.sparkml

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.ml.PipelineModel

object ProdML {
  def main(args: Array[String]): Unit = {
    if (args.length != 3) {
      println("Usage: SparkML <path-to-model> <path-to-input> <path-to-output>")
      sys.exit(-1)
    }

    val spark = SparkSession.builder
      .appName("SparkML")
      .config("spark.sql.debug.maxToStringFields", 100)
      .getOrCreate()

    import spark.implicits._

    try {
      val model = PipelineModel.load(args(0))

      val data = spark.read
        .option("header", "true")
        .option("inferSchema", "true")
        .csv(args(1))

      val prediction = model.transform(data)

      prediction
        .filter($"prediction" === 1)
        .select("CLIENTNUM")
        .repartition(1)
        .write
        .mode(SaveMode.Overwrite)
        .csv(args(2))

    } catch {
      case e: Exception =>
        println(s"ERROR: ${e.getMessage}")
        sys.exit(-1)
    } finally {
      spark.stop()
    }
  }
}

Весь проект находится здесь: https://github.com/vzaigrin/otus/tree/main/SparkML

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s