Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:random prompt #7

Merged
merged 5 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ loop.run_until_complete(main())
```python
from novelai_python.utils.random_prompt import RandomPromptGenerator

s = RandomPromptGenerator(nsfw_enabled=False).generate()
s = RandomPromptGenerator(nsfw_enabled=False).random_prompt()
print(s)
```

Expand Down
15 changes: 10 additions & 5 deletions playground/random_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
# @Author : sudoskys
# @File : random_prompt.py
# @Software: PyCharm
import random

from novelai_python.utils.random_prompt import RandomPromptGenerator


print(random.random())
s = RandomPromptGenerator(nsfw_enabled=True).generate()
print(s)
gen = RandomPromptGenerator(nsfw_enabled=True)
print(gen.get_weighted_choice([[1, 35], [2, 20], [3, 7]], []))
print("====")
print(gen.get_weighted_choice([['mss', 30], ['fdd', 50], ['oa', 10]], []))
print("====")
print(gen.get_weighted_choice([['m', 30], ['f', 50], ['o', 10]], ['m']))
print("====")
for i in range(200):
s = RandomPromptGenerator(nsfw_enabled=True).random_prompt()
print(s)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.2.0"
version = "0.2.1"
description = "Novelai Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
42 changes: 28 additions & 14 deletions src/novelai_python/utils/random_prompt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ def get_weighted_choice(tags, existing_tags: List[str]):
:param existing_tags: a list of existing tags
:return: a tag
"""
valid_tags = [tag for tag in tags if
len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
valid_tags = [tag
for tag in tags
if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
total_weight = sum(tagr[1] for tagr in valid_tags if len(tagr) > 1)
if total_weight == 0:
return random.choice(tags)
if isinstance(tags, list):
rd = random.choice(tags)
elif isinstance(tags, str):
rd = tags
else:
raise ValueError('get_weighted_choice: should not reach here')
return rd
random_number = random.randint(1, total_weight)
cumulative_weight = 0
for tag in valid_tags:
cumulative_weight += tag[1]
if random_number <= cumulative_weight:
if isinstance(tag, str):
raise Exception("tag is string")
return tag[0]
raise ValueError('get_weighted_choice: should not reach here')

Expand All @@ -61,7 +70,7 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
has_unique_feature = any(feature in features for feature in unique_features)
if random.random() < 0.1 and enable_skin_color:
features.append(self.get_weighted_choice(sinkColor, features))
if random.random() < 0.3:
if random.random() < 0.8:
features.append(self.get_weighted_choice(eyeColors, features))
if random.random() < 0.1:
features.append(self.get_weighted_choice(eyeCharacteristics, features))
Expand All @@ -71,12 +80,12 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
features.append(self.get_weighted_choice(hairLength, features))
if random.random() < 0.2:
features.append(self.get_weighted_choice(backHairStyle, features))
if random.random() < 0.2:
if random.random() < 0.1:
features.append(self.get_weighted_choice(hairColors, features))
if random.random() < 0.1:
features.append(self.get_weighted_choice(hairColorExtra, features))
features.append(self.get_weighted_choice(hairColors, features))
if random.random() < 0.1:
if random.random() < 0.12:
features.append(self.get_weighted_choice(hairFeatures, features))
if gender.startswith('f') and random.random() < 0.8:
features.append(self.get_weighted_choice(breastsSize, features))
Expand Down Expand Up @@ -136,9 +145,9 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
# 单角色 + nsfw 为 1
possible_actions = action
if nsfw_enabled:
if random.random() < 0.3:
if random.random() < 0.5:
features.append(self.get_weighted_choice(nsfw["action"], features))
if random.random() < 0.25:
if random.random() < 0.5:
features.append(self.get_weighted_choice(nsfw["pussyForeplay"], features))
possible_actions += nsfw["action"] + nsfw["analForeplay"] + nsfw["pussyForeplay"]
if random.random() < 0.5:
Expand Down Expand Up @@ -190,16 +199,21 @@ def random_prompt(self, *,
enable_moods: bool = True,
enable_character: bool = True,
enable_identity: bool = False,
must_appear=None,
):
if must_appear is None:
must_appear = []
tags = []
# 必须出现的标签
tags.extend(must_appear)
if self.nsfw_enabled:
tags.append('nsfw')
if random.random() < 0.1:
tags.append('explicit')
tags.append('lewd')
irs = self.get_weighted_choice([[1, 70], [2, 20], [3, 7], [0, 5]], tags)
if self.nsfw_enabled:
irs = self.get_weighted_choice([[1, 35], [2, 20], [3, 7]], tags)
irs = self.get_weighted_choice([[1, 40], [2, 20], [3, 7]], tags)
if irs == 0:
tags.append('no humans')
if random.random() < 0.3:
Expand All @@ -216,11 +230,11 @@ def random_prompt(self, *,
return ', '.join(tags)
if random.random() < 0.3:
tags.append(self.get_weighted_choice(artStyle, tags))
if random.random() < 0.5:
if random.random() < 0.6:
tags.append("{" + self.get_weighted_choice(rankArtist, tags) + "}")
if random.random() < 0.5:
tags.append("[" + self.get_weighted_choice(rankArtist, tags) + "]")
if random.random() < 0.5:
if random.random() < 0.4:
tags.append("{{" + self.get_weighted_choice(rankArtist, tags) + "}}")
c_count = 0
d_count = 0
Expand Down Expand Up @@ -260,17 +274,17 @@ def random_prompt(self, *,
tags.append(nsfw["ya"])
if d_count > 0:
if c_count > 0:
if random.random() < 0.6:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["penis"], tags))
else:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["penis"], tags))
if d_count > 0 and g_count > 0:
if random.random() < 0.5:
if random.random() < 0.7:
tags.append(self.get_weighted_choice(nsfw["analSex"], tags))
if g_count > 0:
features = []
if random.random() < 0.3:
if random.random() < 0.2:
features = self.character_features(nsfw['fu'], None, True, irs)
if random.random() < 0.6:
tags.append(self.get_weighted_choice(nsfw["sex"], tags))
Expand Down
Loading