寫一手漂亮的代碼,走向極致的編程 一、代碼運行時間分析

前言

寫一手漂亮的代碼,何謂漂亮的代碼?對我來說大概有這麼幾點:

  1. 寫法符合規範(如:該空格的地方打上空格,該換行的地方換行,名命方式符合規範等等)
  2. 簡潔且可讀性高(能十行代碼實現並且讓人容易看懂的絕不寫十一行,對經常重複出現的代碼段落進行封裝)
  3. 性能高(如:運行時間儘可能短,運行時所用內存儘可能少)

要實現以上目標,自然就要對代碼進行優化,說到代碼的優化,自然而然就會想到對算法時間複雜度進行優化,比如我要實現一個在有序數組中查找一個數,最容易想到的就是遍歷一遍 O(n) 的複雜度,優化一下自然是使用二分, O(logn) 的複雜度。如果這段代碼在我們的程序中會經常被調用,那麼,通過這算法上的優化,我們的程序性能自然而然的會有很高的提升。

但是,有時候會發現,已經對算法進行優化了,程序的性能(如運行時間、內存佔用等)仍然不能達到預期,那麼,這時候該如何對我們的代碼進行進一步的優化呢?

這篇文章將以 Python 為例進行介紹

先來段代碼

這裡,我將通過使用 Julia 分形的代碼來進行。

Julia 集合,由式 \(f_c(z) = z ^2 + c\) 進行反覆迭代到。

對於固定的複數 c ,取某一 z 值,可以得到序列

\(z_0, f_c(z_0), f_c(f_c(z_0)), …\)

這一序列可能發散於無窮大或處於某一範圍之內並收斂於某一值,我們將使其不擴散的 z 值的集合稱為朱利亞集合。

import time
import numpy as np
import imageio
import PIL
import matplotlib.pyplot as plt
import cv2 as cv

x1, x2, y1, y2 = -1.8, 1.8, -1.8, 1.8
c_real, c_imag = -0.62772, -0.42193

def calculate_z_serial_purepython(maxiter, zs, cs):
    output = [0] * len(zs)
    for i in range(len(zs)):
        n = 0
        z = zs[i]
        c = cs[i]
        while abs(z) < 2 and n < maxiter:
            z = z * z + c
            n += 1
        output[i] = n
    return output

def calc_pure_python(desired_width, max_itertions):
    x_step = (float(x2 - x1)) / float(desired_width)
    y_step = (float(y2 - y1)) / float(desired_width)
    x, y = [], []
    ycoord = y1
    while ycoord < y2:
        y.append(ycoord)
        ycoord += y_step
    xcoord = x1
    while xcoord < x2:
        x.append(xcoord)
        xcoord += x_step
    zs, cs = [], []
    for ycoord in y:
        for xcoord in x:
            zs.append(complex(xcoord, ycoord))
            cs.append(complex(c_real, c_imag))
    print(f"Length of x: {len(x)}")
    print(f"Total elements: {len(zs)}")
    start_time = time.time()
    output = calculate_z_serial_purepython(max_itertions, zs, cs)
    end_time = time.time()
    secs = end_time - start_time
    print("calculate_z_serial_purepython took", secs, "seconds")

    assert sum(output) == 33219980
    # # show img
    # output = np.array(output).reshape(desired_width, desired_width)
    # plt.imshow(output, cmap='gray')
    # plt.savefig("julia.png")
    

if __name__ == "__main__":
    calc_pure_python(desired_width=1000, max_itertions=300)

這段代碼運行完,可以得到圖片

運行結果

Length of x: 1000
Total elements: 1000000
calculate_z_serial_purepython took 25.053941249847412 seconds

開始分析

這裡,將通過各種方法來對這段代碼的運行時間來進行分析

直接打印運行時間

在前面的代碼中,我們可以看到有 start_time 和 end_time 兩個變量,通過 print 兩個變量的差值即可得到運行時間,但是,每次想要打印運行時間都得加那麼幾行代碼就會很麻煩,此時我們可以通過使用修飾器來進行

