初探Numpy中的花式索引

前言

Numpy中对数组索引的方式有很多(为了方便介绍文中的数组如不加特殊说明指的都是Numpy中的ndarry数组),比如:

  • 基本索引:通过单个整数值来索引数组
import numpy as np    arr = np.arange(9) # 构造一维数组  print(arr) # array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])  # 通过整数值索引一维数组中的单个元素值  print(arr[2]) # 2  print(arr[8]) # 8

使用基本索引方式索引二维数组。

import numpy as np    arr2d = np.arange(9).reshape(3, 3) # 构建二维数组  print(arr2d) # [[0 1 2] [3 4 5] [6 7 8]]  # 通过整数值索引二维数组中的数组子集  print(arr2d[0]) # [0 1 2]  # 通过整数值索引二维数组中的单个元素值  print(arr2d[0, 2]) # 2
  • 切片索引:通过[start: end: step](起始位置为start,终止位置为end,步长为steps)的方式索引连续的数组子集
import numpy as np    arr2d = np.arange(9).reshape(3, 3)  print(arr2d) # [[0 1 2] [3 4 5] [6 7 8]]  print(arr2d[:, 0]) # [0 3 6]  print(arr2d[::2, :]) # [[0 1 2] [6 7 8]] 
  • 布尔索引:通过布尔类型的数组进行索引
import numpy as np    names = np.array(['Bob', 'Joe', 'Will'])  scores = np.random.randint(0, 100, (3, 4)) # 3名学生的4科成绩    print(names == 'Bob')  print(scores[names == 'Bob']) # 获取Bob的四科成绩
  • 花式索引:通过整型数组进行索引

本文将重点介绍通过整型数组进行索引的花式索引。

a

什么是花式索引?

花式索引(Fancy indexing)是指利用整数数组进行索引,这里的整数数组可以是Numpy数组也可以是Python中列表、元组等可迭代类型。

花式索引根据索引整型数组的值作为目标数组的某个轴的下标来取值。这句话对于理解花式索引非常关键,而核心就是"轴"以及"下标",既然是整数数组作为下标,这就要求如果设置多个整数数组来索引的话,这些整数数组的元素个数要相等,这样才能够将整数数组映射成下标。比如对于[0, 1]和[1, 1]两个整型数组,可以拼接成arr[0, 1]和arr[1, 1]的下标来取值,而对于[1, 2, 3]和[3, 4]两个元素个数不等的情况下,是不能拼接成对应的下标的。

import numpy as np    arr3d = np.arange(12).reshape(2, 2, 3)    # 使用两个整数数组来对axis= 0,1两个轴进行花式索引  print(arr3d[[0, 1], [1, 1]])  print(arr3d[[0, 1], [0, 1, 2]]) #error    [[ 3  4  5]   [ 9 10 11]]  Traceback (most recent call last):    File "D:/code/PycharmProjects/Python_base/text01.py", line 147, in <module>      print(arr3d[[0, 1], [0, 1, 2]])  IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (3,) 

当然得益于Numpy中的广播机制,如果其中的一个整型数组只有一个元素可以广播到与之其它整型数组相同的元素个数,比如[0, 1]和[2]两个整数数组,Numpy的广播机制先将[2]变成[2, 2],然后再拼接成相应的下标arr[0, 2]和arr[1, 2]。当基本索引和缓释索引组合的时候,基本索引会被广播成整数数组,形成花式索引。

import numpy as np    arr3d = np.arange(12).reshape(3,4)    print(arr3d[[0, 1], [2]])  print(arr3d[[0, 1], [2, 2]])    # 花式索引和基本索引组合  print(arr3d[[0, 1], 2])  print(arr3d[[0, 1], [2, 2]])    print(arr3d[0, [0]])  print(arr3d[[0], [0]])    [2 6]  [2 6]  [2 6]  [2 6]  [0]  [0]

下面先来利用一维数组来举例,花式索引利用整数数组来索引,那么就先来一个整数数组,这里的整数数组可以为Numpy数组以及Python中可迭代类型,这里为了方便使用Python中的list列表。

