Skip to content

Commit

Permalink
整理: verify_mutability() で mutability チェックを置き換え (#1363)
Browse files Browse the repository at this point in the history
* refactor: `verify_mutability()` で mutability チェックを置き換え

* refactor: `verify_mutability` → `verify_mutability_allowed` へリネーム

* Update voicevox_engine/app/dependencies.py

* 更新

---------

Co-authored-by: Hiroshiba <[email protected]>
Co-authored-by: Hiroshiba Kazuyuki <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent 95f218a commit c15dd34
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 42 deletions.
23 changes: 15 additions & 8 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import FastAPI

from voicevox_engine import __version__
from voicevox_engine.app.dependencies import deprecated_mutable_api
from voicevox_engine.app.dependencies import generate_mutability_allowed_verifier
from voicevox_engine.app.global_exceptions import configure_global_exception_handlers
from voicevox_engine.app.middlewares import configure_middlewares
from voicevox_engine.app.openapi_schema import configure_openapi_schema
Expand Down Expand Up @@ -49,6 +49,10 @@ def generate_app(
if speaker_info_dir is None:
speaker_info_dir = engine_root() / "resources" / "character_info"

verify_mutability_allowed = generate_mutability_allowed_verifier(
disable_mutable_api
)

app = FastAPI(
title=engine_manifest.name,
description=f"{engine_manifest.brand_name} の音声合成エンジンです。",
Expand All @@ -58,9 +62,6 @@ def generate_app(
app = configure_middlewares(app, cors_policy_mode, allow_origin)
app = configure_global_exception_handlers(app)

if disable_mutable_api:
deprecated_mutable_api.enable = False

metas_store = MetasStore(speaker_info_dir)

app.include_router(
Expand All @@ -69,16 +70,22 @@ def generate_app(
)
)
app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store))
app.include_router(generate_preset_router(preset_manager))
app.include_router(
generate_preset_router(preset_manager, verify_mutability_allowed)
)
app.include_router(
generate_speaker_router(core_manager, metas_store, speaker_info_dir)
)
if engine_manifest.supported_features.manage_library:
app.include_router(generate_library_router(library_manager))
app.include_router(generate_user_dict_router(user_dict))
app.include_router(
generate_library_router(library_manager, verify_mutability_allowed)
)
app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed))
app.include_router(generate_engine_info_router(core_manager, engine_manifest))
app.include_router(
generate_setting_router(setting_loader, engine_manifest.brand_name)
generate_setting_router(
setting_loader, engine_manifest.brand_name, verify_mutability_allowed
)
)
app.include_router(generate_portal_page_router(engine_manifest.name))