from functools import wraps
def timefn(fn):
    @wraps(fn)
    def measure_time(*args, **kwargs):
        start_time = time.time()
        result = fn(*args, **kwargs)
        end_time = time.time()
        print("@timefn:" + fn.__name__ + " took " + str(end_time - start_time), " seconds")
        return result
    return measure_time

然後對 calculate_z_serial_purepython 函數進行測試

@timefn
def calculate_z_serial_purepython(maxiter, zs, cs):
	...

運行後輸出結果

Length of x: 1000
Total elements: 1000000
@timefn:calculate_z_serial_purepython took 26.64286208152771  seconds
calculate_z_serial_purepython took 26.64286208152771 seconds

另外,也可以在命令行中輸入

python -m timeit -n 5 -r 5 -s "import code" "code.calc_pure_python(desired_width=1000, max_itertions=300)"

其中 -n 5 表示循環次數, -r 5 表示重複次數,timeit 會對語句循環執行 n 次,並計算平均值作為一個結果,重複 r 次選出最好的結果。

5 loops, best of 5: 24.9 sec per loop

UNIX tine 命令

由於電腦上沒有 Linux 環境,於是使用 WSL 來進行

time -p python code.py
如果是 Linux 中進行,可能命令需改成
/usr/bin/time -p python code.py

輸出結果

Length of x: 1000
Total elements: 1000000
@timefn:calculate_z_serial_purepython took 14.34933090209961  seconds
calculate_z_serial_purepython took 14.350624322891235 seconds
real 15.57
user 15.06
sys 0.40

其中 real 記錄整體耗時, user 記錄了 CPU 花在任務上的時間,sys 記錄了內核函數耗費的時間

/usr/bin/time --verbose python code.py

輸出,WSL 的 time 命令裏面沒有 –verbose 這個參數,只能到服務器裏面試了,突然覺得我的筆記本跑的好慢。。。

Length of x: 1000
Total elements: 1000000
@timefn:calculate_z_serial_purepython took 7.899603605270386  seconds
calculate_z_serial_purepython took 7.899857997894287 seconds
        Command being timed: "python code.py"
        User time (seconds): 8.33
        System time (seconds): 0.08
        Percent of CPU this job got: 98%
        Elapsed (wall clock) time (h:mm:ss or m:ss): 0:08.54
        Average shared text size (kbytes): 0
        Average unshared data size (kbytes): 0
        Average stack size (kbytes): 0
        Average total size (kbytes): 0
        Maximum resident set size (kbytes): 98996
        Average resident set size (kbytes): 0
        Major (requiring I/O) page faults: 0
        Minor (reclaiming a frame) page faults: 25474
        Voluntary context switches: 0
        Involuntary context switches: 2534
        Swaps: 0
        File system inputs: 0
        File system outputs: 0
        Socket messages sent: 0
        Socket messages received: 0
        Signals delivered: 0
        Page size (bytes): 4096
        Exit status: 0

這裏面需要關心的參數是 Major (requiring I/O) page faults ,表示操作系統是否由於 RAM 中的數據不存在而需要從磁盤上讀取頁面。

cProfile 模塊

cProfile 模塊是標準庫內建三個的分析工具之一,另外兩個是 hotshot 和 profile。

python -m cProfile -s cumulative code.py

-s cumulative 表示對每個函數累計花費的時間進行排序

輸出

