深度剖析為什麼 Python 中整型不會溢出?

  • 2019 年 10 月 4 日
  • 筆記

? 「Python貓」 ,一個值得加星標的公眾號

花下貓語:前不久,我應讀者提問而寫了一篇《Python 的整數與 Numpy 的數據溢出》,簡要介紹過 Python 中的整數表示法與數據溢出問題。那篇文章的獵奇/科普成分更大些,文章簡短,乾貨量不足。為了彌補,今天特分享一篇深度的文章,大家一起來學習吧!

作者:weapon(本文獲授權轉載)

來源:https://zhuanlan.zhihu.com/p/37983326

劇照 | 《神鵰俠侶》

前言

本次分析基於 CPython 解釋器,python3.x 版本

在 python2 時代,整型有 int 類型和 long 長整型,長整型不存在溢出問題,即可以存放任意大小的整數。在 python3 後,統一使用了長整型。這也是吸引科研人員的一部分了,適合大數據運算,不會溢出,也不會有其他語言那樣還分短整型,整型,長整型… 因此 python 就降低其他行業的學習門檻了。

那麼,不溢出的整型實現上是否可行呢?

不溢出的整型的可行性

儘管在 C 語言中,整型所表示的大小是有範圍的,但是 python 程式碼是保存到文本文件中的,也就是說,python程式碼中並不是一下子就轉化成 C 語言的整型的,我們需要重新定義一種數據結構來表示和存儲我們新的「整型」。

怎麼來存儲呢,既然我們要表示任意大小,那就得用動態的可變長的結構,顯然,數組的形式能夠勝任:

[longintrepr.h]  struct _longobject {      PyObject_VAR_HEAD      int *ob_digit;  };

長整型的保存形式

長整型在python內部是用一個 int 數組( ob_digit[n] )保存值的. 待存儲的數值的低位資訊放於低位下標, 高位資訊放於高下標.比如要保存 123456789 較大的數字,但我們的int只能保存3位(假設):

ob_digit[0] = 789;  ob_digit[1] = 456;  ob_digit[2] = 123;

低索引保存的是地位,那麼每個 int 元素保存多大的數合適?有同學會認為數組中每個int存放它的上限(2^31 – 1),這樣表示大數時,數組長度更短,更省空間。但是,空間確實是更省了,但操作會程式碼麻煩,比方大數做乘積操作,由於元素之間存在乘法溢出問題,又得多考慮一種溢出的情況。

怎麼來改進呢?在長整型的 ob_digit 中元素理論上可以保存的int類型有 32 位,但是我們只保存 15位,這樣元素之間的乘積就可以只用 int 類型保存即可, 對乘積結果做位移操作就能得到尾部和進位 carry了,因此定義位移長度為 15

#define PyLong_SHIFT  15  #define PyLong_BASE ((digit)1 << PyLong_SHIFT)  #define PyLong_MASK ((digit)(PyLong_BASE - 1))

PyLong_MASK 也就是 0b111111111111111 ,通過與它做位運算 的操作就能得到低位數。

有了這種存放方式,在記憶體空間允許的情況下,我們就可以存放任意大小的數字了。

長整型的運算

加法與乘法運算都可以使用我們小學的豎式計算方法,例如對於加法運算:

為方便理解,表格展示的是數組中每個元素保存的是 3 位十進位數,計算結果保存在變數z中,那麼 z 的數組最多只要 size_a + 1 的空間(兩個加數中數組較大的元素個數 + 1),因此對於加法運算,處理過程就是各個對應位置的元素進行加法運算,計算過程就是豎式計算的方式:

