数据算法之反转排序 | 寻找相邻单词的数量

  • 2020 年 2 月 10 日
  • 筆記

这期题目和Leetcode中的一些搜索题目有点类似。

想处理的问题是:统计一个单词相邻前后两位的数量,如有w1,w2,w3,w4,w5,w6,则:

最终要输出为(word,neighbor,frequency)。

我们用五种方法实现:

  • MapReduce
  • Spark
  • Spark SQL的方法
  • Scala方法
  • Scala版Spark SQL

MapReduce

//map函数   @Override      protected void map(LongWritable key, Text value, Context context)              throws IOException, InterruptedException {            String[] tokens = StringUtils.split(value.toString(), " ");          //String[] tokens = StringUtils.split(value.toString(), "\s+");          if ((tokens == null) || (tokens.length < 2)) {              return;          }          //计算相邻两个单词的计算规则          for (int i = 0; i < tokens.length; i++) {              tokens[i] = tokens[i].replaceAll("\W+", "");                if (tokens[i].equals("")) {                  continue;              }                pair.setWord(tokens[i]);                int start = (i - neighborWindow < 0) ? 0 : i - neighborWindow;              int end = (i + neighborWindow >= tokens.length) ? tokens.length - 1 : i + neighborWindow;              for (int j = start; j <= end; j++) {                  if (j == i) {                      continue;                  }                  pair.setNeighbor(tokens[j].replaceAll("\W", ""));                  context.write(pair, ONE);              }              //              pair.setNeighbor("*");              totalCount.set(end - start);              context.write(pair, totalCount);          }      }
//reduce函数   @Override      protected void reduce(PairOfWords key, Iterable<IntWritable> values, Context context)              throws IOException, InterruptedException {          //等于*表示为单词本身,它的count为totalCount          if (key.getNeighbor().equals("*")) {              if (key.getWord().equals(currentWord)) {                  totalCount += totalCount + getTotalCount(values);              } else {                  currentWord = key.getWord();                  totalCount = getTotalCount(values);              }          } else {              //其它的则为单次的word,需要通过getTotalCount获得相加              int count = getTotalCount(values);              relativeCount.set((double) count / totalCount);              context.write(key, relativeCount);          }        }

Spark

