본문 바로가기
Python/python 응용: 수학 및 과학

Matplotlib을 이용해서 다중 그래프 그리기: subplot

by 철이88 2023. 4. 27.
반응형

파이썬의 라이브러리 모듈인 Matplotlib에는 여러 쌍의 데이터 셋을 한 번에 그릴 수 있는 subplot 기능이 있습니다. 이 기능을 이용하면 여러 데이터를 하나의 그림 안에 시각화하여 비교할 수 있어 특히 공대생들에게 유용합니다. 이번 글에서는 subplot의 사용법을 알아보겠습니다.

 

1. Matplotlib의 subplot 함수

Matplotlib은 파이썬에서 그래프를 그리기 위한 라이브러리입니다. 여기에는 여러 종류의 그래프를 그리기 위한 다양한 함수와 클래스가 포함되어 있습니다. Matplotlib은 NumPy와 함께 사용되며, 데이터를 그래프로 시각화하는데 매우 유용합니다. 

 

그리고 오늘의 주제인 subplot()은 Matplotlib에서 제공하는 함수 중 하나입니다. subplot() 함수는 하나의 그림에 여러 개의 하위 그래프를 만들 수 있도록 합니다.

subplot() 함수는 그림을 그리기 전에 먼저 호출되고, 여러 개의 그래프를 한 번에 그릴 수 있도록 AxesSubplot라고 하는 객체를 반환합니다. AxesSubplot 객체는 그림 내의 하위 요소(subplot)를 나타내는 객체입니다.

 

조금은 어렵겠지만, 다음 예문을 보고 이해해 보시기 바랍니다.

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 4))

위 문장은 그림을 생성하는 코드입니다. 

좌변을 보면 fig와 axes 두 객체가 있는 것을 알 수 있습니다.

 

여기서 fig는 Figure 클래스의 인스턴스로, 새로운 그래프 창을 생성하는 데 사용됩니다. 그리고 axes는 Axes 클래스의 인스턴스로, 그림 내에서 각각의 subplot을 나타내는 데 사용됩니다. axes 객체는 그래프를 그리는 기능을 담당하며, 다양한 메서드를 제공하여 그래프를 그리는 작업을 지원합니다.

 

그리고 오른쪽에는 plt.subplots() 함수가 있는데, 보통 다음과 같이 Matplotlib의 pyplot을 가져와서 사용하겠다고 선언하여 쓰는 것입니다.

import matplotlib.pyplot as plt

 

또한 subplot() 함수의 첫 번째 인자는 하위 그림의 행(row) 개수, 두 번째 인자는 열(column) 개수, 세 번째는 생성할 하위 그림의 인덱스를 의미합니다. 예를 들어, subplot(2, 3, 1)은 그림 내에 2개의 행과 3개의 열을 가진 subplot 중 첫 번째(subplot의 인덱스는 1부터 시작) subplot을 선택하게 됩니다.

 

2. subplot의 예제

다음은 subplot을 사용한 간단한 예시 코드입니다.

import numpy as np
import matplotlib.pyplot as plt

# 데이터 생성
x = np.linspace(0, 10, 100)
y1 = x
y2 = x**2
y3 = x**3

# 그림 생성: 1개의 행과 3개의 열로 구성
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 4))

# 첫 번째 subplot 그리기
axes[0].plot(x, y1)
axes[0].set_title('$y = x$') #제목

# 두 번째 subplot 그리기
axes[1].plot(x, y2)
axes[1].set_title('$y = x^{2}$') #제목

# 세 번째 subplot 그리기
axes[2].plot(x, y3)
axes[2].set_title('$y = x^{3}$') #제목

# subplot 간격 조절
fig.subplots_adjust(wspace=0.3)

# 전체 그림의 제목 설정
fig.suptitle('Subplots test')

# 그래프 출력
plt.show()

위 코드의 의미는 각 부분마다 설명하였습니다.

실행을 하여 얻은 그림은 아래와 같습니다.

 

3. 반복문 사용하여 다수의 그래프 한번에 그리기

한 번에 그리고자 하는 데이터 셋이 많을 때 반복문을 이용하면 간단합니다.

다음은 반복문을 사용해서 30개의 2차원 그래프를 한 번에 그리는 예제입니다.

import numpy as np
import matplotlib.pyplot as plt

# 5행 6열 형대로 30개의 그래프를 그림
fig, axes = plt.subplots(nrows=5, ncols=6, figsize=(10,10))

for i in range(30):
    input_data = np.random.randn(20, 2) #임의로 20개의 x,y 데이터 생성
    
    # 각각의 하위 그래프
    ax = axes.flat[i]
    ax.scatter(input_data[:,0], input_data[:,1])
    ax.set_title("data #" + str(i+1))

plt.tight_layout()  # 그래프간 간격 넣기
plt.show()

이 예제에서는 30개의 데이터 세트를 랜덤하게 만들었습니다.

 

그리고 주의할 점은 axes는 행과 열의 2차원 배열의 객체라는 것입니다.

위 예에서 for문의 i는 0부터 29까지의 숫자를 갖으며 하위 그림의 번호를 의미하도록 설계되어 있습니다 (i에 1을 더하면).

즉, i는 1차원이기때문에 2차원 객체인 axes의 인덱스에 바로 사용할 수 없고, 이런 경우는 위의 예처럼 axes.flat을 이용하면 됩니다. 위 코드로 그림을 그려보면 이해가 쉬울 것입니다.

반응형

댓글