找出二值图像的所有连通域,面试官给了半个小时,其实还是很充分的,大概在规定时间内做出了这道题。发个blog稍微做个记录。分别用python和C++写出这道题。
详细描述一下这道题:
输入是一张二值化的矩阵,只含有0和255两种取值。求255的数值组成的连通域,规则就和围棋连通的规则一样,这道题像是在找围棋里白子的连通域。输出形式简单明了,比如这张图里有诸多连通域,那么对每个连通域进行编号,第一个连通域取值全部为1,第二个连通域取值为2,以此类推。
比如,输入为:
[[0 , 255, 0 , 0, 0 , 255],
[255, 255, 0 , 0, 255, 0 ],
[0 , 255, 255, 0, 255, 0 ]]
则要求的输出为:
[[0 1 0 0 0 2]
[1 1 0 0 3 0]
[0 1 1 0 3 0]]
我把这道题当作动态规划问题来做,首先要判断一个值为255的pixel连通哪些255的pixel,依次找到连通的所有pixel。就类似于一种蔓延。当时,为了求快,首选的语言还是python:
import numpy as np
def islink(a, b, arr):
"""
to get a list of neighbor 255-value pixel's position
"""
ls = []
m, n = arr.shape
left = max(0, b - 1) # left edge to prevent overflow
right = min(n - 1, b + 1) # right edge to prevent overflow
top = max(0, a - 1) # top edge to prevent overflow
btm = min(m - 1, a + 1) # bottom edge to prevent overflow
if arr[a][left] == 255:
ls.append([a, left])
if arr[a][right] == 255:
ls.append([a, right])
if arr[top][b] == 255:
ls.append([top, b])
if arr[btm][b] == 255:
ls.append([btm, b])
return ls
def sol(arr):
m, n = arr.shape
pointer = 0 # initialize the link-area value
for i in range(m):
for j in range(n):
if arr[i][j] == 255:
if pointer == 255.5:
pointer = int(pointer + 0.5)
else:
pointer += 1 # change link-area value
if pointer == 255: # prevent link-area value equal to pixel value
pointer += 0.5
arr[i][j] = pointer
ls = islink(i, j, arr)
while True:
if len(ls) == 0:
break
temp_ls = []
for k in range(len(ls)):
c = ls[k][0]
r = ls[k][1]
arr[c][r] = pointer
temp_ls = temp_ls + islink(c, r, arr)
ls = temp_ls
arr[arr == 255.5] = 255 # restore the true value of the 255th link-area
return arr
def run():
arr = np.array([[0, 255, 0, 0, 0, 255],
[255, 255, 0, 0, 255, 0],
[0, 255, 255, 0, 255, 0]])
print(sol(arr))
run()
今天手写的代码复现到IDE上,居然跑通了。不过之前没考虑到link-area == 255的情况,一个corner case被面试官指出来,所以现在加了个判断条件,遇到255的时候自动加0.5变成255.5,然后在返回矩阵之前修改回255。这样就解决了corner的情况。面试官建议的方法是新建一个矩阵,把像素设为负值,就不会存在link-area value等于pixel value的情况。其实解决办法还是很多啦
输出结果为:
[[0 1 0 0 0 2]
[1 1 0 0 3 0]
[0 1 1 0 3 0]]
C++版本:
#include<iostream>
#include<string>
#include<vector>
using namespace std;
vector<vector<int>> islink(int a, int b, vector<vector<int>> arr) {
int m = arr.size();
int n = arr[0].size();
vector<vector<int>> ls;
int left = (0 > b - 1) ? 0 : b - 1;
int right = (n - 1 < b + 1) ? n - 1 : b + 1;
int top = (0 > a - 1) ? 0 : a - 1;
int btn = (m - 1 < a + 1) ? m - 1 : a + 1;
if (arr[a][left] == 255) {
vector<int> temp;
temp.push_back(a);
temp.push_back(left);
ls.push_back(temp);
}
if (arr[a][right] == 255) {
vector<int> temp;
temp.push_back(a);
temp.push_back(right);
ls.push_back(temp);
}
if (arr[top][b] == 255) {
vector<int> temp;
temp.push_back(top);
temp.push_back(b);
ls.push_back(temp);
}
if (arr[btn][b] == 255) {
vector<int> temp;
temp.push_back(btn);
temp.push_back(b);
ls.push_back(temp);
}
return ls;
}
vector<vector<int>> sol(vector<vector<int>> arr) {
int m = arr.size();
int n = arr[0].size();
int pointer = 0;
vector<vector<int>> ls;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
if (arr[i][j] == 255) {
pointer++;
arr[i][j] = pointer;
ls = islink(i, j, arr);
while (1) {
if (!ls.size()) break;
vector<vector<int>> temp_ls;
for (int k = 0; k < ls.size(); k++) {
int r = ls[k][0];
int c = ls[k][1];
arr[r][c] = pointer;
vector<vector<int>> tms = islink(r, c, arr);
for (int z = 0; z < tms.size(); z++) {
temp_ls.push_back(tms[z]);
};
}
ls = temp_ls;
}
}
}
}
return arr;
}
void print_vector(vector<vector<int>> arr) {
int m = arr.size();
int n = arr[0].size();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
cout << arr[i][j] << "\t";
}
cout << "\n";
}
}
int main() {
vector<vector<int>> arr;
arr = {
{ 0, 255, 0, 0, 255},
{255, 255, 0,255, 0},
{255, 0 ,255,255, 0}
};
arr = sol(arr);
print_vector(arr);
system("pause");
return 0;
}
C++的代码量明显比Python长,即使是同样的逻辑。。