spark 编写udaf函数求中位数

  • 2019 年 10 月 4 日
  • 笔记

package com.frank.sparktest.java;    import org.apache.spark.sql.Row;  import org.apache.spark.sql.expressions.MutableAggregationBuffer;  import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;  import org.apache.spark.sql.types.DataType;  import org.apache.spark.sql.types.DataTypes;  import org.apache.spark.sql.types.StructField;  import org.apache.spark.sql.types.StructType;    import java.util.ArrayList;  import java.util.Arrays;  import java.util.Collections;  import java.util.List;    public class MedianUdaf extends UserDefinedAggregateFunction {        private StructType inputSchema;      private StructType bufferSchema;        public MedianUdaf(){          List<StructField> inputFields = new ArrayList<>();          inputFields.add(DataTypes.createStructField("nums",DataTypes.IntegerType,true));          inputSchema=DataTypes.createStructType(inputFields);          List<StructField> bufferFields = new ArrayList<>();          bufferFields.add(DataTypes.createStructField("datas",DataTypes.StringType,true));          bufferSchema=DataTypes.createStructType(bufferFields);      }        @Override      public StructType inputSchema() {          return inputSchema;      }        @Override      public StructType bufferSchema() {          return bufferSchema;      }        @Override      public DataType dataType() {          return DataTypes.DoubleType;      }        @Override      public boolean deterministic() {          return true;      }        @Override      public void initialize(MutableAggregationBuffer buffer) {          buffer.update(0,0);          buffer.update(1,0);      }        @Override      public void update(MutableAggregationBuffer buffer, Row input) {          if (!input.isNullAt(0)){              buffer.update(0,buffer.getString(0)+","+input.getInt(0));          }      }        @Override      public void merge(MutableAggregationBuffer buffer1, Row buffer2) {          buffer1.update(0,buffer1.getString(0)+","+buffer2.getInt(0));      }        @Override      public Object evaluate(Row buffer) {          List<Integer> list = new ArrayList<Integer>();          List<String> stringList = Arrays.asList(buffer.getString(0).split(","));          for (String s : stringList){              list.add(Integer.valueOf(s));          }          Collections.sort(list);          int size = list.size();          int num=0;          if(size % 2 == 1) {              num = list.get((size / 2)+1);          }          if(size %2  == 0) {              num = (list.get(size / 2)+list.get((size / 2)+1))/2;          }          return num;      }    }

上面是代码段,可以直接拿来使用

下面是测试程序

package com.frank.sparktest.java;    import org.apache.spark.sql.SQLContext;  import org.apache.spark.sql.SparkSession;  import org.apache.spark.sql.types.DataTypes;    import java.io.IOException;  import java.util.stream.IntStream;    public class DemoUDAF {        public static void main(String[] args) throws IOException {          SQLContext sqlContext = SparkSession.builder().master("local").getOrCreate().sqlContext();          sqlContext.udf().register("generate", (Integer start, Integer end)-> IntStream.range(start, end+1).boxed().toArray(), DataTypes.createArrayType(DataTypes.IntegerType));          sqlContext.udf().register("media",new MedianUdaf());          sqlContext.sql("select generate(1,10)").show();      }  }