FastAPI를 이용한 머신러닝 모델 REST API 서빙

학습 목표 매핑

SKALA 3기 Module 7 — 데이터 분석 Mini-project (Learning Objective 7-2)

  • Objective: 훈련된 모델을 FastAPI로 감싸서 REST API 엔드포인트로 제공하고 / 모듈 4 FastAPI 기술 사용 / Swagger 자동 문서가 생성되고 curl/Postman으로 테스트 가능 (Bloom L6-Create)
  • Evaluation: API 동작 검증

FastAPI란?

정의: Python 기반 고성능 웹 프레임워크

  • 타입 힌트 기반 자동 검증
  • 자동 OpenAPI(Swagger) 문서 생성
  • 동기/비동기 동시 지원
  • 최소 보일러플레이트

업계 채택:

  • 2021년 21% → 2023년 29%로 성장
  • 데이터 사이언티스트 선호도 31%
  • Microsoft, Uber, Netflix 사용

ML 모델 서빙 아키텍처

[Trained Model (model.pkl)]
           ↓
[FastAPI Application]
           ↓
[REST API Endpoints]
           ↓
[Client (curl, Postman, Web App)]

단계 1: 모델 훈련 및 직렬화

# train_model.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
import joblib
 
# 데이터 로드 및 모델 훈련
iris = load_iris()
X = iris.data
y = iris.target
 
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X, y)
 
# 메타데이터와 함께 저장
joblib.dump({
    "model": clf,
    "target_names": iris.target_names,
    "feature_names": iris.feature_names,
}, "models/iris_model.joblib")
 
print("✅ 모델 저장 완료!")

단계 2: Pydantic으로 입력 스키마 정의

# schemas.py
from pydantic import BaseModel, Field
from typing import List
 
class IrisFeatures(BaseModel):
    """붓꽃 분류 입력 특성"""
    sepal_length: float = Field(..., gt=0, description="꽃받침 길이 (cm)")
    sepal_width: float = Field(..., gt=0, description="꽃받침 너비 (cm)")
    petal_length: float = Field(..., gt=0, description="꽃잎 길이 (cm)")
    petal_width: float = Field(..., gt=0, description="꽃잎 너비 (cm)")
 
class PredictionResponse(BaseModel):
    """모델 예측 응답"""
    prediction: str
    probability: List[float]
    confidence: float

단계 3: FastAPI 애플리케이션 생성

기본 구조

# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import joblib
import numpy as np
from schemas import IrisFeatures, PredictionResponse
 
app = FastAPI(
    title="붓꽃 분류 API",
    description="붓꽃 종 분류 REST API",
    version="1.0.0"
)
 
# CORS 미들웨어 (웹 프론트엔드 통합)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
 
# 글로벌 모델 저장소
model_artifact = None
 
@app.on_event("startup")
def load_model():
    """애플리케이션 시작 시 모델 로드"""
    global model_artifact
    model_artifact = joblib.load("models/iris_model.joblib")
    print("✅ 모델 로드 완료!")
 
@app.get("/health")
def health_check():
    """헬스 체크 엔드포인트"""
    return {
        "status": "healthy",
        "model_loaded": model_artifact is not None
    }
 
@app.post("/predict")
def predict(features: IrisFeatures) -> dict:
    """
    붓꽃 종 예측
    
    Args:
        features: IrisFeatures 모델 인스턴스
    
    Returns:
        - prediction: 예측 종명 (setosa, versicolor, virginica)
        - probability: 각 종의 확률
        - confidence: 최고 확률
    """
    if model_artifact is None:
        raise HTTPException(
            status_code=503, 
            detail="모델이 로드되지 않았습니다"
        )
    
    # 특성 배열 구성 (올바른 순서 중요!)
    X = np.array([[
        features.sepal_length,
        features.sepal_width,
        features.petal_length,
        features.petal_width
    ]])
    
    # 예측 수행
    model = model_artifact["model"]
    prediction_idx = model.predict(X)[0]
    probabilities = model.predict_proba(X)[0]
    
    return {
        "prediction": model_artifact["target_names"][prediction_idx],
        "probabilities": {
            model_artifact["target_names"][i]: float(p) 
            for i, p in enumerate(probabilities)
        },
        "confidence": float(probabilities[prediction_idx])
    }
 
@app.post("/predict_batch")
def predict_batch(features_list: list[IrisFeatures]) -> dict:
    """배치 예측 (여러 샘플 동시 처리)"""
    X = np.array([[
        f.sepal_length, f.sepal_width, 
        f.petal_length, f.petal_width
    ] for f in features_list])
    
    model = model_artifact["model"]
    predictions = model.predict(X)
    
    return {
        "predictions": [
            model_artifact["target_names"][idx] 
            for idx in predictions
        ]
    }
 
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

단계 4: 로컬 테스트

서버 실행

# 방법 1: 직접 실행
python main.py
 
# 방법 2: Uvicorn 사용
uvicorn main:app --reload
 
