Skip to content

Commit

Permalink
Merge pull request #69 from LlmKira/dev
Browse files Browse the repository at this point in the history
Improvement: Added settings related to the `Variety Boost ` in `generate_image.Class.build` and optimized default settings.
  • Loading branch information
sudoskys authored Aug 31, 2024
2 parents 72b3734 + b4d46ca commit cae6c9d
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 187 deletions.
299 changes: 152 additions & 147 deletions pdm.lock

Large diffs are not rendered by default.

Binary file modified playground/augment-image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 5 additions & 2 deletions playground/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from novelai_python import APIError, LoginCredential
from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential
from novelai_python.sdk.ai.generate_image import Action, Sampler, Model
from novelai_python.sdk.ai._enum import Sampler
from novelai_python.sdk.ai.generate_image import Action, Model
from novelai_python.utils.useful import enum_to_list


Expand All @@ -33,8 +34,10 @@ async def generate(prompt="1girl, year 2023, dynamic angle, best quality, amazin
prompt=prompt,
model=Model.NAI_DIFFUSION_3,
action=Action.GENERATE,
sampler=Sampler.DDIM,
sampler=Sampler.K_DPMPP_2M,
qualityToggle=True,
decrisp_mode=False,
variety_boost=True
)
print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3")
print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3")
Expand Down
2 changes: 1 addition & 1 deletion playground/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ async def main():
logger.exception(e)


loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.run_until_complete(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.4.13"
version = "0.4.14"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
6 changes: 6 additions & 0 deletions src/novelai_python/credential/ApiToken.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# @File : ApiToken.py
# @Software: PyCharm

import datetime

import shortuuid
from curl_cffi.requests import AsyncSession
from loguru import logger
from pydantic import SecretStr, Field, field_validator
Expand All @@ -16,6 +19,7 @@ class ApiCredential(CredentialBase):
ApiCredential is the base class for all credential.
"""
api_token: SecretStr = Field(None, description="api token")
_x_correlation_id: str = shortuuid.uuid()[0:6]

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
Expand All @@ -26,6 +30,8 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self._x_correlation_id,
"x-initiated-at": f"{datetime.datetime.now(datetime.UTC).isoformat()}Z",
}

if update_headers:
Expand Down
6 changes: 6 additions & 0 deletions src/novelai_python/credential/JwtToken.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# @Author : sudoskys
# @File : JwtToken.py
# @Software: PyCharm
import datetime

import shortuuid
from curl_cffi.requests import AsyncSession
from loguru import logger
from pydantic import SecretStr, Field, field_validator
Expand All @@ -15,6 +18,7 @@ class JwtCredential(CredentialBase):
JwtCredential is the base class for all credential.
"""
jwt_token: SecretStr = Field(None, description="jwt token")
_x_correlation_id: str = shortuuid.uuid()[0:6]

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
Expand All @@ -25,6 +29,8 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self._x_correlation_id,
"x-initiated-at": f"{datetime.datetime.now(datetime.UTC).isoformat()}Z",
}

if update_headers:
Expand Down
5 changes: 5 additions & 0 deletions src/novelai_python/credential/UserAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# @Time : 2024/2/7 下午12:14
# @Author : sudoskys
# @File : UserAuth.py
import datetime
import time
from typing import Optional

import shortuuid
from curl_cffi.requests import AsyncSession
from pydantic import SecretStr, Field

Expand All @@ -19,6 +21,7 @@ class LoginCredential(CredentialBase):
password: SecretStr = Field(None, description="password")
_session_headers: dict = {}
_update_at: Optional[int] = None
_x_correlation_id: str = shortuuid.uuid()[0:6]

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
Expand All @@ -29,6 +32,8 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self._x_correlation_id,
"x-initiated-at": f"{datetime.datetime.now(datetime.UTC).isoformat()}Z",
}

# 30 天有效期
Expand Down
64 changes: 57 additions & 7 deletions src/novelai_python/sdk/ai/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@
# @File : _enum.py
# @Software: PyCharm
from enum import Enum, IntEnum
from typing import List


class Sampler(Enum):
PLMS = "plms"
DDIM = "ddim"
K_EULER = "k_euler"
K_EULER_ANCESTRAL = "k_euler_ancestral"
K_DPMPP_2S_ANCESTRAL = "k_dpmpp_2s_ancestral"
K_DPMPP_2M = "k_dpmpp_2m"
K_DPMPP_SDE = "k_dpmpp_sde"
DDIM_V3 = "ddim_v3"

DDIM = "ddim"
PLMS = "plms"
K_DPM_2 = "k_dpm_2"
K_DPM_2_ANCESTRAL = "k_dpm_2_ancestral"
K_LMS = "k_lms"
K_DPMPP_2S_ANCESTRAL = "k_dpmpp_2s_ancestral"
K_DPMPP_SDE = "k_dpmpp_sde"
K_DPMPP_2M = "k_dpmpp_2m"
K_DPM_ADAPTIVE = "k_dpm_adaptive"
K_DPM_FAST = "k_dpm_fast"
K_DPMPP_2M_SDE = "k_dpmpp_2m_sde"
K_DPMPP_3M_SDE = "k_dpmpp_3m_sde"
DDIM_V3 = "ddim_v3"
NAI_SMEA = "nai_smea"
NAI_SMEA_DYN = "nai_smea_dyn"

Expand All @@ -34,6 +34,56 @@ class NoiseSchedule(Enum):
POLYEXPONENTIAL = "polyexponential"


def get_supported_noise_schedule(sample_type: Sampler) -> List[NoiseSchedule]:
"""
Get supported noise schedule for a given sample type
:param sample_type: Sampler
:return: List[NoiseSchedule]
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return [
NoiseSchedule.NATIVE,
NoiseSchedule.KARRAS,
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
elif sample_type in [Sampler.K_DPM_2]:
return [
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
else:
return []


def get_default_noise_schedule(sample_type: Sampler) -> NoiseSchedule:
"""
Get default noise schedule for a given sample type
:param sample_type: Sampler
:return: NoiseSchedule
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return NoiseSchedule.KARRAS
elif sample_type in [Sampler.K_DPM_2]:
return NoiseSchedule.EXPONENTIAL
else:
return NoiseSchedule.NATIVE


class UCPreset(IntEnum):
TYPE0 = 0
TYPE1 = 1
Expand Down
Loading

0 comments on commit cae6c9d

Please sign in to comment.