36222017 function calls in 30.381 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   30.381   30.381 {built-in method builtins.exec}
        1    0.064    0.064   30.381   30.381 code.py:1(<module>)
        1    1.365    1.365   30.317   30.317 code.py:35(calc_pure_python)
        1    0.000    0.000   28.599   28.599 code.py:13(measure_time)
        1   19.942   19.942   28.598   28.598 code.py:22(calculate_z_serial_purepython)
 34219980    8.655    0.000    8.655    0.000 {built-in method builtins.abs}
  2002000    0.339    0.000    0.339    0.000 {method 'append' of 'list' objects}
        1    0.012    0.012    0.012    0.012 {built-in method builtins.sum}
        4    0.003    0.001    0.003    0.001 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 code.py:12(timefn)
        1    0.000    0.000    0.000    0.000 functools.py:44(update_wrapper)
        4    0.000    0.000    0.000    0.000 {built-in method time.time}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:989(_handle_fromlist)
        4    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        7    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        5    0.000    0.000    0.000    0.000 {built-in method builtins.setattr}
        1    0.000    0.000    0.000    0.000 functools.py:74(wraps)
        1    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

可以看到,在代碼的入口處總共花費了 30.381 秒,ncalls 為 1,表示只執行了 1 次,然後 calculate_z_serial_purepython 花費了 28.598 秒,可以推斷出調用該函數使用了近 2 秒。另外可以看到,abs 函數被調用了 34219980 次。對列表項的 append 操作進行了 2002000 次(1000 * 1000 * 2 +1000 * 2 )。

接下來,我們進行更深入的分析。

python -m cProfile -o profile.stats code.py

先生成一個統計文件,然後在 python 中進行分析

>>> import pstats
>>> p = pstats.Stats("profile.stats")
>>> p.sort_stats("cumulative")
<pstats.Stats object at 0x000002AA0A6A8908>
>>> p.print_stats()
Sat Apr 25 16:38:07 2020    profile.stats

         36222017 function calls in 30.461 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   30.461   30.461 {built-in method builtins.exec}
        1    0.060    0.060   30.461   30.461 code.py:1(<module>)
        1    1.509    1.509   30.400   30.400 code.py:35(calc_pure_python)
        1    0.000    0.000   28.516   28.516 code.py:13(measure_time)
        1   20.032   20.032   28.515   28.515 code.py:22(calculate_z_serial_purepython)
 34219980    8.483    0.000    8.483    0.000 {built-in method builtins.abs}
  2002000    0.360    0.000    0.360    0.000 {method 'append' of 'list' objects}
        1    0.012    0.012    0.012    0.012 {built-in method builtins.sum}
        4    0.004    0.001    0.004    0.001 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 code.py:12(timefn)
        1    0.000    0.000    0.000    0.000 C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)
        4    0.000    0.000    0.000    0.000 {built-in method time.time}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:989(_handle_fromlist)
        7    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        4    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:74(wraps)
        1    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
        5    0.000    0.000    0.000    0.000 {built-in method builtins.setattr}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


<pstats.Stats object at 0x000002AA0A6A8908>

這裡,就生成了與上面一致的信息

>>> p.print_callers()
   Ordered by: cumulative time

Function                                                                                              was called by...
                                                                                                          ncalls  tottime  cumtime
{built-in method builtins.exec}                                                                       <-
code.py:1(<module>)                                                                                   <-       1    0.060   30.461  {built-in method builtins.exec}
code.py:35(calc_pure_python)                                                                          <-       1    1.509   30.400  code.py:1(<module>)
code.py:13(measure_time)                                                                              <-       1    0.000   28.516  code.py:35(calc_pure_python)
code.py:22(calculate_z_serial_purepython)                                                             <-       1   20.032   28.515  code.py:13(measure_time)
{built-in method builtins.abs}                                                                        <- 34219980    8.483    8.483  code.py:22(calculate_z_serial_purepython)
{method 'append' of 'list' objects}                                                                   <- 2002000    0.360    0.360  code.py:35(calc_pure_python)
{built-in method builtins.sum}                                                                        <-       1    0.012    0.012  code.py:35(calc_pure_python)
{built-in method builtins.print}                                                                      <-       1    0.000    0.000  code.py:13(measure_time)
                                                                                                               3    0.003    0.003  code.py:35(calc_pure_python)
