我只想说。能在现场编出满分代码的是真的厉害。即便这只是CCF的第三道题。但是明显和前两道难度翻了指数幂倍。。
搞了几天终于搞出来了(还是参考别人的),由于模拟练习只有1s时间,所以超时了(90分),现场给的时间是10s
from operator import itemgetter
# 获取原始输入,[['1.0.0/8'],[...]]
def getInput():
n = int(input())
ip_list = []
for _ in range(n):
ip_list.append(input())
return n, ip_list
#标准化 [['1'],[..]] --> ['1.0.0.0/8',[..]] -->再根据‘.’和‘/’切割
def norm(ip_list: list): # [['1.0.0.0/8'],['1.6.6/23'],['101.6']]
standard_ip_list = []
for ip in ip_list:
if '/' in ip:
ip_str, ip_len = ip.split('/')
ip_len = int(ip_len)
else:
ip_str = ip
ip_len = (ip.count('.') + 1) * 8
tmp_list = [0] * 4
ip_str_list = ip_str.split('.')
tmp_list[0:len(ip_str_list)] = map(int, ip_str_list)
standard_ip_list.append([*tmp_list, ip_len])
return standard_ip_list
# 检查ip2 是否是 ip1的子集
def issubset(standard_ip1, standard_ip2):
n1 = 256 ** 3 * standard_ip1[0] + 256 ** 2 * standard_ip1[1] + 256 ** 1 * standard_ip1[2] + standard_ip1[3]
mask1 = standard_ip1[4]
n2 = 256 ** 3 * standard_ip2[0] + 256 ** 2 * standard_ip2[1] + 256 ** 1 * standard_ip2[2] + standard_ip2[3]
# mask2 = standard_ip2[4]
# 检查高位是否相同
if (n1 >> 32 - mask1) == (n2 >> 32 - mask1):
return True
else:
return False
# 第一次压缩
def refineLv1(standard_ip_list):
ip_list = standard_ip_list
idx = 0
while idx < len(ip_list) - 1:
# compress
if issubset(ip_list[idx], ip_list[idx + 1]):
ip_list.pop(idx + 1)
# print('out...')
else:
idx += 1
# 第二次压缩(同级合并)
def refineLv2(standerd_ip_list):
ip_list = standard_ip_list
idx = 0
while idx < len(ip_list) - 1:
# 相邻元素
if ip_list[idx][4] == ip_list[idx + 1][4] and ip_list[idx][4] - 1 >= 0:
t = ip_list[idx][:]
t[4] -= 1
if issubset(t, ip_list[idx]) and issubset(t, ip_list[idx + 1]):
ip_list.pop(idx)
ip_list.pop(idx) # 经过前面的pop,原来的idx+1位置的元素挪到了idx
ip_list.insert(idx, t)
if idx > 0: # 说明前面还有元素
idx -= 1
else:
idx += 1
else:
idx += 1
if __name__ == '__main__':
# 原始输入
n, ipList = getInput()
# 标准化
normed = norm(ipList)
# 排序
standard_ip_list = sorted(normed, key=itemgetter(0, 1, 2, 3, 4))
# 第一次合并
refineLv1(standard_ip_list)
# 第二次合并
refineLv2(standard_ip_list)
# 打印处理结果
for ele in standard_ip_list:
print(str(ele[0]) + '.' + str(ele[1]) + '.' + str(ele[2]) + '.' + str(ele[3]) + '/' + str(ele[4]))
# 3
# 1.0.0.0/8
# 1.6.6/23
# 101.6