Streamlit cache_resource 사용하기

Streamlit cache_resource 사용하기

Streamlit
Streamlit cache_resource 사용하기
Author

gabriel yang

Published

September 17, 2024


Streamlit은 데이터 사이언스와 머신러닝 애플리케이션을 빠르게 시각화할 수 있는 강력한 도구입니다. 특히 Streamlit은 자원을 효율적으로 관리하고, 동일한 계산을 반복하지 않도록 캐싱 기능을 제공합니다. 이 중 st.cache_resource는 주로 리소스를 효율적으로 재사용하는 데 사용됩니다. 이 글에서는 st.cache_resource의 개념과 사용법을 간단한 예시를 통해 설명하겠습니다.

st.cache_resource란?

st.cache_resource는 복잡한 계산 또는 시간이 오래 걸리는 리소스를 생성하고 이를 캐싱해, 동일한 리소스를 여러 번 생성하지 않고 재사용할 수 있도록 돕는 기능입니다. 예를 들어, 데이터베이스 연결이나 모델 로드와 같은 작업은 초기화에 시간이 걸리므로, 이를 캐싱해 동일한 작업이 반복적으로 수행되지 않도록 최적화할 수 있습니다.

이 기능은 다음과 같은 경우에 유용합니다:

  • 데이터베이스 연결: 매번 새로운 연결을 생성하는 대신 동일한 연결을 재사용.
  • 머신러닝 모델 로드: 모델을 여러 번 로드하지 않고 한 번 로드된 모델을 사용.
  • API 클라이언트 초기화: 같은 API 클라이언트를 계속 재사용.

st.cache_resource 기본 문법

@st.cache_resource
def expensive_initialization():
    # 리소스 초기화 작업
    return resource

이 데코레이터를 사용하면 함수가 최초로 호출될 때만 리소스를 초기화하고, 이후에는 캐시된 리소스를 사용하게 됩니다.

예시 1: 데이터베이스 연결 캐싱

먼저 간단한 예시로, 데이터베이스 연결을 설정해 보겠습니다. 보통 데이터베이스 연결은 한 번 설정하면 프로그램이 종료될 때까지 재사용하는 것이 좋습니다.

import streamlit as st
import sqlite3

@st.cache_resource
def create_connection():
    # 데이터베이스에 연결하는 작업
    conn = sqlite3.connect('my_database.db')
    return conn

def get_data():
    conn = create_connection()
    cur = conn.cursor()
    cur.execute("SELECT * FROM my_table")
    data = cur.fetchall()
    return data

st.write("데이터베이스에서 데이터를 불러옵니다...")
data = get_data()
st.write(data)

설명:

  • create_connection 함수는 SQLite 데이터베이스와 연결을 설정하는 함수입니다.
  • @st.cache_resource를 사용하여 이 연결을 캐싱합니다. 즉, 앱이 실행되는 동안 데이터베이스 연결을 한 번만 수행하고, 이후부터는 캐시된 연결을 재사용합니다.

이렇게 하면, 데이터베이스에 반복적으로 연결하지 않고 효율적으로 리소스를 사용할 수 있습니다.

예시 2: 머신러닝 모델 로드 캐싱

머신러닝 모델을 로드하는 것도 시간이 걸리는 작업입니다. st.cache_resource를 사용하면 모델을 한 번만 로드하고, 나머지 호출에서는 캐시된 모델을 사용합니다.

import streamlit as st
from sklearn.ensemble import RandomForestClassifier
import joblib

@st.cache_resource
def load_model():
    # 미리 학습된 모델을 로드하는 작업
    model = joblib.load('my_model.pkl')
    return model

def predict(input_data):
    model = load_model()
    return model.predict(input_data)

# 예시 입력 데이터
input_data = [[5.1, 3.5, 1.4, 0.2]]
prediction = predict(input_data)

st.write("예측 결과:", prediction)

설명:

  • load_model 함수는 사전 학습된 모델을 파일에서 불러오는 작업을 담당합니다.
  • @st.cache_resource를 사용하여 모델을 한 번만 로드하고 이후에는 캐시된 모델을 재사용하여 성능을 향상시킵니다.

예시 3: API 클라이언트 캐싱

여러 번 호출될 수 있는 API 클라이언트를 초기화할 때도 st.cache_resource가 유용합니다.

import streamlit as st
import requests

@st.cache_resource
def get_api_client():
    # API 클라이언트를 생성하는 작업
    session = requests.Session()
    return session

def fetch_data():
    client = get_api_client()
    response = client.get('https://api.example.com/data')
    return response.json()

st.write("API에서 데이터를 불러옵니다...")
api_data = fetch_data()
st.write(api_data)

설명:

  • get_api_client 함수는 requests.Session 객체를 생성하여 API 요청을 관리합니다.
  • @st.cache_resource를 사용하여 클라이언트를 한 번만 생성하고 재사용함으로써, API 요청을 효율적으로 처리합니다.

st.cache_resource vs st.cache_data

Streamlit에는 st.cache_data라는 또 다른 캐싱 데코레이터가 있습니다. 둘의 차이는 다음과 같습니다:

  • st.cache_data: 데이터를 캐싱하고, 주로 계산된 값(데이터)을 다시 계산하지 않도록 저장합니다. (ex 데이터프레임 변환 결과, 대용량 데이터 다운로드, 복잡한 계산 결과)
  • st.cache_resource: 리소스를 캐싱하고, 연결 객체나 머신러닝 모델과 같은 시스템 자원을 여러 번 생성하지 않도록 관리합니다. 주로 연결 객체, 머신러닝 모델, API 클라이언트와 같은 상태가 있는 리소스를 캐싱합니다. 이러한 리소스는 주로 초기화에 시간이 오래 걸리거나, 한 번 생성되면 재사용되는 경우가 많습니다.

따라서, 데이터 처리에는 st.cache_data를, 리소스(연결, 클라이언트, 모델) 관리에는 st.cache_resource를 사용하는 것이 좋습니다.

마무리

Streamlit의 st.cache_resource를 사용하면 데이터베이스 연결, 모델 로드, API 클라이언트와 같은 시스템 리소스를 효율적으로 관리할 수 있습니다. 이 기능을 통해 불필요한 리소스 재생성을 방지하고 애플리케이션의 성능을 크게 향상시킬 수 있습니다.