[longobject.c]  static PyLongObject * x_add(PyLongObject *a, PyLongObject *b) {      int size_a = len(a), size_b = len(b);      PyLongObject *z;      int i;      int carry = 0; // 進位        // 確保a是兩個加數中較大的一個      if (size_a < size_b) {          // 交換兩個加數          swap(a, b);          swap(&size_a, &size_b);      }        z = _PyLong_New(size_a + 1);  // 申請一個能容納size_a+1個元素的長整型對象      for (i = 0; i < size_b; ++i) {          carry += a->ob_digit[i] + b->ob_digit[i];          z->ob_digit[i] = carry & PyLong_MASK;   // 掩碼          carry >>= PyLong_SHIFT;                 // 移除低15位, 得到進位      }      for (; i < size_a; ++i) {                   // 單獨處理a中高位數字          carry += a->ob_digit[i];          z->ob_digit[i] = carry & PyLong_MASK;          carry >>= PyLong_SHIFT;      }      z->ob_digit[i] = carry;      return long_normalize(z);                   // 整理元素個數    }

這部分的過程就是,先將兩個加數中長度較長的作為第一個加數,再為用於保存結果的 z 申請空間,兩個加數從數組從低位向高位計算,處理結果的進位,將結果的低 15 位賦值給 z 相應的位置。最後的 long_normalize(z) 是一個整理函數,因為我們 z 申請了 a_size + 1 的空間,但不意味著 z 會全部用到,因此這個函數會做一些調整,去掉多餘的空間,數組長度調整至正確的數量。

若不方便理解,附錄將給出更利於理解的 python 程式碼。

豎式計算不是按個位十位來計算的嗎,為什麼這邊用整個元素?

豎式計算方法適用與任何進位的數字,我們可以這樣來理解,這是一個 32768 (2的15次方) 進位的,那麼就可以把數組索引為 0 的元素當做是 「個位」,索引 1 的元素當做是 「十位」。

乘法運算

乘法運算一樣可以用豎式的計算方式,兩個乘數相乘,存放結果的 z 的元素個數為 size_a + size_b 即可:

img

這裡需要主意的是,當乘數 b 用索引 i 的元素進行計算時,結果 z 也是從 i 索引開始保存。先創建 z 並初始化為 0,這 z 進行累加,加法運算則可以利用前面的 x_add 函數:

// 為方便理解,會與cpython中源碼部分稍有不同  static PyLongObject * x_mul(PyLongObject *a, PyLongObject *b)  {      int size_a = len(a), size_b = len(b);      PyLongObject *z = _PyLong_New(size_a + size_b);      memset(z->ob_digit, 0, len(z) * sizeof(int)); // z 的數組清 0        for (i = 0; i < size_b; ++i) {          int carry = 0;          // 用一個int保存元素之間的乘法結果          int f = b->ob_digit[i]; // 當前乘數b的元素            // 創建一個臨時變數,保存當前元素的計算結果,用於累加          PyLongObject *temp = _PyLong_New(size_a + size_b);          memset(temp->ob_digit, 0, len(temp) * sizeof(int)); // temp 的數組清 0            int pz = i; // 存放到臨時變數的低位            for (j = 0; j < size_a; ++j) {              carry = f * a[j] + carry;              temp[pz] = carry & PyLong_MASK;  // 取低15位              carry = carry >> PyLong_SHIFT;  // 保留進位              pz ++;          }          if (carry){     //  處理進位              carry += temp[pz];              temp[pz] = carry & PyLong_MASK;              carry = carry >> PyLong_SHIFT;          }          if (carry){              temp[pz] += carry & PyLong_MASK;          }          temp = long_normalize(temp);          z = x_add(z, temp);      }        return z    }

這大致就是乘法的處理過程,豎式乘法的複雜度是n^2,當數字非常大的時候(數組元素個數超過 70 個)時,python會選擇性能更好,更高效的 Karatsuba multiplication 乘法運算方式,這種的演算法複雜度是 3nlog3≈3n1.585,當然這種計算方法已經不是今天討論的內容了。有興趣的小夥伴可以去了解下。

總結

要想支援任意大小的整數運算,首先要找到適合存放整數的方式,本篇介紹了用 int 數組來存放,當然也可以用字元串來存儲。找到合適的數據結構後,要重新定義整型的所有運算操作,本篇雖然只介紹了加法和乘法的處理過程,但其實還需要做很多的工作諸如減法,除法,位運算,取模,取余等。

python程式碼以文本形式存放,因此最後,還需要一個將字元串形式的數字轉換成這種整型結構:

[longobject.c]  PyObject * PyLong_FromString(const char *str, char **pend, int base)  {  }

這部分不是本篇的重點,有興趣的同學可以看看這個轉換的過程,這個過程還是比較繁瑣的,因為它還要處理進位問題,能夠處理 0xfff3 或者 0b1011 等情況。

附錄

參考:longobject.cgithub.com

# 例子中的表格中,數組元素最多存放3位整數,因此這邊設置1000  # 對應的取低位與取高位也就變成對 1000 取模和取余操作  PyLong_SHIFT = 1000  PyLong_MASK = 999    # 以15位長度的二進位  # PyLong_SHIFT = 15  # PyLong_MASK = (1 << 15) - 1    def long_normalize(num):      """      去掉多餘的空間,調整數組的到正確的長度      eg: [176, 631, 0, 0]  ==>  [176, 631]      :param num:      :return:      """      end = len(num)      while end >= 1:          if num[end - 1] != 0:              break          end -= 1        num = num[:end]      return num    def x_add(a, b):      size_a = len(a)      size_b = len(b)      carry = 0        # 確保 a 是兩個加數較大的,較大指的是元素的個數      if size_a < size_b:          size_a, size_b = size_b, size_a          a, b = b, a        z = [0] * (size_a + 1)      i = 0      while i < size_b:          carry += a[i] + b[i]          z[i] = carry % PyLong_SHIFT          carry //= PyLong_SHIFT          i += 1        while i < size_a:          carry += a[i]          z[i] = carry % PyLong_SHIFT          carry //= PyLong_SHIFT          i += 1      z[i] = carry        # 去掉多餘的空間,數組長度調整至正確的數量      z = long_normalize(z)        return z    def x_mul(a, b):      size_a = len(a)      size_b = len(b)      z = [0] * (size_a + size_b)        for i in range(size_b):          carry = 0          f = b[i]            # 創建一個臨時變數          temp = [0] * (size_a + size_b)          pz = i  # 元素計算結果從 i 索引開始保存          for j in range(size_a):              carry += f * a[j]              temp[pz] = carry % PyLong_SHIFT              carry //= PyLong_SHIFT              pz += 1            if carry:              carry += temp[pz]              temp[pz] = carry % PyLong_SHIFT              carry //= PyLong_SHIFT              pz += 1            if carry:              temp[pz] += carry % PyLong_SHIFT          temp = long_normalize(temp)          z = x_add(z, temp)        return z    a = [543, 934, 23]  b = [632, 454]  print(x_add(a, b))  print(x_mul(a, b))