Alink漫談(一) : 從KMeans演算法實現不同看Alink設計思想
Alink漫談(一) : 從KMeans演算法實現不同看Alink設計思想
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平台,是業界首個同時支援批式演算法、流式演算法的機器學習平台。本文將帶領大家從多重角度出發來分析推測Alink的設計思路。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
0x01 Flink 是什麼
Apache Flink是由Apache軟體基金會開發的開源流處理框架,它通過實現了 Google Dataflow 流式計算模型實現了高吞吐、低延遲、高性能兼具實時流式計算框架。
其核心是用Java和Scala編寫的分散式流數據流引擎。Flink以數據並行和流水線方式執行任意流數據程式,Flink的流水線運行時系統可以執行批處理和流處理程式。此外,Flink的運行時本身也支援迭代演算法的執行。
0x02 Alink 是什麼
Alink 是阿里巴巴計算平台事業部PAI團隊從2017年開始基於實時計算引擎 Flink 研發的新一代機器學習演算法平台,提供豐富的演算法組件庫和便捷的操作框架,開發者可以一鍵搭建覆蓋數據處理、特徵工程、模型訓練、模型預測的演算法模型開發全流程。項目之所以定為Alink,是取自相關名稱(Alibaba, Algorithm, AI, Flink, Blink)的公共部分。
藉助Flink在批流一體化方面的優勢,Alink能夠為批流任務提供一致性的操作。在2017年初,阿里團隊通過調研團隊看到了Flink在批流一體化方面的優勢及底層引擎的優秀性能,於是基於Flink重新設計研發了機器學習演算法庫,即Alink平台。該平台於2018年在阿里集團內部上線,隨後不斷改進完善,在阿里內部錯綜複雜的業務場景中鍛煉成長。
0x03 Alink設計思路
因為目前關於Alink設計的公開資料比較少,我們手頭只有其源碼,看起來只能從程式碼反推。但是世界上的事物都不是孤立的,我們還有其他角度來幫助我們判斷推理。所以下面就讓我們來進行推斷。
1. 白手起家
FlinkML 是 Flink 社區現存的一套機器學習演算法庫,這一套演算法庫已經存在很久而且更新比較緩慢。
Alink團隊起初面臨的抉擇是:是否要基於 Flink ML 進行開發,或者對 Flink ML進行更新。
經過研究,Alink團隊發現,Flink ML 其僅支援10餘種演算法,支援的數據結構也不夠通用,在演算法性能方面做的優化也比較少,而且其程式碼也很久沒有更新。所以,他們放棄了基於舊版FlinkML進行改進、升級的想法,決定基於Flink重新設計研發機器學習演算法庫。
所以我們要分析的就是如何從無到有設計出一個新的機器學習平台/框架。
2. 替代品如何造成威脅
因為Alink是市場的新進入者,所以Alink的最大問題就是如何替代市場上的現有產品。
邁克爾·波特用 「替代品威脅」 來解釋用戶的整個替代邏輯,當新產品能牢牢掌握住這一點,就有可能在市場上獲得非常好的表現,打敗競爭對手。
假如現在想從0到1構建一個機器學習庫或者機器學習框架,那麼我們需要從商業意識和商業邏輯出發,來思考這個產品的價值所在,就能對這個產品做個比較精確的定義,從而能夠確定產品路線。
產品需要解決應用環境下的綜合性問題,產品的價值體現,可以分拆了三個維度。
- 用戶的角度:價值體現在用戶使用,獲取產品的意願。這個就是換用成本的問題,一旦換用成本過高,這個產品就很難成功。
- 競爭對手的角度: 產品的競爭力,最終都體現為用戶為了獲取該產品願意支付的最高成本上限,當一個替代品進入市場,必須有能給用戶足夠的洞理驅使用戶換用替代品。
- 企業的角度:站在企業的角度,實際就是成本結構和收益的規模性問題 。
下面就讓我們逐一分析。
3. 用戶角度看設計
這個就是換用成本的問題,一旦換用成本過高,這個產品就很難成功。
Alink大略有兩種用戶:演算法工程師,應用工程師。
Alink演算法工程師特指實現機器學習演算法的工程師。Alink應用工程師就是應用Alink AI演算法做業務的工程師。這兩類用戶的換用成本都是Alink需要考慮的。
新產品對於用戶來說,有兩個大的問題:產品底層邏輯和開發工具。一個優秀的新產品絕對不能在這兩個問題上增加用戶的換用成本。
底層邏輯Flink
Flink這個平台博大精深,無論是熟悉其API還是深入理解系統架構都不是容易的事情。如果Alink用戶還需要熟悉Flink,那勢必造成ALink用戶的換用成本,所以這點應該盡量避免。
-
對於演算法工程師,他們應該主要把思路集中在演算法上,而盡量不用關心Flink內部的細節,如果一定要熟悉Flink,那麼越少越好;
-
對於應用工程師,他們主要的需求就是API介面越簡單越好,他們最理想的狀態應該是:完全感覺不到Flink的存在。
綜上所述,Alink的原則之一應該是 :演算法的歸演算法,Flink的歸Flink,盡量屏蔽AI演算法和Flink之間的聯繫。
開發工具
開發工具就是究竟用什麼語言開發。Flink的開發語言主要是JAVA,SCALA,Python。而機器學習世界中主要還是Python。
-
首先要排除SCALA。因為Scala 是一門很難掌握的語言,它的規則是基於數學類型理論的,學習曲線相當陡峭。一個能夠領會規則和語言特性的優秀程式設計師,使用 Scala 會比使用 Java 更高效,但是一個普通程式設計師的生產力,從功能實現上來看,效率則會相反。
讓我們看看基於Flink的原生KMeans SCALA程式碼,很多人看了之後恐怕都會懵圈。
val finalCentroids = centroids.iterate(params.getInt("iterations", 10)) { currentCentroids => val newCentroids = points .map(new SelectNearestCenter).withBroadcastSet(currentCentroids, "centroids") .map { x => (x._1, x._2, 1L) }.withForwardedFields("_1; _2") .groupBy(0) .reduce { (p1, p2) => (p1._1, p1._2.add(p2._2), p1._3 + p2._3) }.withForwardedFields("_1") .map { x => new Centroid(x._1, x._2.div(x._3)) }.withForwardedFields("_1->id") newCentroids }
-
其次是選擇JAVA還是Python開發具體演算法。Alink內部肯定進行了很多權宜和抉擇。因為這個不單單是哪個語言本身更合適,也涉及到Alink團隊內部有哪些資源,比如是JAVA工程師更多還是Python更多。最終Alink選擇了JAVA來開發演算法。
-
最後是API。這個就沒有什麼疑問了,Alink提供了Python和JAVA兩種語言的API,直接可參見GitHub的介紹。
在 PyAlink 中,演算法組件提供的介面基本與 Java API 一致,即通過默認構造方法創建一個演算法組件,然後通過
setXXX
設置參數,通過link/linkTo/linkFrom
與其他組件相連。 這裡利用 Jupyter 的自動補全機制可以提供書寫便利。
另外,如果採用JAVA或者Python,肯定有大量現有程式碼可以修改復用。如果採用SCALA,就難以復用之前的積累。
綜上所述,Alink的原則之一應該是 :採用最簡單,最常見的開發語言和設計思維。
4. 競爭對手角度看設計
Alink的競爭對手大略可以認為是Spark ML, Flink ML, Scikit-learn。
他們是市場上的現有力量,擁有大量的用戶。用戶已經熟悉了這些競爭對手的設計思路,開發策略,基本概念和API。除非Alink能夠提供一種神奇簡便的API,否則Alink應該在設計上最大程度借鑒這些競爭對手。
比如機器學習開發中有如下常見概念:Transformer,Estimator,PipeLine,Parameter。這些概念 Alink 應該盡量提供。
綜上所述,**Alink的原則之一應該是 :盡量借鑒市面上通用的設計思路和開發模式,讓開發者無縫切換 **。
從 Alink的目錄結構中 ,我們可以看出,Alink確實提供了這些常見概念。
比如 Pipeline,Trainer,Model,Estimator。我們會在後續文章中再詳細介紹這些概念。
./java/com/alibaba/alink:
common operator params pipeline
./java/com/alibaba/alink/params:
associationrule evaluation nlp regression statistics
classification feature onlinelearning shared tuning
clustering io outlier similarity udf
dataproc mapper recommendation sql validators
./java/com/alibaba/alink/pipeline:
EstimatorBase.java ModelBase.java Trainer.java feature
LocalPredictable.java ModelExporterUtils.java TransformerBase.java nlp
LocalPredictor.java Pipeline.java classification recommendation
MapModel.java PipelineModel.java clustering regression
MapTransformer.java PipelineStageBase.java dataproc tuning
5. 企業角度看設計
這是成本結構和收益的規模性問題。從而決定了Alink在開發時候,必須盡量提高開發工程師的效率,提高生產力。前面提到的棄用SCALA,部分也出於這個考慮。
挑戰集中在:
- 如何在對開發者最大程度屏蔽Flink的情況下,依然利用好Flink的各種能力。
- 如何構建一套相應打法和戰術體系,即middleware或者adapter,讓用戶基於此可以快速開發演算法
舉個例子:
-
肯定有個別開發者,其對Flink特別熟悉,他們可以運用各種Flink API和函數編程思維開發出高效率的演算法。這種開發者,我們可以稱為是武松武都頭。他們類似特種兵,能上戰場衝鋒陷陣,也能吊打白額大蟲。
-
但是絕大多數開發者對Flink不熟悉,他們更熟悉AI演算法和命令式編程思路。這種開發者我們可以認為他們屬於八十萬禁軍或者是玄甲軍,北府兵,魏武卒,背嵬軍。這種才是實際開發中的主力部隊和常規套路。
我們需要針對八十萬禁軍,讓林沖林教頭設計出一套適合正規作戰的槍棒打法。或者針對背嵬軍,讓岳飛岳元帥設計一套馬軍沖陣機制。
因此,**Alink的原則之一應該是 :構建一套戰術打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,還可以讓用戶基於此可以快速開發演算法 **。
我們想想看大概有哪些基礎工作需要做:
- 如何初始化
- 如果通訊
- 如何分割程式碼,如何廣播程式碼
- 如果分割數據,如何廣播數據
- 如何迭代演算法
- ……
讓我們看看Alink做了哪些努力,這點從其目錄結構可以看出有queue,operator,mapper等等構建架構所必須的數據結構:
./java/com/alibaba/alink/common:
MLEnvironment.java linalg MLEnvironmentFactory.java mapper
VectorTypes.java model comqueue utils io
./java/com/alibaba/alink/operator:
AlgoOperator.java common batch stream
其中最重要的概念是BaseComQueue,這是把通訊或者計算抽象成ComQueueItem,然後把ComQueueItem串聯起來形成隊列。這樣就形成了面向迭代計算場景的一套迭代通訊計算框架。其他數據結構大多是圍繞著BaseComQueue來具體運作。
/**
* Base class for the com(Computation && Communicate) queue.
*/
public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {
/**
* All computation or communication functions.
*/
private final List<ComQueueItem> queue = new ArrayList<>();
/**
* sessionId for shared objects within this BaseComQueue.
*/
private final int sessionId = SessionSharedObjs.getNewSessionId();
/**
* The function executed to decide whether to break the loop.
*/
private CompareCriterionFunction compareCriterion;
/**
* The function executed when closing the iteration
*/
private CompleteResultFunction completeResult;
/**
* Max iteration count.
*/
private int maxIter = Integer.MAX_VALUE;
private transient ExecutionEnvironment executionEnvironment;
}
MLEnvironment 是另外一個重要的類。其封裝了Flink開發所必須要的運行上下文。用戶可以通過這個類來獲取各種實際運行環境,可以建立table,可以運行SQL語句。
/**
* The MLEnvironment stores the necessary context in Flink.
* Each MLEnvironment will be associated with a unique ID.
* The operations associated with the same MLEnvironment ID
* will share the same Flink job context.
*/
public class MLEnvironment {
private ExecutionEnvironment env;
private StreamExecutionEnvironment streamEnv;
private BatchTableEnvironment batchTableEnv;
private StreamTableEnvironment streamTableEnv;
}
6. 設計原則總結
下面我們可以總結下Alink部分設計原則
-
演算法的歸演算法,Flink的歸Flink,盡量屏蔽AI演算法和Flink之間的聯繫。
-
採用最簡單,最常見的開發語言。
-
盡量借鑒市面上通用的設計思路和開發模式,讓開發者無縫切換。
-
構建一套戰術打法(middleware或者adapter),即屏蔽了Flink,又可以利用好Flink,還可以讓用戶基於此可以快速開發演算法。
0x04 KMeans演算法實現看設計
Flink和Alink源碼中,都提供了KMeans演算法例子,所以我們就從KMeans入手看看Flink原生演算法和Alink演算法實現的區別。為了統一標準,我們都選用JAVA版本的演算法實現。
1. KMeans演算法
KMeans演算法的思想比較簡單,假設我們要把數據分成K個類,大概可以分為以下幾個步驟:
- 隨機選取k個點,作為聚類中心;
- 計算每個點分別到k個聚類中心的聚類,然後將該點分到最近的聚類中心,這樣就行成了k個簇;
- 再重新計算每個簇的質心(均值);
- 重複以上2~4步,直到質心的位置不再發生變化或者達到設定的迭代次數。
2. Flink KMeans例子
K-Means 是迭代的聚類演算法,初始設置K個聚類中心
- 在每一次迭代過程中,演算法計算每個數據點到每個聚類中心的歐式距離
- 每個點被分配到它最近的聚類中心
- 隨後每個聚類中心被移動到所有被分配的點
- 移動的聚類中心被分配到下一次迭代
- 演算法在固定次數的迭代之後終止(在本實現中,參數設置)
- 或者聚類中心在迭代中不在移動
- 本項目是工作在二維平面的數據點上
- 它計算分配給集群中心的數據點
- 每個數據點都使用其所屬的最終集群(中心)的id進行注釋。
下面給出部分程式碼,具體演算法解釋可以在注釋中看到。
這裡主要採用了Flink的批量迭代。其調用 DataSet 的 iterate(int)
方法創建一個 BulkIteration,迭代以此為起點,返回一個 IterativeDataSet,可以用常規運算符進行轉換。迭代調用的參數 int 指定最大迭代次數。
IterativeDataSet 調用 closeWith(DataSet)
方法來指定哪個轉換應該回饋到下一個迭代,可以選擇使用 closeWith(DataSet,DataSet)
指定終止條件。如果該 DataSet 為空,則它將評估第二個 DataSet 並終止迭代。如果沒有指定終止條件,則迭代在給定的最大次數迭代後終止。
public class KMeans {
public static void main(String[] args) throws Exception {
// Checking input parameters
final ParameterTool params = ParameterTool.fromArgs(args);
// set up execution environment
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().setGlobalJobParameters(params); // make parameters available in the web interface
// get input data:
// read the points and centroids from the provided paths or fall back to default data
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
// set number of bulk iterations for KMeans algorithm
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
DataSet<Centroid> newCentroids = points
// compute closest centroid for each point
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// compute new centroids from point counts and coordinate sums
.map(new CentroidAverager());
// feed new centroids back into next iteration
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
// assign points to final clusters
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
// emit result
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
// since file sinks are lazy, we trigger the execution explicitly
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
3. Alink KMeans示例
Alink中,Kmeans是分布在若干文件中,這裡我們提取部分程式碼來對照。
KMeansTrainBatchOp
這裡是演算法主程式,這裡倒是看起來十分清爽乾淨,但實際上是沒有這麼簡單,Alink在其背後做了大量的基礎工作。
可以看出,演算法實現的主要工作是:
- 構建了一個IterativeComQueue(BaseComQueue的預設實現)。
- 初始化數據,這裡有兩種辦法:initWithPartitionedData將DataSet分片快取至記憶體。initWithBroadcastData將DataSet整體快取至每個worker的記憶體。
- 將計算分割為若干ComputeFunction,比如KMeansPreallocateCentroid / KMeansAssignCluster / KMeansUpdateCentroids …,串聯在IterativeComQueue。
- 運用AllReduce通訊模型完成了數據同步。
public final class KMeansTrainBatchOp extends BatchOperator <KMeansTrainBatchOp>
implements KMeansTrainParams <KMeansTrainBatchOp> {
static DataSet <Row> iterateICQ(...省略...) {
return new IterativeComQueue()
.initWithPartitionedData(TRAIN_DATA, data)
.initWithBroadcastData(INIT_CENTROID, initCentroid)
.initWithBroadcastData(KMEANS_STATISTICS, statistics)
.add(new KMeansPreallocateCentroid())
.add(new KMeansAssignCluster(distance))
.add(new AllReduce(CENTROID_ALL_REDUCE))
.add(new KMeansUpdateCentroids(distance))
.setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol))
.closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
.setMaxIter(maxIter)
.exec();
}
}
KMeansPreallocateCentroid
預先分配聚類中心
public class KMeansPreallocateCentroid extends ComputeFunction {
public void calc(ComContext context) {
if (context.getStepNo() == 1) {
List<FastDistanceMatrixData> initCentroids = (List)context.getObj("initCentroid");
List<Integer> list = (List)context.getObj("statistics");
Integer vectorSize = (Integer)list.get(0);
context.putObj("vectorSize", vectorSize);
FastDistanceMatrixData centroid = (FastDistanceMatrixData)initCentroids.get(0);
Preconditions.checkArgument(centroid.getVectors().numRows() == vectorSize, "Init centroid error, size not equal!");
context.putObj("centroid1", Tuple2.of(context.getStepNo() - 1, centroid));
context.putObj("centroid2", Tuple2.of(context.getStepNo() - 1, new FastDistanceMatrixData(centroid)));
context.putObj("k", centroid.getVectors().numCols());
}
}
}
KMeansAssignCluster
為每個點(point)計算最近的聚類中心,為每個聚類中心的點坐標的計數和求和
/**
* Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
*/
public class KMeansAssignCluster extends ComputeFunction {
@Override
public void calc(ComContext context) {
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
// get iterative coefficient from static memory.
Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
if (context.getStepNo() % 2 == 0) {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
} else {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
}
if (null == distanceMatrix) {
distanceMatrix = new DenseMatrix(k, 1);
}
double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
if (sumMatrixData == null) {
sumMatrixData = new double[k * (vectorSize + 1)];
context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
}
Iterable<FastDistanceVectorData> trainData = context.getObj(KMeansTrainBatchOp.TRAIN_DATA);
if (trainData == null) {
return;
}
Arrays.fill(sumMatrixData, 0.0);
for (FastDistanceVectorData sample : trainData) {
KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance,
distanceMatrix);
}
}
}
KMeansUpdateCentroids
基於點計數和坐標,計算新的聚類中心。
/**
* Update the centroids based on the sum of points and point number belonging to the same cluster.
*/
public class KMeansUpdateCentroids extends ComputeFunction {
@Override
public void calc(ComContext context) {
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
if (context.getStepNo() % 2 == 0) {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
} else {
stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
}
stepNumCentroids.f0 = context.getStepNo();
context.putObj(KMeansTrainBatchOp.K,
updateCentroids(stepNumCentroids.f1, k, vectorSize, sumMatrixData, distance));
}
}
4. 區別
程式碼量
通過下面的分析可以看出,從實際業務程式碼量角度說,其實差別不大。
- Flink的程式碼量少;
- Alink的程式碼量雖然大,但其本質就是把Flink版本的一些用戶定義類分離到自己不同類中,並且有很多讀取環境變數的程式碼;
所以Alink程式碼只能說比Flink原生實現略大。
耦合度
這裡指的是與Flink的耦合度。能看出來Flink的KMeans演算法需要大量的Flink類。而Alink被最大限度屏蔽了。
- Flink 演算法需要引入的flink類如下
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
- Alink 演算法需要引入的flink類如下,可以看出來ALink使用的都是基本設施,不涉及運算元和複雜API,這樣就減少了用戶的負擔。
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
編程模式
這是一個主要的區別。
- Flink 使用的是函數式編程。這種範式相對新穎,很多工程師不熟悉。
- Alink 依然使用了命令式編程。這樣的好處在於,大量現有演算法程式碼可以復用,也更符合絕大多數工程師的習慣。
- Flink 通過Flink的各種運算元完成了操作,比如IterativeDataSet實現了迭代。但這種實現對於不熟悉Flink的工程師是個折磨。
- Alink 基於自己的框架,把計算程式碼總結成了若干ComputeFunction,然後通過IterativeComQueue完成了具體演算法的迭代。這樣用戶其實對Flink是不需要過多深入理解。
在下一期文章中,將從源碼角度分析驗證本文的設計思路。
0x05 參考
Spark ML簡介之Pipeline,DataFrame,Estimator,Transformer
斬獲GitHub 2000+ Star,阿里雲開源的 Alink 機器學習平台如何跑贏雙11數據「博弈」?|AI 技術生態論