FastAPI를 이용한 머신러닝 모델 REST API 서빙
학습 목표 매핑
SKALA 3기 Module 7 — 데이터 분석 Mini-project (Learning Objective 7-2)
- Objective: 훈련된 모델을 FastAPI로 감싸서 REST API 엔드포인트로 제공하고 / 모듈 4 FastAPI 기술 사용 / Swagger 자동 문서가 생성되고 curl/Postman으로 테스트 가능 (Bloom L6-Create)
- Evaluation: API 동작 검증
FastAPI란?
정의: Python 기반 고성능 웹 프레임워크
- 타입 힌트 기반 자동 검증
- 자동 OpenAPI(Swagger) 문서 생성
- 동기/비동기 동시 지원
- 최소 보일러플레이트
업계 채택:
- 2021년 21% → 2023년 29%로 성장
- 데이터 사이언티스트 선호도 31%
- Microsoft, Uber, Netflix 사용
ML 모델 서빙 아키텍처
[Trained Model (model.pkl)]
↓
[FastAPI Application]
↓
[REST API Endpoints]
↓
[Client (curl, Postman, Web App)]
단계 1: 모델 훈련 및 직렬화
# train_model.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
import joblib
# 데이터 로드 및 모델 훈련
iris = load_iris()
X = iris.data
y = iris.target
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X, y)
# 메타데이터와 함께 저장
joblib.dump({
"model": clf,
"target_names": iris.target_names,
"feature_names": iris.feature_names,
}, "models/iris_model.joblib")
print("✅ 모델 저장 완료!")단계 2: Pydantic으로 입력 스키마 정의
# schemas.py
from pydantic import BaseModel, Field
from typing import List
class IrisFeatures(BaseModel):
"""붓꽃 분류 입력 특성"""
sepal_length: float = Field(..., gt=0, description="꽃받침 길이 (cm)")
sepal_width: float = Field(..., gt=0, description="꽃받침 너비 (cm)")
petal_length: float = Field(..., gt=0, description="꽃잎 길이 (cm)")
petal_width: float = Field(..., gt=0, description="꽃잎 너비 (cm)")
class PredictionResponse(BaseModel):
"""모델 예측 응답"""
prediction: str
probability: List[float]
confidence: float단계 3: FastAPI 애플리케이션 생성
기본 구조
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import joblib
import numpy as np
from schemas import IrisFeatures, PredictionResponse
app = FastAPI(
title="붓꽃 분류 API",
description="붓꽃 종 분류 REST API",
version="1.0.0"
)
# CORS 미들웨어 (웹 프론트엔드 통합)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 글로벌 모델 저장소
model_artifact = None
@app.on_event("startup")
def load_model():
"""애플리케이션 시작 시 모델 로드"""
global model_artifact
model_artifact = joblib.load("models/iris_model.joblib")
print("✅ 모델 로드 완료!")
@app.get("/health")
def health_check():
"""헬스 체크 엔드포인트"""
return {
"status": "healthy",
"model_loaded": model_artifact is not None
}
@app.post("/predict")
def predict(features: IrisFeatures) -> dict:
"""
붓꽃 종 예측
Args:
features: IrisFeatures 모델 인스턴스
Returns:
- prediction: 예측 종명 (setosa, versicolor, virginica)
- probability: 각 종의 확률
- confidence: 최고 확률
"""
if model_artifact is None:
raise HTTPException(
status_code=503,
detail="모델이 로드되지 않았습니다"
)
# 특성 배열 구성 (올바른 순서 중요!)
X = np.array([[
features.sepal_length,
features.sepal_width,
features.petal_length,
features.petal_width
]])
# 예측 수행
model = model_artifact["model"]
prediction_idx = model.predict(X)[0]
probabilities = model.predict_proba(X)[0]
return {
"prediction": model_artifact["target_names"][prediction_idx],
"probabilities": {
model_artifact["target_names"][i]: float(p)
for i, p in enumerate(probabilities)
},
"confidence": float(probabilities[prediction_idx])
}
@app.post("/predict_batch")
def predict_batch(features_list: list[IrisFeatures]) -> dict:
"""배치 예측 (여러 샘플 동시 처리)"""
X = np.array([[
f.sepal_length, f.sepal_width,
f.petal_length, f.petal_width
] for f in features_list])
model = model_artifact["model"]
predictions = model.predict(X)
return {
"predictions": [
model_artifact["target_names"][idx]
for idx in predictions
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)단계 4: 로컬 테스트
서버 실행
# 방법 1: 직접 실행
python main.py
# 방법 2: Uvicorn 사용
uvicorn main:app --reload
# 결과
# Uvicorn running on http://127.0.0.1:8000엔드포인트 테스트
cURL로 테스트
# 단일 예측
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
# 응답
{
"prediction": "setosa",
"probabilities": {
"setosa": 0.98,
"versicolor": 0.02,
"virginica": 0.0
},
"confidence": 0.98
}Swagger UI에서 테스트
Browser: http://localhost:8000/docs
Features:
- 🎯 Interactive endpoint testing
- 📋 Parameter descriptions
- 📊 Request/response schemas
- ✅ Real-time validation
Python 클라이언트
# client.py
import requests
url = "http://localhost:8000/predict"
data = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
response = requests.post(url, json=data)
print(response.json())프로덕션 배포 고려사항
1. 서버 구성 (Gunicorn + Uvicorn)
개발 환경:
uvicorn main:app --reload프로덕션 환경:
# Gunicorn 설치
pip install gunicorn uvicorn
# 4개 워커 프로세스 실행
gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app| 옵션 | 설명 |
|---|---|
-w 4 | 4개의 워커 프로세스 (CPU 코어 수에 맞추기) |
-k | Uvicorn 워커 클래스 지정 |
--bind 0.0.0.0:8000 | 포트 지정 |
2. 환경 변수 관리
from dotenv import load_dotenv
import os
load_dotenv()
MODEL_PATH = os.getenv("MODEL_PATH", "models/iris_model.joblib")
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").env 파일:
MODEL_PATH=models/iris_model.joblib
ENVIRONMENT=production
LOG_LEVEL=INFO
API_KEY=your-secret-key
3. 에러 처리
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return JSONResponse(
status_code=400,
content={"detail": "Invalid input", "errors": str(exc)},
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)4. 로깅
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
logger.info(f"Response: {response.status_code}")
return response5. 모델 버전 관리
from enum import Enum
class ModelVersion(str, Enum):
v1 = "iris_model_v1.joblib"
v2 = "iris_model_v2.joblib"
@app.get("/model/version")
def get_model_version():
return {"version": "1.0", "model_file": "iris_model_v1"}
@app.post("/predict/{model_version}")
def predict_with_version(
model_version: ModelVersion,
features: IrisFeatures
):
artifact = joblib.load(f"models/{model_version.value}")
# ... 예측 로직API 자동 문서화
Swagger UI (/docs)
http://localhost:8000/docs
기능:
- Interactive parameter testing
- Schema visualization
- Try it out 기능
ReDoc API Documentation (/redoc)
http://localhost:8000/redoc
기능:
- 깨끗한 API 문서 스타일
- 복잡한 스키마에 최적화
성능 메트릭
| 메트릭 | 값 |
|---|---|
| Throughput | 초당 1000+ 요청 |
| Latency | ~10-50ms (모델 복잡도에 따라) |
| Scalability | Gunicorn 워커로 수평 확장 |
배포 플랫폼
| 플랫폼 | 방식 | 난이도 |
|---|---|---|
| Heroku | git push로 자동 배포 | ⭐⭐ |
| AWS EC2 | Systemd service로 운영 | ⭐⭐⭐ |
| Google Cloud Run | Serverless 컨테이너 | ⭐⭐ |
| Docker | 컨테이너 이미지 (Module 7-3) | ⭐⭐⭐ |
Module 7 실전 체크리스트
- 모델 훈련 및 joblib로 직렬화
- Pydantic BaseModel로 입력 스키마 정의
- FastAPI 애플리케이션 생성
-
@app.on_event("startup")에서 모델 로드 -
/predict엔드포인트 구현 -
/health헬스 체크 엔드포인트 추가 - CORS 미들웨어 설정
- Swagger UI (
/docs)에서 테스트 - cURL 또는 Postman으로 테스트
- 에러 처리 및 로깅 추가
- requirements.txt에 fastapi, uvicorn 추가
Best Practices 요약
- ✅ Pydantic으로 입력 검증
- ✅
@app.on_event("startup")에서 모델 로드 - ✅
/health헬스 체크 엔드포인트 - ✅ 상세한 docstring 작성
- ✅ 프로덕션은 Gunicorn 사용
- ✅ CORS 미들웨어 추가 (웹 통합)
- ✅ 적절한 에러 처리
- ✅ 모든 요청·응답 로깅
- ✅ API 버전 관리
- ✅ 자동
/docs문서 활용
참고 자료
- FastAPI 공식: https://fastapi.tiangolo.com/
- Pydantic: https://docs.pydantic.dev/
- Gunicorn: https://gunicorn.org/
- Uvicorn: https://www.uvicorn.org/
타 소스와의 연계
end-to-end-data-science-project (전체 프로젝트 구조 - 7-1) docker-ml-containerization (Docker 배포 - 7-3) github-documentation-standards (GitHub 문서화 - 7-4) python-asyncio-daleseo (비동기 프로그래밍) pydantic-validation-velog (Pydantic 데이터 검증)