public static void main(String[] args) {          if (args.length < 3) {              System.out.println("Usage: RelativeFrequencyJava <neighbor-window> <input-dir> <output-dir>");              System.exit(1);          }            SparkConf sparkConf = new SparkConf().setAppName("RelativeFrequency");          JavaSparkContext sc = new JavaSparkContext(sparkConf);            int neighborWindow = Integer.parseInt(args[0]);          String input = args[1];          String output = args[2];            final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);            JavaRDD<String> rawData = sc.textFile(input);            /*           * Transform the input to the format: (word, (neighbour, 1))           */          JavaPairRDD<String, Tuple2<String, Integer>> pairs = rawData.flatMapToPair(                  new PairFlatMapFunction<String, String, Tuple2<String, Integer>>() {              private static final long serialVersionUID = -6098905144106374491L;                @Override              public java.util.Iterator<scala.Tuple2<String, scala.Tuple2<String, Integer>>> call(String line) throws Exception {                  List<Tuple2<String, Tuple2<String, Integer>>> list = new ArrayList<Tuple2<String, Tuple2<String, Integer>>>();                  String[] tokens = line.split("\s");                  for (int i = 0; i < tokens.length; i++) {                      int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value();                      int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value();                      for (int j = start; j <= end; j++) {                          if (j != i) {                              list.add(new Tuple2<String, Tuple2<String, Integer>>(tokens[i], new Tuple2<String, Integer>(tokens[j], 1)));                          } else {                              // do nothing                              continue;                          }                      }                  }                  return list.iterator();              }          }          );            // (word, sum(word))          //PairFunction<T, K, V> T => Tuple2<K, V>          JavaPairRDD<String, Integer> totalByKey = pairs.mapToPair(                    new PairFunction<Tuple2<String, Tuple2<String, Integer>>, String, Integer>() {              private static final long serialVersionUID = -213550053743494205L;                @Override              public Tuple2<String, Integer> call(Tuple2<String, Tuple2<String, Integer>> tuple) throws Exception {                  return new Tuple2<String, Integer>(tuple._1, tuple._2._2);              }          }).reduceByKey(                          new Function2<Integer, Integer, Integer>() {                      private static final long serialVersionUID = -2380022035302195793L;                        @Override                      public Integer call(Integer v1, Integer v2) throws Exception {                          return (v1 + v2);                      }                  });            JavaPairRDD<String, Iterable<Tuple2<String, Integer>>> grouped = pairs.groupByKey();            // (word, (neighbour, 1)) -> (word, (neighbour, sum(neighbour)))          //flatMapValues至少对value进行操作,但是不改变key的顺序          JavaPairRDD<String, Tuple2<String, Integer>> uniquePairs = grouped.flatMapValues(                  //Function<T1, R> -> R call(T1 v1)                  new Function<Iterable<Tuple2<String, Integer>>, Iterable<Tuple2<String, Integer>>>() {              private static final long serialVersionUID = 5790208031487657081L;                @Override              public Iterable<Tuple2<String, Integer>> call(Iterable<Tuple2<String, Integer>> values) throws Exception {                  Map<String, Integer> map = new HashMap<>();                  List<Tuple2<String, Integer>> list = new ArrayList<>();                  Iterator<Tuple2<String, Integer>> iterator = values.iterator();                  while (iterator.hasNext()) {                      Tuple2<String, Integer> value = iterator.next();                      int total = value._2;                      if (map.containsKey(value._1)) {                          total += map.get(value._1);                      }                      map.put(value._1, total);                  }                  for (Map.Entry<String, Integer> kv : map.entrySet()) {                      list.add(new Tuple2<String, Integer>(kv.getKey(), kv.getValue()));                  }                  return list;              }          });            // (word, ((neighbour, sum(neighbour)), sum(word)))          JavaPairRDD<String, Tuple2<Tuple2<String, Integer>, Integer>> joined = uniquePairs.join(totalByKey);            // ((key, neighbour), sum(neighbour)/sum(word))          JavaPairRDD<Tuple2<String, String>, Double> relativeFrequency = joined.mapToPair(                  new PairFunction<Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>>, Tuple2<String, String>, Double>() {              private static final long serialVersionUID = 3870784537024717320L;                @Override              public Tuple2<Tuple2<String, String>, Double> call(Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>> tuple) throws Exception {                  return new Tuple2<Tuple2<String, String>, Double>(new Tuple2<String, String>(tuple._1, tuple._2._1._1), ((double) tuple._2._1._2 / tuple._2._2));              }          });            // For saving the output in tab separated format          // ((key, neighbour), relative_frequency)          //将结果转换成一个String          JavaRDD<String> formatResult_tab_separated = relativeFrequency.map(                  new Function<Tuple2<Tuple2<String, String>, Double>, String>() {              private static final long serialVersionUID = 7312542139027147922L;                @Override              public String call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception {                  return tuple._1._1 + "t" + tuple._1._2 + "t" + tuple._2;              }          });            // save output          formatResult_tab_separated.saveAsTextFile(output);            // done          sc.close();        }

Spark SQL

 public static void main(String[] args) {          if (args.length < 3) {              System.out.println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>");              System.exit(1);          }            SparkConf sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency");          //创建SparkSQL需要的SparkSession          SparkSession spark = SparkSession                  .builder()                  .appName("SparkSQLRelativeFrequency")                  .config(sparkConf)                  .getOrCreate();            JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());          int neighborWindow = Integer.parseInt(args[0]);          String input = args[1];          String output = args[2];            final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);            /*           *注册一个Schema表,这个frequency等会要用           * Schema (word, neighbour, frequency)           */          StructType rfSchema = new StructType(new StructField[]{              new StructField("word", DataTypes.StringType, false, Metadata.empty()),              new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()),              new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())});            JavaRDD<String> rawData = sc.textFile(input);            /*           * Transform the input to the format: (word, (neighbour, 1))           */          JavaRDD<Row> rowRDD = rawData                  .flatMap(new FlatMapFunction<String, Row>() {                      private static final long serialVersionUID = 5481855142090322683L;                        @Override                      public Iterator<Row> call(String line) throws Exception {                          List<Row> list = new ArrayList<>();                          String[] tokens = line.split("\s");                          for (int i = 0; i < tokens.length; i++) {                              int start = (i - brodcastWindow.value() < 0) ? 0                                      : i - brodcastWindow.value();                              int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1                                      : i + brodcastWindow.value();                              for (int j = start; j <= end; j++) {                                  if (j != i) {                                      list.add(RowFactory.create(tokens[i], tokens[j], 1));                                  } else {                                      // do nothing                                      continue;                                  }                              }                          }                          return list.iterator();                      }                  });          //创建DataFrame          Dataset<Row> rfDataset = spark.createDataFrame(rowRDD, rfSchema);          //将rfDataset转成一个table,可以进行查询          rfDataset.createOrReplaceTempView("rfTable");            String query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf "                  + "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a "                  + "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word";          Dataset<Row> sqlResult = spark.sql(query);            sqlResult.show(); // print first 20 records on the console          sqlResult.write().parquet(output + "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects.          sqlResult.rdd().saveAsTextFile(output + "/textFormat"); // to see output via cat command            // done          sc.close();          spark.stop();        }

Scala

def main(args: Array[String]): Unit = {        if (args.size < 3) {        println("Usage: RelativeFrequency <neighbor-window> <input-dir> <output-dir>")        sys.exit(1)      }        val sparkConf = new SparkConf().setAppName("RelativeFrequency")      val sc = new SparkContext(sparkConf)        val neighborWindow = args(0).toInt      val input = args(1)      val output = args(2)        val brodcastWindow = sc.broadcast(neighborWindow)        val rawData = sc.textFile(input)        /*       * Transform the input to the format:       * (word, (neighbour, 1))       */      val pairs = rawData.flatMap(line => {        val tokens = line.split("\s")        for {          i <- 0 until tokens.length          start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value          end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value          j <- start to end if (j != i)          //用yield来收集转换之后的函数(word, (neighbour, 1))        } yield (tokens(i), (tokens(j), 1))      })        // (word, sum(word))      val totalByKey = pairs.map(t => (t._1, t._2._2)).reduceByKey(_ + _)        val grouped = pairs.groupByKey()        // (word, (neighbour, sum(neighbour)))      val uniquePairs = grouped.flatMapValues(_.groupBy(_._1).mapValues(_.unzip._2.sum))      //用join函数把两个RDD连接起来      // (word, ((neighbour, sum(neighbour)), sum(word)))      val joined = uniquePairs join totalByKey        // ((key, neighbour), sum(neighbour)/sum(word))      val relativeFrequency = joined.map(t => {        ((t._1, t._2._1._1), (t._2._1._2.toDouble / t._2._2.toDouble))      })        // For saving the output in tab separated format      // ((key, neighbour), relative_frequency)      val formatResult_tab_separated = relativeFrequency.map(t => t._1._1 + "t" + t._1._2 + "t" + t._2)      formatResult_tab_separated.saveAsTextFile(output)        // done      sc.stop()    }  

Scala版Spark SQL

def main(args: Array[String]): Unit = {        if (args.size < 3) {        println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>")        sys.exit(1)      }        val sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency")        val spark = SparkSession        .builder()        .config(sparkConf)        .getOrCreate()      val sc = spark.sparkContext        val neighborWindow = args(0).toInt      val input = args(1)      val output = args(2)        val brodcastWindow = sc.broadcast(neighborWindow)        val rawData = sc.textFile(input)        /*      * Schema      * (word, neighbour, frequency)      */      val rfSchema = StructType(Seq(        StructField("word", StringType, false),        StructField("neighbour", StringType, false),        StructField("frequency", IntegerType, false)))        /*       * Transform the input to the format:       * Row(word, neighbour, 1)       */      //转换成StructType中要求的格式      val rowRDD = rawData.flatMap(line => {        val tokens = line.split("\s")        for {          i <- 0 until tokens.length          //正常的计算规则,与MapReduce有区别          start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value          end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value          j <- start to end if (j != i)        } yield Row(tokens(i), tokens(j), 1)      })        val rfDataFrame = spark.createDataFrame(rowRDD, rfSchema)      //创建rfTable表      rfDataFrame.createOrReplaceTempView("rfTable")        import spark.sql        val query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf " +        "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a " +        "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"        val sqlResult = sql(query)      sqlResult.show() // print first 20 records on the console      sqlResult.write.save(output + "/parquetFormat") // saves output in compressed Parquet format, recommended for large projects.      sqlResult.rdd.saveAsTextFile(output + "/textFormat") // to see output via cat command        // done      spark.stop()      }

以上就是用五种方法解决这个问题。