誰說搞Java的不能玩機器學習?
- 2019 年 11 月 5 日
- 筆記
簡介
機器學習在全球範圍內越來越受歡迎和使用。 它已經徹底改變了某些應用程式的構建方式,並且可能會繼續成為我們日常生活中一個巨大的(並且正在增加的)部分。
沒有什麼包裝且機器學習並不簡單。 它對許多人來說似乎非常複雜並常常令人生畏。
像Google這樣的公司將自己的機器學習概念與開發人員聯繫起來,在Google幫助下讓他們逐漸邁出第一步,故TensorFlow的框架誕生了。
TensorFlow為何物?
TensorFlow是由Google使用Python和C++開發的開源機器學習框架。
它可以幫助開發人員輕鬆獲取數據,準備和訓練模型,預測未來狀態,以及執行大規模機器學習。
有了它,我們可以訓練和運行深度神經網路的內容,諸如光學字元識別,影像識別/分類,自然語言處理等。
張量與操作
TensorFlow基於計算圖,你可以將其想像為具有節點和邊的經典圖。
每個節點被稱為操作,它們將零個或多個張量輸入併產生零個或多個張量輸出。 操作可以非常簡單,例如基本的添加,但它們也可以非常複雜。
張量被描繪為圖的邊緣,並且是核心數據單元。 當我們將它們提供給操作時,我們在這些張量上執行不同的功能。 它們可以具有單個或多個維度,有時也稱為它們的等級(標量:等級0,向量:等級1,矩陣:等級2)。
這些數據受到操作的影響通過張量傳遞到計算圖中,故而稱為TensorFlow。
張量可以以任意數量的維度存儲數據,並且有三種主要類型的張量:佔位符,變數和常量。
安裝TensorFlow
使用Maven,安裝TensorFlow就像包含依賴項一樣簡單:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.13.1</version> </dependency>
如果你的設備支援GPU功能,可以添加以下依賴:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow</artifactId> <version>1.13.1</version> </dependency> <dependency> <groupId>org.tensorflow</groupId> <artifactId>libtensorflow_jni_gpu</artifactId> <version>1.13.1</version> </dependency>
你可以使用TensorFlow對象來檢查當前操作的TensorFlow的版本。
System.out.println(TensorFlow.version());
TensorFlow的JavaAPI
Java API TensorFlow提供包含在org.tensorflow包中。 它目前是實驗性的,因此不能保證其穩定性。
需要注意的是TensorFlow唯一完全支援的語言是Python,Java API幾乎沒有什麼功能。
API向我們介紹了新的類,介面,枚舉和異常。
類
通過API引入的新類是:
- Graph:表示TensorFlow計算的數據流圖;
- Operation:在Tensors上執行計算的Graph節點;
- OperationBuilder:Operations的構建器類;
- Output
:操作產生的張量的符號句柄; - SavedModelBundle:表示從存儲載入的模型;
- SavedModelBundle.Loader:提供載入SavedModel的選項;
- Server:進程內TensorFlow伺服器,用於分散式訓練;
- Session:圖形執行的驅動程式;
- Session.Run:輸出執行會話時獲得的張量和元數據;
- Session.Runner:運行操作並評估張量;
- Shape:由操作產生的可能部分已知的張量形狀;
- Tensor
:靜態類型的多維數組,其元素是由T描述的類型; - TensorFlow:描述TensorFlow運行時的靜態實用程式方法;
- Tensors:用於創建張量對象的類型安全工廠方法;
枚舉
- DataType:將張量中的元素類型表示為枚舉;
介面
- Operand
:由TensorFlow操作的操作數實現的介面;
異常
- TensorFlowException:執行TensorFlow圖時拋出的未經檢查的異常
如果我們將所有這些與Python中的tf模組進行比較將發現存在明顯的區別。 Java API沒有幾乎相同的功能,至少目前如此。
圖(Graphs)
如前所述,TensorFlow基於計算圖 – 其中org.tensorflow.Graph是Java的實現。
注意:它的實例是執行緒安全的,儘管我們需要在完成它之後顯式釋放Graph使用的資源。
讓我們從一個空圖開始:
Graph graph = new Graph();
該對象是空的,所以這個圖表意義不大。 要對它做任何操作,我們首先需要使用Operations載入它。
我們使用opBuilder()方法來載入它,它返回一個OperationBuilder對象,一旦我們調用.build()方法,它就會將操作添加到我們的圖形中。
常量
讓我們在圖表中添加一個常量:
Operation x = graph.opBuilder("Const", "x") .setAttr("dtype", DataType.FLOAT) .setAttr("value", Tensor.create(3.0f)) .build();
佔位符
佔位符是變數的「類型」,聲明時沒有賦值,他們的值將在之後進行分配。 這允許我們使用沒有任何實際數據的操作來構建圖形:
Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.FLOAT) .build();
函數
最後為了解決這個問題,我們需要添加某些函數。 這些可以像乘法,除法或加法一樣簡單,也可以像矩陣乘法一樣複雜。 和之前一樣,我們使用.opBuilder()方法定義函數:
Operation xy = graph.opBuilder("Mul", "xy") .addInput(x.output(0)) .addInput(y.output(0)) .build();
注意:我們使用input(0)作為張量可以有多個輸出。
圖形可視化
遺憾的是,Java API還沒有包含任何允許像Python中一樣可視化圖形的工具。
會話(Sessions)
如前所述,Session是Graph的驅動程式。 它封裝了執行Operation和Graph計算張量(tensors)的環境。
這意味著我們構建的圖(graph)中的張量(tensors)實際上並沒有任何值,因為我們沒有在會話(session)中運行圖形(graph)。
我們首先將圖表添加到會話(session)中:
Session session = new Session(graph);
我們的操作知識簡單地將x於y相乘,為了運行我們的圖(graph)並得到計算結果,我們需要使用fetch()獲取到xy的操作並為其提供x和y的值:
Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());
運行這段程式碼將產生的結果如下:
10.0f
Java當中載入Python中Saving模組
這可能聽起來有點奇怪,但由於Python是唯一受到良好支援的語言,因此Java API仍然沒有保存模型的功能。
這意味著Java API僅用於服務用例,至少在TensorFlow完全支援之前。 目前至少我們可以使用SavedModelBundle類在Python中訓練和保存模型,然後使用Java載入它們來為它們提供服務:
SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); System.out.println(tensor.floatValue());
結論
TensorFlow是一個功能強大且廣泛使用的框架。 它不斷得到改進,並最近被引入新語言:包括Java和JavaScript。
儘管Java API還沒有像TensorFlow在Python中那麼多的功能,但它仍然可以作為向Java開發人員介紹TensorFlow的一個很好的開始。
原文鏈接:https://stackabuse.com/how-to-use-tensorflow-with-java/
作 者:David Landup
譯 者:klein
——
9月福利,關注公眾號
後台回復:004,領取8月翻譯集錦!
往期福利回復:001,002, 003即可領取!