import numpy as np    arr = np.arange(9)  print(arr)    arr2 = arr[[0, 2]] # 使用花式索引  print(arr2)    print(arr2[0])  print(arr2[1])    [0 1 2 3 4 5 6 7 8]  [0 2]  0  2

前面提到对于理解花式索引非常关键的"轴"和"下标":

  1. 对于一维数组只有一个轴axis = 0,因此我们只能设置一个整型数组并且整型数组只能作用在axis = 0这个轴上;
  2. 下标其实也很好理解,对于整数数组为[0, 2],可以简单理解0和2分别是arr数组的下标,即arr[0]和arr[2],花式索引arr[[0, 2]]结果中的元素值和单独对arr[0]以及arr[2]进行索引的元素值是一致的。

一维数组还比较简单,下面来看一个二维数组要如何理解?

import numpy as np    arr2d = np.arange(9).reshape(3, 3)  print(arr2d)    arr2d2 = arr2d[[0, 2]] # 使用花式索引  print(arr2d2)    print(arr2d[0])  print(arr2d[2])    [[0 1 2]   [3 4 5]   [6 7 8]]  [[0 1 2]   [6 7 8]]  [0 1 2]  [6 7 8]

继续使用花式索引中的"轴"和"下标"来理解花式索引下的二维数组:

  1. 对于二维数组来说一共有两个维度两个轴axis = 0、axis = 1,由于此时整数数组只有一个,此时由于花式索引中只有一个数组,所以此时的索引数组只能作用在axis = 0的这个轴上;
  2. 由于这里只有一个数组所以下标的理解和在一维数组中类似,对于[0, 2]来说,对应的下标索引为arr2d[0]、arr2d[2],对于二维数组相应的索引结果为二维数组arr2中的第一行和第三行;

一个整数数组能够索引一个轴,那么对于二维数组来说,如果有两个整数数组的话肯定能够索引两个轴。接下来我们再为二维数组添加一个整数数组[1, 2]。

import numpy as np    arr2d = np.arange(9).reshape(3, 3)  print(arr2d)  arr2d2 = arr2d[[0, 2], [1, 2]] # 使用花式索引  print(arr2d2)    [[0 1 2] [3 4 5] [6 7 8]]  [1 8]
  1. 二维数组一共有两个轴,此时的整数数组刚好有两个,所以两个整数数组会作用在二维数组中的两个轴上;
  2. 由于二维数组的两个轴都被索引了,所以此时的下标和上面的稍有不同,对于[0, 2]和[1, 2]两个整数数组来说,相应的下标先在第一个整数数组中选择0,然后再在第二个整数数组中选择1,即为arr2d[0][0]等价arr2d[0, 0],同理对于第二个索引来说先在第一个整数数组中选择2,然后再第二个整数数组中选择2,即为arr2d[2][2]等价arr2d[2, 2]。这也从侧面证明了为什么花式索引会要求在给定轴上的整数数组元素个数要相等;

简单总结一下,一个整数数组作用在待索引数组中的一个轴上,因此整数数组的个数要小于等于待索引数组的维度个数,对于下标来说,花式索引本质上可以转换为基本索引,所以要求整数数组中的元素值不能超过对应待索引数组的最大索引。

b

花式索引的使用

通过上面的例子你可能会觉得花式索引完全可以被其它的索引方式所替代,并没有存在的必要。花式索引擅长一些不规则的索引,这些不规则的索引使用其它的索引方式可能也可以实现,但是相比于花式索引实现会比较复杂。

比如现在有一个二维数组,二维数组的形状为(3, 4),表示3名学生的4课成绩。

import numpy as np    np.random.seed(666) # 设置随机种子  scores = np.random.randint(0, 100, (3, 4))  print(scores)    [[ 2 45 30 62]   [70 73 30 36]   [61 91 94 51]]

现在比如想要获取第1名学生以及第3名学生的成绩。

import numpy as np    np.random.seed(666) # 设置随机种子  scores = np.random.randint(0, 100, (3, 4))  print(scores)  print(scores[[0, 2]]) # 通过花式索引第1名学生以及第3名学生    [[ 2 45 30 62]   [70 73 30 36]   [61 91 94 51]]  [[ 2 45 30 62]   [61 91 94 51]]

