Skip to content

Commit

Permalink
🎨 chore(utils/random_prompt): refactor tag selection logic
Browse files Browse the repository at this point in the history
Refactored the tag selection logic in the `RandomPrompt` class to improve
readability and maintainability. Split long lines and adjusted indentation
for better code formatting. Removed unused imports and commented out code.

Co-authored-by: [Author Name] <[[email protected]]>
  • Loading branch information
sudoskys and [Author Name] committed Feb 14, 2024
1 parent 0b22334 commit dbc2d62
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 420 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.3.2"
version = "0.3.3"
description = "Novelai Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
37 changes: 16 additions & 21 deletions src/novelai_python/utils/random_prompt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
socks,
action, expression, footwears, bottoms, color as colors
)
from .tag_artist import rankArtist
from .tag_character import rankMoods, rankCharacter, rankIdentity
from .tag_nsfw import nsfw

Expand All @@ -30,9 +29,11 @@ def get_weighted_choice(tags, existing_tags: List[str]):
:param existing_tags: a list of existing tags
:return: a tag
"""
valid_tags = [tag
for tag in tags
if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
valid_tags = [
tag
for tag in tags
if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])
]
total_weight = sum(tagr[1] for tagr in valid_tags if len(tagr) > 1)
if total_weight == 0:
if isinstance(tags, list):
Expand Down Expand Up @@ -70,7 +71,7 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
has_unique_feature = any(feature in features for feature in unique_features)
if random.random() < 0.1 and enable_skin_color:
features.append(self.get_weighted_choice(sinkColor, features))
if random.random() < 0.8:
if random.random() < 0.7:
features.append(self.get_weighted_choice(eyeColors, features))
if random.random() < 0.1:
features.append(self.get_weighted_choice(eyeCharacteristics, features))
Expand Down Expand Up @@ -230,12 +231,6 @@ def random_prompt(self, *,
return ', '.join(tags)
if random.random() < 0.3:
tags.append(self.get_weighted_choice(artStyle, tags))
if random.random() < 0.6:
tags.append("{" + self.get_weighted_choice(rankArtist, tags) + "}")
if random.random() < 0.5:
tags.append("[" + self.get_weighted_choice(rankArtist, tags) + "]")
if random.random() < 0.4:
tags.append("{{" + self.get_weighted_choice(rankArtist, tags) + "}}")
c_count = 0
d_count = 0
u_count = 0
Expand Down Expand Up @@ -286,23 +281,23 @@ def random_prompt(self, *,
features = []
if random.random() < 0.2:
features = self.character_features(nsfw['fu'], None, True, irs)
if random.random() < 0.6:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["sex"], tags))
if random.random() < 0.6:
if random.random() < 0.5:
tags.append(self.get_weighted_choice(nsfw["pussy"], tags))
if random.random() < 0.6:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["sexMod"], tags))
if random.random() < 0.6:
if random.random() < 0.5:
tags.append(self.get_weighted_choice(nsfw["sexActionMode"], tags))
if random.random() < 0.2:
if random.random() < 0.1:
tags.append(self.get_weighted_choice(nsfw["bdsm"], tags))
if random.random() < 0.3:
if random.random() < 0.2:
tags.append(self.get_weighted_choice(nsfw["sexAccessories"], tags))
tags.extend(features)
if random.random() < 0.65 and enable_moods:
if random.random() < 0.4 and enable_moods:
# 心情
tags.append("[" + self.get_weighted_choice(rankMoods, tags) + "]")
if random.random() < 0.5 and enable_identity:
if random.random() < 0.4 and enable_identity:
# 身份
tags.append("[" + self.get_weighted_choice(rankIdentity, tags) + "]")
if random.random() < 0.2 and enable_character:
Expand All @@ -315,9 +310,9 @@ def random_prompt(self, *,
num_objects = random.randint(1, 3)
for _ in range(num_objects):
tags.append(self.get_weighted_choice(backgroundObjects, tags))
if random.random() < 0.3:
if random.random() < 0.1:
tags.append(self.get_weighted_choice(cameraPerspective, tags))
if random.random() < 0.7:
if random.random() < 0.5:
tags.append(self.get_weighted_choice(cameraAngle, tags))
for _ in range(c_count):
if random.random() < 0.2:
Expand Down
Loading

0 comments on commit dbc2d62

Please sign in to comment.