diff --git a/main.py b/main.py index 01d5288..04f3e1f 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ # Backend/main.py import os from dotenv import load_dotenv +# 환경 변수를 최대한 빨리 로드하여 GPU 설정(CUDA_VISIBLE_DEVICES)이 라우터 임포트 전에 적용되도록 함 +load_dotenv() from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -12,8 +14,12 @@ from routers.checklist import router as checklist_router from routers.file import router as file_router + +# 1) 환경변수 로드 (상단에서 선 로드됨) + import uvicorn + load_dotenv() app = FastAPI() diff --git a/models/base.py b/models/base.py index 59be703..bf02b73 100644 --- a/models/base.py +++ b/models/base.py @@ -1,3 +1,5 @@ +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base + Base = declarative_base() diff --git a/models/file.py b/models/file.py index bef1d99..55eb21b 100644 --- a/models/file.py +++ b/models/file.py @@ -1,3 +1,6 @@ + +from sqlalchemy.orm import relationship + from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text from sqlalchemy.orm import relationship from .base import Base @@ -6,6 +9,20 @@ class File(Base): __tablename__ = "file" id = Column(Integer, primary_key=True, autoincrement=True) + + user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) + folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) + note_id = Column(Integer, ForeignKey('note.id', ondelete='CASCADE'), nullable=True) + original_name = Column(String(255), nullable=False) + saved_path = Column(String(512), nullable=False) + content_type = Column(String(100), nullable=False) + created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + user = relationship("User", back_populates="files") + folder = relationship("Folder", back_populates="files") + note = relationship("Note", back_populates="files") + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) note_id = Column(Integer, ForeignKey("note.id", ondelete="SET NULL"), nullable=True) @@ -17,3 +34,4 @@ class File(Base): # relations user = relationship("User", back_populates="files") note = relationship("Note", back_populates="files") + diff --git a/models/folder.py b/models/folder.py index 3105616..92d6bf0 100644 --- a/models/folder.py +++ b/models/folder.py @@ -11,6 +11,16 @@ class Folder(Base): parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) updated_at = Column(TIMESTAMP, nullable=False, + + server_default=text('CURRENT_TIMESTAMP'), + onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + user = relationship("User", back_populates="folders") + parent = relationship("Folder", remote_side=[id], backref="children") + notes = relationship("Note", back_populates="folder", cascade="all, delete") + files = relationship("File", back_populates="folder", cascade="all, delete") + server_default=text("CURRENT_TIMESTAMP"), onupdate=text("CURRENT_TIMESTAMP")) @@ -18,3 +28,4 @@ class Folder(Base): user = relationship("User") parent = relationship("Folder", remote_side=[id], backref="children") notes = relationship("Note", back_populates="folder", cascade="all, delete") + diff --git a/models/note.py b/models/note.py index f52ecca..680a666 100644 --- a/models/note.py +++ b/models/note.py @@ -14,10 +14,17 @@ class Note(Base): last_accessed = Column(TIMESTAMP, nullable=True) created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) updated_at = Column(TIMESTAMP, nullable=False, + + server_default=text('CURRENT_TIMESTAMP'), + onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + server_default=text("CURRENT_TIMESTAMP"), onupdate=text("CURRENT_TIMESTAMP")) # relations + user = relationship("User", back_populates="notes") folder = relationship("Folder", back_populates="notes") files = relationship("File", back_populates="note", cascade="all, delete") diff --git a/models/user.py b/models/user.py index 96f341f..1bf291f 100644 --- a/models/user.py +++ b/models/user.py @@ -9,6 +9,22 @@ class User(Base): id = Column(String(50), nullable=False, unique=True) # 로그인 ID 또는 소셜 ID email = Column(String(150), nullable=False, unique=True) password = Column(String(255), nullable=False) + + provider = Column( + Enum('local','google','kakao','naver', name='provider_enum'), + nullable=False, + server_default=text("'local'") + ) + created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + updated_at = Column(TIMESTAMP, nullable=False, + server_default=text('CURRENT_TIMESTAMP'), + onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + folders = relationship("Folder", back_populates="user", cascade="all, delete") + notes = relationship("Note", back_populates="user", cascade="all, delete") + files = relationship("File", back_populates="user", cascade="all, delete") + provider = Column(Enum("local", "google", "kakao", "naver", name="provider_enum"), nullable=False, server_default=text("'local'")) created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) diff --git a/routers/file.py b/routers/file.py index 6f37d6c..2b9d7ef 100644 --- a/routers/file.py +++ b/routers/file.py @@ -12,6 +12,52 @@ from models.note import Note as NoteModel from utils.jwt_utils import get_current_user +cuda_gpu +# 추가: 파일명 인코딩용 +import urllib.parse + +# ------------------------------- +# 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화) +# ------------------------------- +import easyocr +reader = easyocr.Reader(["ko", "en"], gpu=True) + +# ------------------------------- +# 2) Hugging Face TrOCR 모델용 파이프라인 (GPU 사용) +# ------------------------------- +from transformers import pipeline + +hf_trocr_printed = pipeline( + "image-to-text", + model="microsoft/trocr-base-printed", + device=0, + trust_remote_code=True +) +hf_trocr_handwritten = pipeline( + "image-to-text", + model="microsoft/trocr-base-handwritten", + device=0, + trust_remote_code=True +) +hf_trocr_small_printed = pipeline( + "image-to-text", + model="microsoft/trocr-small-printed", + device=0, + trust_remote_code=True +) +hf_trocr_large_printed = pipeline( + "image-to-text", + model="microsoft/trocr-large-printed", + device=0, + trust_remote_code=True +) + +BASE_UPLOAD_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "uploads" +) + # 공통 OCR 파이프라인 from utils.ocr import run_pipeline, detect_type from schemas.file import OCRResponse @@ -79,11 +125,10 @@ async def upload_file( orig_filename: str = upload_file.filename or "unnamed" content_type: str = upload_file.content_type or "application/octet-stream" - # 사용자별 디렉토리 생성 user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id)) os.makedirs(user_dir, exist_ok=True) - # 원본 파일명 유지 (중복 방지) + saved_filename = orig_filename saved_path = os.path.join(user_dir, saved_filename) if os.path.exists(saved_path): @@ -98,7 +143,9 @@ async def upload_file( break counter += 1 + # 저장 + try: with open(saved_path, "wb") as buffer: content = await upload_file.read() @@ -106,6 +153,7 @@ async def upload_file( except Exception as e: raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}") + # note_id가 있으면 해당 노트 확인 note_obj = None if note_id is not None: @@ -118,6 +166,7 @@ async def upload_file( raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.") # DB 메타 기록 + new_file = FileModel( user_id=current_user.u_id, folder_id=None if note_id else folder_id, @@ -202,6 +251,10 @@ def download_file( if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.") + # 원본 파일명 UTF-8 URL 인코딩 처리 + quoted_name = urllib.parse.quote(file_obj.original_name, safe='') + content_disposition = f"inline; filename*=UTF-8''{quoted_name}" + return FileResponse( path=file_path, media_type=file_obj.content_type, @@ -254,6 +307,67 @@ async def ocr_and_create_note( db: Session = Depends(get_db), current_user = Depends(get_current_user) ): + + """ + • ocr_file: 이미지 파일(UploadFile) + • 1) EasyOCR로 기본 텍스트 추출 (GPU 모드) + • 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU) + • 3) 가장 긴 결과를 최종 OCR 결과로 선택 + • 4) Note로 저장 및 결과 반환 + """ + + # 1) 이미지 로드 (PIL) + contents = await ocr_file.read() + try: + image = Image.open(io.BytesIO(contents)).convert("RGB") + except Exception as e: + raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}") + + # 2) EasyOCR로 텍스트 추출 + try: + image_np = np.array(image) + easy_results = reader.readtext(image_np) # GPU 모드 사용 + easy_text = " ".join([res[1] for res in easy_results]) + except Exception: + easy_text = "" + + # 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input) + hf_texts: List[str] = [] + try: + out1 = hf_trocr_printed(image) + if isinstance(out1, list) and "generated_text" in out1[0]: + hf_texts.append(out1[0]["generated_text"].strip()) + + out2 = hf_trocr_handwritten(image) + if isinstance(out2, list) and "generated_text" in out2[0]: + hf_texts.append(out2[0]["generated_text"].strip()) + + out3 = hf_trocr_small_printed(image) + if isinstance(out3, list) and "generated_text" in out3[0]: + hf_texts.append(out3[0]["generated_text"].strip()) + + out4 = hf_trocr_large_printed(image) + if isinstance(out4, list) and "generated_text" in out4[0]: + hf_texts.append(out4[0]["generated_text"].strip()) + except Exception: + # TrOCR 중 오류 발생 시 무시하고 계속 진행 + pass + + # 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택 + candidates = [t for t in [easy_text] + hf_texts if t and t.strip()] + if not candidates: + raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.") + + ocr_text = max(candidates, key=lambda s: len(s)) + + # 5) 새 노트 생성 및 DB에 저장 + try: + new_note = NoteModel( + user_id=current_user.u_id, + folder_id=folder_id, + title="OCR 결과", + content=ocr_text # **원본 OCR 텍스트만 저장** + # 422 방지: 파일 필드명 유연 처리 upload = file or ocr_file if upload is None: @@ -277,6 +391,7 @@ async def ocr_and_create_note( warnings=[f"허용되지 않는 확장자({ext}). 허용: {sorted(ALLOWED_ALL_EXTS)}"], note_id=None, text=None, + ) # 타입 판별