【演算法框架套路】最長公共子序列

需求

給定兩個字元串 text1 和 text2,返回這兩個字元串的最長 公共子序列 的長度。如果不存在 公共子序列 ,返回 0 。
輸入:text1 = “abcde”, text2 = “ace”
輸出:3
解釋:最長公共子序列是 “ace”,它的長度為 3。

題目這樣描述看著比較沒意思,因為比較抽象,一般人不知道用來幹嘛的,換個現實說法

尋找劉德化和梁朝偉的最長公共女粉絲

子序列和子串有什麼區別?
子串要連續,子序列可以不連續。比如

a=hellowolrd
b=loop

最長子串是lo,最長子序列是loo

思路

遇到這樣的題,我一般都是這樣的做法

  1. 先暴力破解:窮舉
  2. 更高效地窮舉
  3. 更高效地窮舉+備忘錄
  4. 動態規劃

下面演示一下這種層層推進的過程,以chenqionghexsfz和cqhxsfz為例,兩者的最長公共子序列是cqhxsfz,返回的結果是7.

1. dfs暴力收集所有解,再計算出最大解

用的是回溯套路,可以參考【演算法框架套路】回溯演算法(暴力窮舉的藝術)

這裡就是從頭到尾窮舉,遇到相同的字元串,就加入到公共子串的track數組,到頭了將子串收集到res_list中。

import copy


def long_common_subsequence_all(str1, str2):
    len1, len2 = len(str1), len(str2)
    res_list = []
    lcs = ""

    def dp(i, j, track1, track2):
        if i == len1 or j == len2:
            nonlocal lcs
            cs = "".join(track1)
            res_list.append(cs)  # 到頭了,收集一下公共子序列
            if len(cs) >= len(lcs):
                lcs = cs  # 更新最大子序列
            return

        c_track1 = copy.copy(track1)
        c_track2 = copy.copy(track2)

        if str1[i] == str2[j]:
            # 找到一個lcs中的元素,str1和str2分別選中,繼續往下找
            c_track1.append(str1[i])
            c_track2.append(str2[j])
            dp(i + 1, j + 1, c_track1, c_track2)
            return
        else:
            dp(i, j + 1, c_track1, c_track2)
            dp(i + 1, j, c_track1, c_track2)

    dp(0, 0, [], [])
    return lcs, res_list


s1 = "chenqionghexsfz"
s2 = "cqhxsfz"
lcs, res_list = long_common_subsequence_all(s1, s2)
print(res_list)
print(lcs)

res_list是窮舉所有的公共子串

結果如下

image

2. dfs暴力只收集最大解

這和上次不同,我們從末尾開始遞歸
s1[0:i]和s2[0:j]的最長公共子串
如果s1[i]和sj[j]相同,最長公共子串,肯定是等於s1[i-1]和s[j-1]的結果+1

這樣的方式,肯定比窮舉所有的要好一點,程式碼如下

# dp定義:返回text1[0:i]和text2[0:j]的lcs
def long_common_subsequence_all(text1, text2):
    def dp(i, j):
        if i == -1 or j == -1:
            return 0
        if text1[i] == text2[j]:
            return dp(i - 1, j - 1) + 1
        else:
            return max(dp(i - 1, j), dp(i, j - 1))  # i和j不相同,分別再對比s1[i-1],s2[j]和s[i],s2[j-1]

    return dp(len(text1) - 1, len(text2) - 1)


s1 = "chenqionghexsfz"
s2 = "cqhxsfz"
lcs_len = long_common_subsequence_all(s1, s2)
print(lcs_len)

運行輸出如下
image

這裡只給出了長度,程式碼較少。
如果想知道子串,也可以依照上面的track數組,這樣寫

def long_common_subsequence_all(str1, str2):
    lcs = ""
    # 定義dp:返回str1[0:i]和str2[0:j]的lcs
    def dp(i, j, track1, track2):
        nonlocal lcs
        if i == -1 or j == -1:
            # 到頭了,更新最大的結果
            cs = "".join(track1)
            if len(cs) > len(lcs):
                lcs = cs
            return 0

        c_track1 = copy.copy(track1)
        c_track2 = copy.copy(track2)
        if str1[i] == str2[j]:
            # 找到一個lcs中的元素,str1和str2分別選中,繼續往下找
            c_track1.insert(0, str1[i])
            c_track2.insert(0, str2[j])
            return dp(i - 1, j - 1, c_track1, c_track2) + 1

        else:
            # i和j不相同,分別再對比s1[i-1],s2[j]和s[i],s2[j-1]
            return max(dp(i - 1, j, c_track1, c_track2), dp(i, j - 1, c_track1, c_track2))

    lcs_len = dp(len(str1) - 1, len(str2) - 1, [], [])
    return lcs, lcs_len


s1 = "chenqionghexsfz"
s2 = "cqhxsfz"
lcs, lcs_len = long_common_subsequence_all(s1, s2)
print(lcs, lcs_len)

運行輸出
image

3. dfs暴力只收集最大解+備忘錄

上面會發生一些重複操作,
比如

dp(3,5) = dp(2,4)+1
dp(2,5) = dp(2,4)+1

那麼dp(2,4)會被重複計算,我們需要將已經計算出來的結果快取起來
程式碼如下

def long_common_subsequence(text1, text2):
    memo = {}

    def dp(i, j):
        if (i, j) in memo:
            return memo[(i, j)]
        if i == -1 or j == -1:
            return 0
        if text1[i] == text2[j]:
            return dp(i - 1, j - 1) + 1
        else:
            memo[(i, j)] = max(dp(i - 1, j), dp(i, j - 1))  # i和j不相同,分別再對比s1[i-1],s2[j]和s[i],s2[j-1]
            return memo[(i, j)]

    return dp(len(text1) - 1, len(text2) - 1)


s1 = "chenqionghexsfz"
s2 = "cqhxsfz"
lcs = long_common_subsequence(s1, s2)
print(lcs)

運行輸出
image

4. dp動態規劃

dp(i,j)是返回text1,text2的最大公共子串大小。

dp[i][j]也是返回text1,text2的最大公共子串大小,只是反著來

實現如下

# dp定義:返回text1[0:i]和text2[0:j]的lcs
def long_common_subsequence(text1, text2):
    len1, len2 = len(text1), len(text2)
    dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            # 找到一個公共字元串
            if text1[i - 1] == text2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
    return dp[-1][-1]


s1 = "chenqionghexsfz"
s2 = "cqhxsfz"
lcs = long_common_subsequence(s1, s2)
print(lcs)

運行輸出
image

Tags: