Генеративно-состязательная сеть
Генеративно-состязательная сеть (Generative adversarial network) это методика обучения пары моделей:
-
генеративной $G$, которая по случайному шуму генерирует данные, подчинённые некоторому вероятностному распределению, и
-
дискриминативной $D$, которая приписывает, поступающим на её вход данным, вероятность того, что они получены из тренировочного датасета или сгенерированы моделью $G$.
Т.е. модель $G$ обучается таким образом, чтобы максимизировать вероятность того, что $D$ ошибётся. А модель $D$ тренируется отличать подмену реальных данных от тех, которые сгенерировала сеть $G$. Таким образом мы получаем как бы игру для двух игроков.
Методика тренировки описана в статье [1] и схематично выглядит так:
В качестве примера того, почему этот подход позволяет получить интересные результы, рассмотрим известный набор данных MNIST (см., например, MNIST for ML beginners). Этот набор представляет из себя картинки 28 х 28 пикселей в градациях серого, на которых изображены рукописные цифры от 0 до 9. Воспользовавшись схемой GAN мы можем взять две сети, одна будет по случайному шуму (например, вещественному вектору размерности 100) генерировать картинку 28 х 28 в градациях серого, а вторая пытаться отличить эти сгенерированные картинки от реальных данных взятых из MNIST датасета. Проведя достаточно циклов тренировок мы в результате получим генеративную сеть $G$, которая может по случайному входному вектору выдавать картинку с рукописной цифрой. Пример кода тренировки можно посмотреть вот здесь.
В качестве датасета можно использовать не обязательно MNIST, а, например, базу данных лиц, тогда в результате мы обучим сеть, которая будет уметь создавать лица людей. В принципе, мы можем обучать компьютер генерировать всё на что хватит нашей фантазии и для чего у нас собран более менее приличный датасет. Мне кажется это довольно интересно, поэтому начнём разбираться. И начнём с того, что поймем что такое генеративная и дискриминативная модели и чем они отличаются.
Генеративная и дискриминативная модели
Предположим, у нас есть два множества:
-
$O$ - множество наблюдаемых (observed) величин,
-
$H$ - множество скрытых (hidden) величин.
Тогда наличие генеративной (generative) модели означает, что мы знаем вероятность совместного распределения $P(o, h), o \in O, h \in H$. Таким образом мы можем генерировать события при помощи данной модели.
Дискриминативная (discriminative) же модель означает, что мы знаем условную вероятность $P(h | o)$ и соответственно, можем, например, классифицировать явление (найти наиболее вероятное значение скрытой переменной $h$), наблюдая нечто (зная значение наблюдаемой переменно $o \in O$)
Если генеративная модель известна, то используя формулу Байеса, можно получить дискриминативную модель.
Обычно по теме генеративных vs дискриминативных моделей ссылаются на статью [2].
Примеры
-
Если множества наблюдаемых и скрытых переменных конечны, то генеративную модель можно описать в виде таблицы:
$h_1$ $h_2$ … $h_N$ $o_1$ $P(o_1, h_1) = p_{11}$ $P(o_1, h_2) = p_{12}$ … $P(o_1, h_N) = p_{1N}$ $o_2$ $P(o_2, h_1) = p_{21}$ $P(o_2, h_2) = p_{22}$ … $P(o_2, h_N) = p_{2N}$ … … … … … $o_M$ $P(o_M, h_1) = p_{M1}$ $P(o_M, h_2) = p_{M2}$ … $P(o_M, h_N) = p_{MN}$ Дискриминативная модель получается из данной таблицы, применением формулы Байеса. Обозначим
$Z_i = \sum_{j=1}^N p_{ij}$
и получим:
$h_1$ $h_2$ … $h_N$ $o_1$ $P(h_1 \vert o_1) = p_{11} / Z_1$ $P(h_2 \vert o_1) = p_{12} / Z_1$ … $P(h_N \vert o_1) = p_{1N} / Z_1$ $o_2$ $P(h_1 \vert o_2) = p_{21} / Z_2$ $P(h_2 \vert o_2) = p_{22} / Z_2$ … $P(h_N \vert o_2) = p_{2N} / Z_2$ … … … … … $o_M$ $P(h_1 \vert o_M) = p_{M1} / Z_M$ $P(h_2 \vert o_M) = p_{M2} / Z_M$ … $P(h_N \vert o_M) = p_{MN} / Z_M$ -
Еще один пример описан раньше - это картинки рукописных цифр. Наблюдаемая величина это собственно картинка: 28 х 28, а скрытая величина - это нарисована ли на картинке цифра.
-
Хороший пример, это тексты на разных языках. Допустим, что в качестве наблюдаемой величины у нас выступает текст на каком-то языке, а в качестве скрытой - собственно язык на котором текст написан.
Генеративным подходом в данном случае будет: выучить все возможные языки и тогда мы будем иметь для каждого языка и текста вероятность их совместимости. Дискриминативным же подходом будет только понять как языки отличаются, и приписывать текст к языку на базе этого знания. Можно, например, ограничиться только текстами на двух языках: английском и русском и тогда для генеративного подхода надо всё равно изучить оба языка, а вот для дискриминативного достаточно будет научиться отличать латиницу от кирилицы.
Состязательная сеть
Итак формально мы имеем некоторое множество $X$ примеров, и подмножество $Y \subset X$ позитивных примеров (например, $X$ множество всех картинок размера 28 x 28, а подмножество $Y$ - набор MNIST картинок с рукописными цифрами), таким образом на множестве $X$ определено вероятностное распределение $p_{data}(x)$. Мы хотим отыскать две функции (обе функции мы аппроксимируем при помощи многослойной нейронной сети):
-
$G(z, \theta_g)$. $z$ - случайный шум с каким-то наперёд заданным распределением $p_z(z)$. $\theta_g$ - параметры, которые будем тренировать. На выходе функция $G$ будет выдавать элемент из множества $X$. Таким образом функция $G$ задаёт распределение $p_g(x)$.
-
$D(x, \theta_d)$. $x \in X$, а $\theta_d$ - параметры дискриминативной сети. $D(x)$ - вещественное число - вероятность того, что $x$ взято из подмножество $Y$, а не сгенерировано сетью $G$.
Мы одновременно тренируем сети представляющие функции $D$ и $G$. При этом мы ищем такие параметры $\theta_d$, чтобы максимизировать вероятность правильного разделения функцией $D$ позитивных примеров и примеров сгенерированных функцией $G$. И такие параметры $\theta_g$, которые минимизируют функцию $\log(1 - D(G(z)))$.
Можно переформулировать задачу как минимаксную игру для двух игроков с целевой функцией $V(G, D)$:
\[\min_G \max_D V(D, G) = \mathbb{E}_{data}\left(\log D(x))\right) + \mathbb{E}_z\left(\log(1-D(G(z)))\right)\]В работе [1] теоретически обосновывается, что данная задача имеет глобальный оптимум $-\log(4)$ при $p_g(x) = p_{data}(x)$. Т.е. мы можем натренировать такую функцию $G$, которая будет выдавать исходное распределение, и функция $D$ уже не сможет отличить примеры из тестового набора от примеров генерируемых функцией $G$, а значит будет выдавать вероятность $0.5$ для всех примеров, что собственно и даёт нам в результате $-\log(4)$.
Алгоритм решающий задачу оптимизации, попеременно оптимизирует функцию $G$ при фиксированной $D$, а затем фиксируя $D$ улучшает функцию $G$. При этом, так как на начальном этапе функция $G$ еще не достаточно хорошо умеет генерировать примеры похожие на тестовые, и $D$ назначает им очень маленькую вероятность, то $\log(1-D(G(z)))$ близок к нулю, соответственно обучение параметров функции $G$ будет происходить медленно, поэтому авторы [1] предлагают вместо того, чтобы минимизировать $\log(1-D(G(z)))$ искать максимум $\log(D(G(z)))$. Это эквивалентная задача, но при этом мы будем иметь, на начальном этапе обучения, большие по величине градиенты.
Итак в конечном итоге приходим к следующему алгоритму обучения:
Алгоритм тренировки генеративно-состязательных сетей
-
Для каждого шага тренировки:
-
Делаем $k$ итераций, оптимизируя дискриминатор $D$.
-
Генерируем $m$ примеров $\{z^{(1)}, z^{(2)}, …, z^{(m)}\}$, при помощи текущего варианта функции $G$.
-
Набираем $m$ позитивных примеров $\{x^{(1)}, x^{(2)}, …, x^{(m)}\}$ из тренировочного набора.
-
Обновляем параметры дискриминатора используя градиент:
-
-
Генерируем $m$ примеров $\{z^{(1)}, z^{(2)}, …, z^{(m)}\}$, при помощи текущего варианта функции $G$.
-
Обновляем параметры генератора используя градиент:
-
Гиперпараметр $k$ - количество шагов оптимизации дискриминатора на каждый шаг оптимизации генератора. Авторы [1] предлагают использовать $k=1$
Результаты выдаваемые генеративной сетью, натренированной по описанному выше алгоритму на наборе MNIST код для тренировки (не мой):
Здесь и дискриминативная и генеративная сети содержат два полносвязных слоя. А шум, подаваемый на вход генеративной сети, представляет из себя случайный вектор размерности $100$, элементы которого соответствуют равномерному распределению на отрезке $[-1, 1]$.
Литература
-
Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, “Generative Adversarial Nets” arXiv:1406.2661 2014
-
Andrew Y. Ng, Michael I. Jordan, “On Discriminative vs. Generative classifiers: A comparison of logistic regression and naive Bayes” NIPS