题解
这题其实分成了两个部分:
- DP:把大问题分割成小问题
- 判断循环节:这样才能知道某段字符串可不可以缩
DP/问题切割
给定一段字符串s,如果存在某个中间位置k,把字符串分割成2部分s[:k]
和s[k:]
:
- 如果
s[:k]
由循环节组成,假设已知它可以缩成substr
,那么问题就变成了F(s)=substr+F(s[k:])F(s) = \text{substr} + F(s[k:])F(s)=substr+F(s[k:]) - 类似地,如果
s[k:]
由循环节组成,同样假设它可以被缩成substr
,那么问题就变成了F(s)=F(s[:k])+substrF(s) = F(s[:k])+\text{substr}F(s)=F(s[:k])+substr
两种拆分方式是等价的,选一个判断即可。
我们需要遍历所有的k (from 1 to len(s))
,找到使F(s)
最小的那个。
判断循环节
暴力判断
枚举循环节长度,判断有没有该长度的循环节
KMP
需要利用KMP算法里的next数组(计算方式请自行搜索),原理:
注意点
- 如果选择用KMP方法,当要判断循环节的字符串的起始位置相同时,可以利用同一个next数组,因此可以配合DP里的第一种拆分方式使用,当k不断变化时,
s[:k]
的起始位置总是相同的,因此可以复用同一个next数组。 - 有一个剪枝tip,长度<=4的字符串肯定缩了长度也不会变小,所以可以不用判断它是不是存在循环节。
参考解法
DP
令f[i][j]
表示s[i:j]
的最短表示方式,那么我们需要遍历所有的k from i+1 to j
,判断s[i:k]
可不可以缩。
next数组没啥好说的,需要注意next[x]=y
到底意味着哪段子串与哪段子串相等,容易搞混。
class Solution:
def encode(self, s: str) -> str:
n = len(s)
f = [['' for _ in range(n+1)] for _ in range(n+1)]
nxt = [-1] * (n+1)
for i in range(n-1, -1, -1):
# get next arr
nxt[i] = -1
for j in range(i+1, n+1):
# ptr = nxt[j]
# s[i,...,i+ptr-1] = s[j-ptr,...,j-1]
ptr = nxt[j-1]
while ptr >= 0 and s[i+ptr] != s[j-1]:
ptr = nxt[ptr+i]
nxt[j] = ptr+1
for j in range(i+1, n+1):
# s[i:j]'s minimized version, f[i][j]
# 由于k<=i+4的时候,s[i:k]肯定不需要缩,那结果就跟k=1的时候是一样的,所以下面的for循环进行了剪枝
# k=i+1时的结果
res = s[i] + f[i+1][j]
# 遍历k>i+4时的结果
for k in range(j, i+4, -1):
# split into s[i:k] and s[k:j]
ptr = nxt[k] # s[i...i+ptr-1] = s[k-ptr,..., k-1]
unit = k-ptr-i
L = k-i
if L % unit == 0:
cnt = L // unit
if cnt > 1:
substr1 = f'{cnt}[{f[i][k-ptr]}]'
substr = substr1 if len(substr1) < L else s[i:k]
if len(substr) + len(f[k][j]) < len(res):
res = substr + f[k][j]
f[i][j] = res
# print(f"f[{i}][{j}]={s[i:j]}={res}")
return f[0][n]
记忆化搜索
再次觉得,DP和记忆化搜索其实本质上是差不多的东西……
同样也采取第一种问题划分方式,遍历k,判断s[:k]
是否可以缩。
class Solution:
def encode(self, s: str) -> str:
@lru_cache()
def _encode(s):
n = len(s)
# 神奇,这里如果改成n < 5会慢几十毫秒,也许是判断等于比较快
if n == 0:
return s
nxt = [-1]
for i, c in enumerate(s):
# nxt[i+1]
# s[0...j-1] = s[i-j+1 ... i]
j = nxt[-1]
while j>=0 and s[j] != c:
j = nxt[j]
nxt.append(j+1)
# k<=4时, 都等价于k=1时
res = s[0] + _encode(s[1:])
# k>4时
for k in range(n, 4, -1):
# s[0:k]
# s[0...j-1] = s[k-j ... k-1]
j = nxt[k]
unit = k-j
L = k
if L % unit == 0:
cnt = L // unit
if cnt > 1:
substr1 = f"{cnt}[{_encode(s[0:k-j])}]"
substr = substr1 if len(substr1) < L else s[:k]
other = _encode(s[k:])
if len(res) > len(substr) + len(other):
res = substr + other
return res
return _encode(s)