code.py:12(timefn)                                                                                    <-       1    0.000    0.000  code.py:1(<module>)
C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)  <-       1    0.000    0.000  code.py:12(timefn)
{built-in method time.time}                                                                           <-       2    0.000    0.000  code.py:13(measure_time)
                                                                                                               2    0.000    0.000  code.py:35(calc_pure_python)
<frozen importlib._bootstrap>:989(_handle_fromlist)                                                   <-       1    0.000    0.000  code.py:1(<module>)
{built-in method builtins.getattr}                                                                    <-       7    0.000    0.000  C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)
{built-in method builtins.hasattr}                                                                    <-       1    0.000    0.000  <frozen importlib._bootstrap>:989(_handle_fromlist)
{built-in method builtins.len}                                                                        <-       2    0.000    0.000  code.py:22(calculate_z_serial_purepython)
                                                                                                               2    0.000    0.000  code.py:35(calc_pure_python)
C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:74(wraps)           <-       1    0.000    0.000  code.py:12(timefn)
{method 'update' of 'dict' objects}                                                                   <-       1    0.000    0.000  C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)
{built-in method builtins.setattr}                                                                    <-       5    0.000    0.000  C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)
{method 'disable' of '_lsprof.Profiler' objects}                                                      <-


<pstats.Stats object at 0x000002AA0A6A8908>

這裡,我們可以看到,在每一行最後會有調用這部分的父函數名稱,這樣我們就可以定位到對某一操作最費時的那個函數。

我們還可以顯示那個函數調用了其它函數

>>> p.print_callees()
   Ordered by: cumulative time

Function                                                                                              called...
                                                                                                          ncalls  tottime  cumtime
{built-in method builtins.exec}                                                                       ->       1    0.060   30.461  code.py:1(<module>)
code.py:1(<module>)                                                                                   ->       1    0.000    0.000  <frozen importlib._bootstrap>:989(_handle_fromlist)
                                                                                                               1    0.000    0.000  code.py:12(timefn)
                                                                                                               1    1.509   30.400  code.py:35(calc_pure_python)
code.py:35(calc_pure_python)                                                                          ->       1    0.000   28.516  code.py:13(measure_time)
                                                                                                               2    0.000    0.000  {built-in method builtins.len}
                                                                                                               3    0.003    0.003  {built-in method builtins.print}
                                                                                                               1    0.012    0.012  {built-in method builtins.sum}
                                                                                                               2    0.000    0.000  {built-in method time.time}
                                                                                                         2002000    0.360    0.360  {method 'append' of 'list' objects}
code.py:13(measure_time)                                                                              ->       1   20.032   28.515  code.py:22(calculate_z_serial_purepython)
                                                                                                               1    0.000    0.000  {built-in method builtins.print}
                                                                                                               2    0.000    0.000  {built-in method time.time}
code.py:22(calculate_z_serial_purepython)                                                             -> 34219980    8.483    8.483  {built-in method builtins.abs}
                                                                                                               2    0.000    0.000  {built-in method builtins.len}
{built-in method builtins.abs}                                                                        ->
{method 'append' of 'list' objects}                                                                   ->
{built-in method builtins.sum}                                                                        ->
{built-in method builtins.print}                                                                      ->
code.py:12(timefn)                                                                                    ->       1    0.000    0.000  C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)
                                                                                                               1    0.000    0.000  C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:74(wraps)
C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:44(update_wrapper)  ->       7    0.000    0.000  {built-in method builtins.getattr}
                                                                                                               5    0.000    0.000  {built-in method builtins.setattr}
                                                                                                               1    0.000    0.000  {method 'update' of 'dict' objects}
{built-in method time.time}                                                                           ->
<frozen importlib._bootstrap>:989(_handle_fromlist)                                                   ->       1    0.000    0.000  {built-in method builtins.hasattr}
{built-in method builtins.getattr}                                                                    ->
{built-in method builtins.hasattr}                                                                    ->
{built-in method builtins.len}                                                                        ->
C:\Users\ITryagain\AppData\Local\conda\conda\envs\tensorflow-gpu\lib\functools.py:74(wraps)           ->
{method 'update' of 'dict' objects}                                                                   ->
{built-in method builtins.setattr}                                                                    ->
{method 'disable' of '_lsprof.Profiler' objects}                                                      ->