# 결과
# Uvicorn running on http://127.0.0.1:8000

엔드포인트 테스트

cURL로 테스트

# 단일 예측
curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
 
# 응답
{
  "prediction": "setosa",
  "probabilities": {
    "setosa": 0.98,
    "versicolor": 0.02,
    "virginica": 0.0
  },
  "confidence": 0.98
}

Swagger UI에서 테스트

Browser: http://localhost:8000/docs

Features:

  • 🎯 Interactive endpoint testing
  • 📋 Parameter descriptions
  • 📊 Request/response schemas
  • ✅ Real-time validation

Python 클라이언트

# client.py
import requests
 
url = "http://localhost:8000/predict"
 
data = {
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
}
 
response = requests.post(url, json=data)
print(response.json())

프로덕션 배포 고려사항

1. 서버 구성 (Gunicorn + Uvicorn)

개발 환경:

uvicorn main:app --reload

프로덕션 환경:

# Gunicorn 설치
pip install gunicorn uvicorn
 
# 4개 워커 프로세스 실행
gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
옵션설명
-w 44개의 워커 프로세스 (CPU 코어 수에 맞추기)
-kUvicorn 워커 클래스 지정
--bind 0.0.0.0:8000포트 지정

2. 환경 변수 관리

from dotenv import load_dotenv
import os
 
load_dotenv()
 
MODEL_PATH = os.getenv("MODEL_PATH", "models/iris_model.joblib")
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

.env 파일:

MODEL_PATH=models/iris_model.joblib
ENVIRONMENT=production
LOG_LEVEL=INFO
API_KEY=your-secret-key

3. 에러 처리

from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
 
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
    return JSONResponse(
        status_code=400,
        content={"detail": "Invalid input", "errors": str(exc)},
    )
 
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
    return JSONResponse(
        status_code=500,
        content={"detail": "Internal server error"},
    )

4. 로깅

import logging
 
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
 
@app.middleware("http")
async def log_requests(request, call_next):
    logger.info(f"Request: {request.method} {request.url}")
    response = await call_next(request)
    logger.info(f"Response: {response.status_code}")
    return response

5. 모델 버전 관리

from enum import Enum
 
class ModelVersion(str, Enum):
    v1 = "iris_model_v1.joblib"
    v2 = "iris_model_v2.joblib"
 
@app.get("/model/version")
def get_model_version():
    return {"version": "1.0", "model_file": "iris_model_v1"}
 
@app.post("/predict/{model_version}")
def predict_with_version(
    model_version: ModelVersion, 
    features: IrisFeatures
):
    artifact = joblib.load(f"models/{model_version.value}")
    # ... 예측 로직

API 자동 문서화

Swagger UI (/docs)

http://localhost:8000/docs

기능:

  • Interactive parameter testing
  • Schema visualization
  • Try it out 기능

ReDoc API Documentation (/redoc)

http://localhost:8000/redoc

기능:

  • 깨끗한 API 문서 스타일
  • 복잡한 스키마에 최적화

성능 메트릭

메트릭
Throughput초당 1000+ 요청
Latency~10-50ms (모델 복잡도에 따라)
ScalabilityGunicorn 워커로 수평 확장

배포 플랫폼

플랫폼방식난이도
Herokugit push로 자동 배포⭐⭐
AWS EC2Systemd service로 운영⭐⭐⭐
Google Cloud RunServerless 컨테이너⭐⭐
Docker컨테이너 이미지 (Module 7-3)⭐⭐⭐

Module 7 실전 체크리스트

  • 모델 훈련 및 joblib로 직렬화
  • Pydantic BaseModel로 입력 스키마 정의
  • FastAPI 애플리케이션 생성
  • @app.on_event("startup")에서 모델 로드
  • /predict 엔드포인트 구현
  • /health 헬스 체크 엔드포인트 추가
  • CORS 미들웨어 설정
  • Swagger UI (/docs)에서 테스트
  • cURL 또는 Postman으로 테스트
  • 에러 처리 및 로깅 추가
  • requirements.txt에 fastapi, uvicorn 추가

Best Practices 요약

  1. ✅ Pydantic으로 입력 검증
  2. @app.on_event("startup")에서 모델 로드
  3. /health 헬스 체크 엔드포인트
  4. ✅ 상세한 docstring 작성
  5. ✅ 프로덕션은 Gunicorn 사용
  6. ✅ CORS 미들웨어 추가 (웹 통합)
  7. ✅ 적절한 에러 처리
  8. ✅ 모든 요청·응답 로깅
  9. ✅ API 버전 관리
  10. ✅ 자동 /docs 문서 활용

참고 자료

타 소스와의 연계

end-to-end-data-science-project (전체 프로젝트 구조 - 7-1) docker-ml-containerization (Docker 배포 - 7-3) github-documentation-standards (GitHub 문서화 - 7-4) python-asyncio-daleseo (비동기 프로그래밍) pydantic-validation-velog (Pydantic 데이터 검증)