«1. Обзор
TensorFlow — это библиотека с открытым исходным кодом для программирования потоков данных. Первоначально он был разработан Google и доступен для широкого спектра платформ. Хотя TensorFlow может работать на одном ядре, он также может легко получить выгоду от нескольких доступных процессоров, графических процессоров или TPU.
В этом руководстве мы рассмотрим основы TensorFlow и способы его использования в Java. Обратите внимание, что TensorFlow Java API — это экспериментальный API, поэтому на него не распространяется гарантия стабильности. Позже в этом руководстве мы рассмотрим возможные варианты использования TensorFlow Java API.
2. Основы
Вычисления TensorFlow в основном вращаются вокруг двух фундаментальных концепций: графа и сеанса. Давайте быстро пройдемся по ним, чтобы получить фон, необходимый для прохождения остальной части урока.
2.1. График TensorFlow
Для начала давайте разберемся с основными строительными блоками программ TensorFlow. Вычисления представлены в виде графиков в TensorFlow. Граф обычно представляет собой ориентированный ациклический граф операций и данных, например:
На приведенном выше рисунке представлен вычислительный граф для следующего уравнения:
f(x, y) = z = a*x + b*y
Вычислительный граф TensorFlow состоит из двух элементов:
- Tensor: These are the core unit of data in TensorFlow. They are represented as the edges in a computational graph, depicting the flow of data through the graph. A tensor can have a shape with any number of dimensions. The number of dimensions in a tensor is usually referred to as its rank. So a scalar is a rank 0 tensor, a vector is a rank 1 tensor, a matrix is a rank 2 tensor, and so on and so forth.
- Operation: These are the nodes in a computational graph. They refer to a wide variety of computation that can happen on the tensors feeding into the operation. They often result in tensors as well which emanate from the operation in a computational graph.
2.2 . Сеанс TensorFlow
Теперь график TensorFlow — это просто схема вычислений, которая на самом деле не содержит значений. Такой граф должен быть запущен внутри так называемого сеанса TensorFlow для оценки тензоров в графе. Сеанс может принимать кучу тензоров для оценки из графика в качестве входных параметров. Затем он движется назад по графу и запускает все узлы, необходимые для оценки этих тензоров.
Обладая этими знаниями, мы теперь готовы применить их к Java API!
3. Настройка Maven
Мы быстро настроим проект Maven для создания и запуска графа TensorFlow в Java. Нам просто нужна зависимость от tensorflow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.12.0</version>
</dependency>
4. Создание графа
Давайте теперь попробуем построить граф, который мы обсуждали в предыдущем разделе, используя Java API TensorFlow. Точнее, в этом уроке мы будем использовать TensorFlow Java API для решения функции, представленной следующим уравнением:
z = 3*x + 2*y
Первым шагом является объявление и инициализация графа:
Graph graph = new Graph()
Теперь мы должны определить все необходимые операции. Помните, что операции в TensorFlow потребляют и производят ноль или более тензоров. Более того, каждый узел в графе — это операция, включающая константы и заполнители. Это может показаться нелогичным, но потерпите немного!
Класс Graph имеет универсальную функцию opBuilder() для создания любых операций в TensorFlow.
4.1. Определение констант
Для начала давайте определим константные операции в нашем графе выше. Обратите внимание, что постоянной операции потребуется тензор для ее значения:
Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class))
.build();
Здесь мы определили операцию постоянного типа, вводя тензор со значениями Double 2.0 и 3.0. Поначалу это может показаться немного ошеломляющим, но на данный момент именно так обстоит дело с Java API. Эти конструкции гораздо более лаконичны в таких языках, как Python.
4.2. Определение заполнителей
Несмотря на то, что нам нужно предоставлять значения нашим константам, заполнителям не нужно значение во время определения. Значения заполнителей необходимо указывать, когда граф запускается внутри сеанса. Мы рассмотрим эту часть позже в уроке.
А пока давайте посмотрим, как мы можем определить наши заполнители:
Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Обратите внимание, что нам не нужно указывать какое-либо значение для наших заполнителей. Эти значения будут переданы как тензоры при запуске.
4.3. Определение функций
Наконец, нам нужно определить математические операции нашего уравнения, а именно умножение и сложение, чтобы получить результат.
Это снова не что иное, как операции в TensorFlow, и Graph.opBuilder() снова удобен:
Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();
Здесь мы определили операцию, две для умножения наших входных данных и последнюю для суммирования промежуточных Результаты. Обратите внимание, что операции здесь получают тензоры, которые являются не чем иным, как результатом наших предыдущих операций.
«Обратите внимание, что мы получаем выходной тензор из операции, используя индекс «0». Как мы обсуждали ранее, операция может привести к одному или нескольким тензорам, и, следовательно, при получении дескриптора для него нам нужно упомянуть индекс. Поскольку мы знаем, что наши операции возвращают только один тензор, «0» работает просто отлично!
5. Визуализация графика
Трудно следить за графиком по мере его увеличения. Это делает важным визуализировать его каким-то образом. Мы всегда можем создать рисунок от руки, подобный маленькому графику, который мы создали ранее, но это нецелесообразно для больших графиков. Для этого TensorFlow предоставляет утилиту под названием TensorBoard.
К сожалению, Java API не имеет возможности генерировать файл событий, который используется TensorBoard. Но, используя API в Python, мы можем сгенерировать файл событий, например: продолжить остальную часть учебника.
writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()
Теперь мы можем загружать и визуализировать файл событий в TensorBoard следующим образом:
TensorBoard входит в состав установки TensorFlow.
tensorboard --logdir .
Обратите внимание на сходство между этим и ранее нарисованным вручную графиком!
6. Работа с сеансом
Теперь мы создали вычислительный граф для нашего простого уравнения в TensorFlow Java API. Но как мы его запускаем? Прежде чем обратиться к этому, давайте посмотрим, в каком состоянии находится Graph, который мы только что создали. Если мы попытаемся напечатать вывод нашей последней операции «z»:
Это приведет к чему-то вроде:
System.out.println(z.output(0));
Это не то, что мы ожидали! Но если вспомнить то, что мы обсуждали ранее, в этом действительно есть смысл. Граф, который мы только что определили, еще не запускался, поэтому тензоры в нем на самом деле не имеют никакого фактического значения. В приведенном выше выводе просто говорится, что это будет тензор типа Double.
<Add 'z:0' shape=<unknown> dtype=DOUBLE>
Давайте теперь определим сеанс для запуска нашего графа:
Наконец, теперь мы готовы запустить наш граф и получить ожидаемый результат:
Session sess = new Session(graph)
Итак, что мы здесь делаем ? Это должно быть довольно интуитивно понятно:
Tensor<Double> tensor = sess.runner().fetch("z")
.feed("x", Tensor.<Double>create(3.0, Double.class))
.feed("y", Tensor.<Double>create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());
Получить Runner из сеанса Определить операцию для выборки по ее имени «z» Ввести тензоры для наших заполнителей «x» и «y» Запустить график в сеансе ~~ ~ И теперь мы видим скалярный вывод:
-
Это то, что мы ожидали, не так ли!
7. Пример использования Java API
21.0
На данный момент TensorFlow может показаться излишним для выполнения основных операций. Но, конечно же, TensorFlow предназначен для работы с графиками гораздо большего размера.
Кроме того, тензоры, с которыми он имеет дело в реальных моделях, намного больше по размеру и рангу. Это настоящие модели машинного обучения, в которых TensorFlow находит свое реальное применение.
Нетрудно заметить, что работа с основным API в TensorFlow может стать очень громоздкой по мере увеличения размера графа. С этой целью TensorFlow предоставляет высокоуровневые API, такие как Keras, для работы со сложными моделями. К сожалению, официальной поддержки Keras на Java пока практически нет.
Однако мы можем использовать Python для определения и обучения сложных моделей либо непосредственно в TensorFlow, либо с помощью высокоуровневых API, таких как Keras. Впоследствии мы можем экспортировать обученную модель и использовать ее в Java с помощью TensorFlow Java API.
Итак, зачем нам делать что-то подобное? Это особенно полезно в ситуациях, когда мы хотим использовать функции машинного обучения в существующих клиентах, работающих на Java. Например, рекомендовать подписи к пользовательским изображениям на устройстве Android. Тем не менее, есть несколько случаев, когда мы заинтересованы в выводе модели машинного обучения, но не обязательно хотим создавать и обучать эту модель на Java.
Именно здесь TensorFlow Java API находит большую часть своего использования. Мы рассмотрим, как этого можно достичь в следующем разделе.
8. Использование сохраненных моделей
«Теперь мы поймем, как мы можем сохранить модель в TensorFlow в файловой системе и загрузить ее обратно, возможно, на совершенно другом языке и платформе. TensorFlow предоставляет API-интерфейсы для создания файлов моделей в независимой от языка и платформы структуре, которая называется Protocol Buffer.
8.1. Сохранение моделей в файловой системе
Мы начнем с определения того же графа, который мы создали ранее в Python, и сохранения его в файловой системе.
Давайте посмотрим, что мы можем сделать на Python:
Так как это руководство посвящено Java, давайте не будем уделять много внимания деталям этого кода на Python, за исключением того факта, что он генерирует файл называется «saved_model.pb». Обратите внимание на краткость определения аналогичного графа по сравнению с Java!
8.2. Загрузка моделей из файловой системы
import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
Теперь мы загрузим файл «saved_model.pb» в Java. Java TensorFlow API имеет SavedModelBundle для работы с сохраненными моделями:
Теперь должно быть довольно интуитивно понятно, что делает приведенный выше код. Он просто загружает граф модели из буфера протокола и делает доступной сессию в нем. С этого момента мы можем делать с этим графом все, что угодно, как если бы мы делали это с локально определенным графом.
9. Заключение
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor<Integer> tensor = model.session().runner().fetch("z")
.feed("x", Tensor.<Integer>create(3, Integer.class))
.feed("y", Tensor.<Integer>create(3, Integer.class))
.run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());
Подводя итог, в этом уроке мы рассмотрели основные понятия, связанные с вычислительным графом TensorFlow. Мы увидели, как использовать TensorFlow Java API для создания и запуска такого графа. Затем мы обсудили варианты использования Java API в отношении TensorFlow.
В процессе мы также поняли, как визуализировать график с помощью TensorBoard, а также сохранять и перезагружать модель с помощью Protocol Buffer.
Как всегда, код примеров доступен на GitHub.
«
As always, the code for the examples is available over on GitHub.