Expand Down
27 changes: 13 additions & 14 deletions voicevox_engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""FastAPI dependencies"""

from dataclasses import dataclass
from typing import Any, Callable, Coroutine, TypeAlias

from fastapi import HTTPException

VerifyMutabilityAllowed: TypeAlias = Callable[[], Coroutine[Any, Any, None]]

# 許可されていないAPIを無効化する
@dataclass
class MutableAPI:
enable: bool = True

def generate_mutability_allowed_verifier(
disable_mutable_api: bool,
) -> VerifyMutabilityAllowed:
"""verify_mutability_allowed 関数(データ変更の許可を確認する関数)を生成する。"""

# FIXME: グローバル変数が複数ファイルに分散しているため、DI 等で局所化する
deprecated_mutable_api = MutableAPI()
async def verify_mutability_allowed() -> None:
if disable_mutable_api:
msg = "エンジンの静的なデータを変更するAPIは無効化されています"
raise HTTPException(status_code=403, detail=msg)
else:
pass


async def check_disabled_mutable_api() -> None:
if not deprecated_mutable_api.enable:
raise HTTPException(
status_code=403,
detail="エンジンの静的なデータを変更するAPIは無効化されています",
)
return verify_mutability_allowed
10 changes: 6 additions & 4 deletions voicevox_engine/app/routers/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
)
from voicevox_engine.library.model import DownloadableLibraryInfo, InstalledLibraryInfo

from ..dependencies import check_disabled_mutable_api
from ..dependencies import VerifyMutabilityAllowed


def generate_library_router(library_manager: LibraryManager) -> APIRouter:
def generate_library_router(
library_manager: LibraryManager, verify_mutability: VerifyMutabilityAllowed
) -> APIRouter:
"""音声ライブラリ API Router を生成する"""
router = APIRouter(tags=["音声ライブラリ管理"])

Expand All @@ -46,7 +48,7 @@ def installed_libraries() -> dict[str, InstalledLibraryInfo]:
@router.post(
"/install_library/{library_uuid}",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
async def install_library(
library_uuid: Annotated[str, Path(description="音声ライブラリのID")],
Expand Down Expand Up @@ -76,7 +78,7 @@ async def install_library(
@router.post(
"/uninstall_library/{library_uuid}",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def uninstall_library(
library_uuid: Annotated[str, Path(description="音声ライブラリのID")]
Expand Down
12 changes: 7 additions & 5 deletions voicevox_engine/app/routers/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
PresetManager,
)

from ..dependencies import check_disabled_mutable_api
from ..dependencies import VerifyMutabilityAllowed


def generate_preset_router(preset_manager: PresetManager) -> APIRouter:
def generate_preset_router(
preset_manager: PresetManager, verify_mutability: VerifyMutabilityAllowed
) -> APIRouter:
"""プリセット API Router を生成する"""
router = APIRouter(tags=["その他"])

Expand All @@ -37,7 +39,7 @@ def get_presets() -> list[Preset]:
@router.post(
"/add_preset",
response_description="追加したプリセットのプリセットID",
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def add_preset(
preset: Annotated[
Expand All @@ -61,7 +63,7 @@ def add_preset(
@router.post(
"/update_preset",
response_description="更新したプリセットのプリセットID",
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def update_preset(
preset: Annotated[
Expand All @@ -85,7 +87,7 @@ def update_preset(
@router.post(
"/delete_preset",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def delete_preset(
id: Annotated[int, Query(description="削除するプリセットのプリセットID")]
Expand Down
10 changes: 5 additions & 5 deletions voicevox_engine/app/routers/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from voicevox_engine.setting.setting_manager import Setting, SettingHandler
from voicevox_engine.utility.path_utility import resource_root

from ..dependencies import check_disabled_mutable_api
from ..dependencies import VerifyMutabilityAllowed

_setting_ui_template = Jinja2Templates(
env=Environment(
Expand All @@ -24,7 +24,9 @@


def generate_setting_router(
setting_loader: SettingHandler, brand_name: BrandName
setting_loader: SettingHandler,
brand_name: BrandName,
verify_mutability: VerifyMutabilityAllowed,
) -> APIRouter:
"""設定 API Router を生成する"""
router = APIRouter(tags=["設定"])
Expand Down Expand Up @@ -52,9 +54,7 @@ def setting_get(request: Request) -> Response:
},
)

@router.post(
"/setting", status_code=204, dependencies=[Depends(check_disabled_mutable_api)]
)
@router.post("/setting", status_code=204, dependencies=[Depends(verify_mutability)])
def setting_post(
cors_policy_mode: Annotated[CorsPolicyMode, Form()],
allow_origin: Annotated[str | SkipJsonSchema[None], Form()] = None,
Expand Down
14 changes: 8 additions & 6 deletions voicevox_engine/app/routers/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
WordProperty,
)

from ..dependencies import check_disabled_mutable_api
from ..dependencies import VerifyMutabilityAllowed


def generate_user_dict_router(user_dict: UserDictionary) -> APIRouter:
def generate_user_dict_router(
user_dict: UserDictionary, verify_mutability: VerifyMutabilityAllowed
) -> APIRouter:
"""ユーザー辞書 API Router を生成する"""
router = APIRouter(tags=["ユーザー辞書"])

Expand All @@ -40,7 +42,7 @@ def get_user_dict_words() -> dict[str, UserDictWord]:
status_code=500, detail="辞書の読み込みに失敗しました。"
)

@router.post("/user_dict_word", dependencies=[Depends(check_disabled_mutable_api)])
@router.post("/user_dict_word", dependencies=[Depends(verify_mutability)])
def add_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
Expand Down Expand Up @@ -92,7 +94,7 @@ def add_user_dict_word(
@router.put(
"/user_dict_word/{word_uuid}",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def rewrite_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
Expand Down Expand Up @@ -146,7 +148,7 @@ def rewrite_user_dict_word(
@router.delete(
"/user_dict_word/{word_uuid}",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def delete_user_dict_word(
word_uuid: Annotated[str, Path(description="削除する言葉のUUID")]
Expand All @@ -166,7 +168,7 @@ def delete_user_dict_word(
@router.post(
"/import_user_dict",
status_code=204,
dependencies=[Depends(check_disabled_mutable_api)],
dependencies=[Depends(verify_mutability)],
)
def import_user_dict_words(
import_dict_data: Annotated[
Expand Down

0 comments on commit c15dd34

Please sign in to comment.