如果使用其它的索引方式会比较复杂,比如使用基本索引需要使用concat将arr[0]和arr[1]合并起来,而切片索引只能索引连续的位置。

还可以通过负值倒叙进行花式索引。比如现在想要索引最后一名学生以及第一名学生的4课成绩。

import numpy as np    np.random.seed(666) # 设置随机种子  scores = np.random.randint(0, 100, (3, 4))  print(scores)  score = scores[[-1, 0]]  print(score)    [[ 2 45 30 62]   [70 73 30 36]   [61 91 94 51]]  [[61 91 94 51]   [ 2 45 30 62]]

在机器学习中常通过使用花式索引来打乱数据集的样本顺序,避免机器学习模型学习到样本的位置噪声,对于监督学习的数据集如果打乱了样本还需要打乱相对应的标签值,样本与标签都是一一对应的关系,使用花式索引能够轻松的解决。

import numpy as np  from sklearn import datasets    digits = datasets.load_digits()  X = digits.data  y = digits.target    index = np.random.permutation(X.shape[0])  print(type(index)) # <class 'numpy.ndarray'>  # 乱序后的数据集  X_random, y_random = X[index], y[index]

c

花式索引的维度问题?

到目前为止我们只关注索引的值,而忽视了最终索引后的维度变化。首先来看下面的例子,依然是上面的形状为(3, 4)表示3名学生的4课成绩的二维数组。这里使用花式索引索引出第2名学生的4课全部成绩。

import numpy as np    np.random.seed(666) # 设置随机种子  scores = np.random.randint(0, 100, (3, 4))  print(scores)  score = scores[[1]] # 花式索引第2名学生的所有成绩  print(score.shape)  print(score)    [[ 2 45 30 62]   [70 73 30 36]   [61 91 94 51]]  (1, 4)  [[70 73 30 36]]

通过前面的学习知道可以将花式索引中的整数数组转换为数组下标的基本索引。通过前面的介绍将scores[[1]]转换为scores[1]的基本索引方式。

import numpy as np    np.random.seed(666) # 设置随机种子  scores = np.random.randint(0, 100, (3, 4))  print(scores)  score = scores[1] # 基本索引第2名学生的所有成绩  print(score.shape)  print(score)    [[ 2 45 30 62]   [70 73 30 36]   [61 91 94 51]]  (4,)  [70 73 30 36]

虽然scores[[1]]的花式索引和score[1]的普通索引最后元素值相同,但是它们的维度却有很大的差别。

如果一开始学习花式索引很容易被维度所搞乱。这里我总结了一个小技巧,每一个整数数组作用一个维度,假设原始数组中有n个维度,使用花式索引,有第一个整数数组的时候结果维度为n,第二个整数数组后的索引结果维度为(n – 1),第三个整数数组后的索引结果维度为(n – 2),依次类推。

下面就来举几个小例子来实验一下我们的小技巧。

import numpy as np    arr3d = np.arange(12).reshape(2, 2, 3)  print(arr3d.ndim) # 3  print(arr3d[[0]].ndim) # 3  print(arr3d[[0], [0]].ndim) # 2  print(arr3d[[0], [0], [0]].ndim) # 1

再来一个更加复杂一点的小例子。

import numpy as np    arr3d = np.arange(12).reshape(2, 2, 3)  print(arr3d.ndim) # 3  arr3d2 = arr3d[[0, -1]][:, [0, 1]]  print(arr3d2.ndim) # 3

arr3d[[0, -1]][:, [0, 1]]本身其实并不复杂,将arr3d[[0, -1]][:, [0, 1]]分成两个部分,先是arr3d[[0, -1]],将结果再进行索引[:, [0, 1]],这里只关注最后的维度,原始数组arr3d的ndim值为3,此时arr3d[[0, -1]]返回的是ndim = 3的索引结果,此时ndim = 3的数组进行[:, [0, 1]]的索引,其中只有一个整数数组,因此最终的维度还是3。