1967번:트리의 지름(RecursionError 해결)

처음 접근

트리의 지름 구하는 공식에 따라 처음은 root인 1 부터 가장 긴 노드 A를 찾고 그 노드 A에서 가장 멀리 있는 노드 B를 찾아, 노드 A와 노드 B 사이의 path 거리를 구했다.

RecursionError 발생 원인

가장 멀리 있는 노드를 찾을 때 dfs를 recursion으로 구현하였는데
조건을 보면 노드의 개수 n(1 ≤ n ≤ 10,000) 이므로
최악의 경우 recursion이 10000번 일어나게 된다.
파이썬에서 recursion 수는 최대 998번으로 제한하고 있으므로 maximum recurion에 의해 RecursionError가 발생한 것이다.

해결 방법

recursion으로 구현한 dfs()를 Iteration(stack)으로 구현해주면 된다.

RecurionError가 발생한 코드 - dfs()를 recursion으로 구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import collections

n = int(input())

tree = [collections.defaultdict(int) for i in range(n+1)]
path = [0 for i in range(n+1)]
visited = [False for i in range(n+1)]
for _ in range(n-1):
p, c, w = map(int, input().split(' '))
tree[p][c] = w


def dfs(start, before_node, path_length):
if visited[start] is True:
return
visited[start] = True

if tree[start][before_node] != 0:
path_length += tree[start][before_node]

elif tree[before_node][start] != 0:
path_length += tree[before_node][start]

path[start] = path_length

for adj_v in tree[start]:
dfs(adj_v, start, path_length)


dfs(1, 1, 0)

for i in range(len(visited)):
visited[i] = False

max_path_length = 0
max_path_node = 0
for idx, p in enumerate(path):
if max_path_length < p:
max_path_length = p
max_path_node = idx

dfs(max_path_node, max_path_node, 0)


print(max(path))

Iteration(stack)으로 변경한 dfs()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def dfs(start, path_length):
stack.append((start, start, path_length))
while stack:
curr, before_node, pl = stack.pop()
if visited[curr] is False:
visited[curr] = True
if tree[curr][before_node] != 0:
pl += tree[curr][before_node]

elif tree[before_node][curr] != 0:
pl += tree[before_node][curr]

path[curr] = pl

for adj_v in tree[curr]:
stack.append((adj_v, curr, pl))

RecursionError를 해결한 코드 - dfs()를 Iteration(stack)으로 구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import collections

n = int(input())

tree = [collections.defaultdict(int) for i in range(n+1)]
path = [0 for i in range(n+1)]
visited = [False for i in range(n+1)]
for _ in range(n-1):
p, c, w = map(int, input().split(' '))
tree[p][c] = w

stack = []


def dfs(start, path_length):
stack.append((start, start, path_length))
while stack:
curr, before_node, pl = stack.pop()
if visited[curr] is False:
visited[curr] = True
if tree[curr][before_node] != 0:
pl += tree[curr][before_node]

elif tree[before_node][curr] != 0:
pl += tree[before_node][curr]

path[curr] = pl

for adj_v in tree[curr]:
stack.append((adj_v, curr, pl))


dfs(1, 0)

for i in range(len(visited)):
visited[i] = False

max_path_length = 0
max_path_node = 0
for idx, p in enumerate(path):
if max_path_length < p:
max_path_length = p
max_path_node = idx

dfs(max_path_node, 0)

print(max(path))