<pstats.Stats object at 0x000002AA0A6A8908>

line_profiler 逐行分析

前面我們通過 cProfile 來對代碼進行了整體的分析,當我們確定了耗時多的函數後,想對該函數進行進一步分析時,就可以使用 line_profiler 了。

先安裝

pip install line_profiler
或
conda install line_profiler

在需要測試的函數前面加上修飾器 @profile,然後命令函輸入

kernprof -l -v code.py

輸出

Wrote profile results to code.py.lprof
Timer unit: 1e-07 s

Total time: 137.019 s
File: code.py
Function: calculate_z_serial_purepython at line 23

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    23                                           @profile
    24                                           def calculate_z_serial_purepython(maxiter, zs, cs):
    25         1      89776.0  89776.0      0.0      output = [0] * len(zs)
    26   1000001    9990393.0     10.0      0.7      for i in range(len(zs)):
    27   1000000    9244029.0      9.2      0.7          n = 0
    28   1000000   10851654.0     10.9      0.8          z = zs[i]
    29   1000000   10242762.0     10.2      0.7          c = cs[i]
    30  34219980  558122806.0     16.3     40.7          while abs(z) < 2 and n < maxiter:
    31  33219980  403539388.0     12.1     29.5              z = z * z + c
    32  33219980  356918574.0     10.7     26.0              n += 1
    33   1000000   11186107.0     11.2      0.8          output[i] = n
    34         1         12.0     12.0      0.0      return output

運行時間比較長。。不過,這裡可以發現,耗時的操作主要都在 while 循環中,做判斷的耗時最長,但是這裡我們並不知道是 abs(z) < 2 還是 n < maxiter 更花時間。z 與 n 的更新也比較花時間,這是因為在每次循環時, Python 的動態查詢機制都在工作。

那麼,這裡可以通過 timeit 來進行測試

In [1]: z = 0 + 0j

In [2]: %timeit abs(z) < 2
357 ns ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [3]: n = 1

In [4]: maxiter = 300

In [5]: %timeit n < maxiter
119 ns ± 6.91 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

可以看到,n < maxiter 所需時間更短,並且每301次會有一次 False,而 abs(z) < 2 為 False 的次數我們並不好估計,佔比約為前面圖片中白色部分所佔比例。因此,我們可以假設交換兩條語句的順序可以使得程序運行速度更快。

Total time: 132.816 s
File: code.py
Function: calculate_z_serial_purepython at line 23

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    23                                           @profile
    24                                           def calculate_z_serial_purepython(maxiter, zs, cs):
    25         1      83002.0  83002.0      0.0      output = [0] * len(zs)
    26   1000001    9833163.0      9.8      0.7      for i in range(len(zs)):
    27   1000000    9241272.0      9.2      0.7          n = 0
    28   1000000   10667576.0     10.7      0.8          z = zs[i]
    29   1000000   10091308.0     10.1      0.8          c = cs[i]
    30  34219980  531157092.0     15.5     40.0          while n < maxiter and abs(z) < 2:
    31  33219980  393275303.0     11.8     29.6              z = z * z + c
    32  33219980  352964180.0     10.6     26.6              n += 1
    33   1000000   10851379.0     10.9      0.8          output[i] = n
    34         1         11.0     11.0      0.0      return output

可以看到,確實是有所優化。

小節

從開始學習編程到現在差不多快 3 年了,之前可以說是從來沒有利用這些工具來對代碼性能進行過分析,最多也只是通過算法複雜度的分析來進行優化,接觸了這些之後就感覺,需要學習的東西還有很多。在近期進行的華為軟挑中,隊友也曾對代碼(C++)的運行時間進行過分析,如下圖。

下篇將介紹對運行時內存的分析。

參考

  1. 《Python 高性能編程》