해당 문제를 접근하기 전에, 누적합에 대해 먼저 서술해보자.
첫번째로, 문제의 유형이 "구간의 합" 혹은 "구간의 평균/최댓값" 등이 반복된다면 누적합을 검토한다.
위 문제를 예시로 보면, 정사각형의 배열에서 구간의 합을 구하라는 문제이므로, 누적합을 생각할 수 있다.
두번째로, 문제에서 주어진 구간의 합을 구하기 전에 전체 배열을 한 번 읽을 수 있는가를 확인해야 한다.
누적합을 적용하려면, 전체 배열의 합을 구할 수 있어야 유의미하기 때문에 전체 배열을 읽을 수 있는 지 확인해야 한다.
위 문제는 2차원 배열이지만, 우선 1차원 배열에서 접근하는 방식을 알고 있어야 한다.
다음과 같은 문제가 있다고 하자.
"길이가 N 인 배열 A 가 주어지고, M 개의 입력값을 받는다. 각 입력값 마다 구간 [L, R]의 합을 구해 출력한다."
이 문제를 직관적으로 풀어보면, 각 입력값마다 L ~ R 까지를 계산해서 출력하는 방법이 있다.
N, M = map(int, input().split())
A = list(map(int, input().split()))
for _ in range(M):
L, R = map(int, input().split())
cnt = 0
for i in range(L, R + 1):
cnt += A[i]
print(cnt)
이런 식으로 풀 수 있는데, 이렇게 되면 N 과 M 이 커질수록 반복 계산이 늘어나 시간이 오래 걸리는 문제가 생긴다.
그러면 이 반복 계산을 없애는 방법에 대해 고민해야 하는데, 이 때 누적합을 사용하면된다.
누적합의 핵심은 "미리 한번 계산한다" 이다.
배열의 첫 i 개 원소 합을 저장하는 보조 배열 P 를 하나 만든다면, 아래와 같고,
P[i] = A[1] + A[2] + ... + A[i]
이렇게 됬을 때, 구간 [L, R]의 합은
A[L] + A[L + 1] + ... + A[R] = (A[1] + ... + A[R]) - (A[1] + ... + A[L - 1]) = P[R] - P[L - 1]
이렇게 계산할 수 있다. 그러면 우리는 한 번 P 를 채우는 데에만 시간을 쓰고, 각 입력값에 대한 계산은 한번에 처리 가능한 구조가 된다.
이제 P 배열을 만드는 방법에 대해 생각해보자, 간단하게 앞에서부터 각각의 합을 구하는 점화식을 생각해보면, 아래와 같다.
P[i] = P[i - 1] + A[i]
작은 예제로 간단하게 계산해 보겠다.
- 배열 A = [3, 1, 4, 1, 5] (실제 코드에선 A[0] = 3 이지만, 이 문제에서는 예시이므로 A[1] = 3 이라 칭하겠다.)
- 누적합 P 계산
1. P[0] = 0
2. P[1] = P[0] + A[1] = 0 + 3 = 3
3. P[2] = P[1] + A[2] = 3 + 1 = 4
4. P[3] = P[2] + A[3] = 4 + 4 = 8
5. P[4] = P[3] + A[4] = 8 + 1 = 9
4. P[3] = P[4] + A[5] = 9 + 5 = 14
- 완성된 P = [0, 3, 4, 8, 9, 14]
- 입력값 예시
1. 구간 [2, 4] 합: A[2] + A[3] + A[4] = 1 + 4 + 1 = 6
-> P[4] - P[1] = 9 - 3 = 6
2. 구간 [1, 5] 합: P[5] - P[0] = 14 - 0 = 14
이제 이 1차원 누적합을 2차원으로 확장시켜 위 문제에 적용시켜보자.
우선 2차원 배열을 아래와 같이 표로 표현한다.
1 | 2 | 3 | 4 | |
1 | 1 | 2 | 3 | 4 |
2 | 2 | 3 | 4 | 5 |
3 | 3 | 4 | 5 | 6 |
4 | 4 | 5 | 6 | 7 |
그 후, 원본 배열과 같은 크기의 누적합 배열을 만든다. 이 때 누적합 배열에는 0번째 행과 0번째 열을 추가하여 모두 0으로 채운다.
0 | 1 | 2 | 3 | 4 | |
0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | ||||
2 | 0 | ||||
3 | 0 | ||||
4 | 0 |
이제 이 누적합 배열을 합쳐야 하는데, 그 원리는 다음과 같다.
(1, 1) ~ (i, j) 까지의 합은 다음 4 가지를 이용하여 계산할 수 있다.
A. (1, 1) ~ (i - 1, j) 까지 합
예를 들어, (1, 1) ~ (3, 3) 까지를 더한다면, (3, 3)의 위쪽까지의 합을 의미한다.
B. (1, 1) ~ (i, j - 1) 까지 합
마찬가지로, (1, 1) ~ (3, 3) 까지 더할 때, (3, 3)의 왼쪽까지의 합을 의미한다.
C. (1, 1) ~ (i - 1, j - 1) 중복된 영역의 합
A와 B 계산에서 중복적으로 계산된 합을 의미한다.
D. 현재 위치(i, j)의 원본 배열 값
예시에서의 (3, 3)을 의미한다.
(이 원리는 표를 함께 보면 더 이해하기 쉬울 것이다.)
최종적으로는 아래와 같은 식이 완성된다.
S[i][j] = A + B - C + D
= S[i - 1][j] + S[i][j - 1] - S[i - 1][j - 1] + A[i][j]
자, 이제 이 식을 근거로 위 누적합 배열을 원본 배열을 이용하여 채워보면, 아래 표와 같이 된다.
0 | 1 | 2 | 3 | 4 | |
0 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 1 | 3 | 6 | 10 |
2 | 0 | 3 | 8 | 15 | 24 |
3 | 0 | 6 | 15 | 27 | 42 |
4 | 0 | 10 | 24 | 42 | 64 |
예) S[2][3] = 15 는 (1, 1) ~ (2, 3) 영역의 합이다.
S[1][3] + S[2][2] - S[1][2] + A[2][3] = 6 + 8 - 3 + 4 = 15
이제 이를 이용해서 우리가 원하는 구간의 합을 구해보자. 이 원리를 이해했다면, 구간의 합을 구하는 공식도 쉽게 이해할 수 있다.
S[x2][y2] - S[x1 -1][y2] - S[x2][y1 - 1] + S[x1 - 1][y1 - 1] = 전체 - 위쪽 - 왼쪽 + 중복 제거된 것
예시로 (2, 2) ~ (3, 4) 까지의 합을 구한다고 하면, 42 - 10 - 6 + 1 = 27 이 나오게 된다.
(이 또한 표를 보고 왜 위쪽과 왼쪽을 빼줘야 하는 지 보면 이해할 수 있다.)
결론 코드
import sys
input = sys.stdin.readline
N, M = map(int, input().split())
A = [list(map(int, input().split())) for _ in range(N)]
S = [[0] * (N + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
for j in range(1, N + 1):
S[i][j] = S[i - 1][j] + S[i][j - 1] - S[i - 1][j - 1] + A[i - 1][j - 1]
out = []
for _ in range(M):
x1, y1, x2, y2 = map(int, input().split())
ans = S[x2][y2] - S[x1 - 1][y2] - S[x2][y1 - 1] + S[x1 - 1][y1 - 1]
out.append(str(ans))
sys.stdout.write("\n".join(out))
'BaekJoon' 카테고리의 다른 글
[Python] 17103. 골드바흐 파티션 (0) | 2025.04.15 |
---|---|
[Python] 1929.소수 구하기 (0) | 2025.04.12 |
[Python] 4134. 다음 소수 (0) | 2025.04.12 |
[Python] 2485. 가로수 (0) | 2025.04.10 |
[Python] 1735. 분수 합 (0) | 2025.04.09 |