diff --git a/README.md b/README.md index 28ca11d..572c49b 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ loop.run_until_complete(main()) ```python from novelai_python.utils.random_prompt import RandomPromptGenerator -s = RandomPromptGenerator(nsfw_enabled=False).generate() +s = RandomPromptGenerator(nsfw_enabled=False).random_prompt() print(s) ``` diff --git a/playground/random_prompt.py b/playground/random_prompt.py index fe27cbc..40b124a 100644 --- a/playground/random_prompt.py +++ b/playground/random_prompt.py @@ -3,11 +3,16 @@ # @Author : sudoskys # @File : random_prompt.py # @Software: PyCharm -import random from novelai_python.utils.random_prompt import RandomPromptGenerator - -print(random.random()) -s = RandomPromptGenerator(nsfw_enabled=True).generate() -print(s) +gen = RandomPromptGenerator(nsfw_enabled=True) +print(gen.get_weighted_choice([[1, 35], [2, 20], [3, 7]], [])) +print("====") +print(gen.get_weighted_choice([['mss', 30], ['fdd', 50], ['oa', 10]], [])) +print("====") +print(gen.get_weighted_choice([['m', 30], ['f', 50], ['o', 10]], ['m'])) +print("====") +for i in range(200): + s = RandomPromptGenerator(nsfw_enabled=True).random_prompt() + print(s) diff --git a/pyproject.toml b/pyproject.toml index d3c14c5..ad62a5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "novelai-python" -version = "0.2.0" +version = "0.2.1" description = "Novelai Python Binding With Pydantic" authors = [ { name = "sudoskys", email = "coldlando@hotmail.com" }, diff --git a/src/novelai_python/utils/random_prompt/__init__.py b/src/novelai_python/utils/random_prompt/__init__.py index 5469481..5de3a7d 100644 --- a/src/novelai_python/utils/random_prompt/__init__.py +++ b/src/novelai_python/utils/random_prompt/__init__.py @@ -30,16 +30,25 @@ def get_weighted_choice(tags, existing_tags: List[str]): :param existing_tags: a list of existing tags :return: a tag """ - valid_tags = [tag for tag in tags if - len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])] + valid_tags = [tag + for tag in tags + if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])] total_weight = sum(tagr[1] for tagr in valid_tags if len(tagr) > 1) if total_weight == 0: - return random.choice(tags) + if isinstance(tags, list): + rd = random.choice(tags) + elif isinstance(tags, str): + rd = tags + else: + raise ValueError('get_weighted_choice: should not reach here') + return rd random_number = random.randint(1, total_weight) cumulative_weight = 0 for tag in valid_tags: cumulative_weight += tag[1] if random_number <= cumulative_weight: + if isinstance(tag, str): + raise Exception("tag is string") return tag[0] raise ValueError('get_weighted_choice: should not reach here') @@ -61,7 +70,7 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters, has_unique_feature = any(feature in features for feature in unique_features) if random.random() < 0.1 and enable_skin_color: features.append(self.get_weighted_choice(sinkColor, features)) - if random.random() < 0.3: + if random.random() < 0.8: features.append(self.get_weighted_choice(eyeColors, features)) if random.random() < 0.1: features.append(self.get_weighted_choice(eyeCharacteristics, features)) @@ -71,12 +80,12 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters, features.append(self.get_weighted_choice(hairLength, features)) if random.random() < 0.2: features.append(self.get_weighted_choice(backHairStyle, features)) - if random.random() < 0.2: + if random.random() < 0.1: features.append(self.get_weighted_choice(hairColors, features)) if random.random() < 0.1: features.append(self.get_weighted_choice(hairColorExtra, features)) features.append(self.get_weighted_choice(hairColors, features)) - if random.random() < 0.1: + if random.random() < 0.12: features.append(self.get_weighted_choice(hairFeatures, features)) if gender.startswith('f') and random.random() < 0.8: features.append(self.get_weighted_choice(breastsSize, features)) @@ -136,9 +145,9 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters, # 单角色 + nsfw 为 1 possible_actions = action if nsfw_enabled: - if random.random() < 0.3: + if random.random() < 0.5: features.append(self.get_weighted_choice(nsfw["action"], features)) - if random.random() < 0.25: + if random.random() < 0.5: features.append(self.get_weighted_choice(nsfw["pussyForeplay"], features)) possible_actions += nsfw["action"] + nsfw["analForeplay"] + nsfw["pussyForeplay"] if random.random() < 0.5: @@ -190,8 +199,13 @@ def random_prompt(self, *, enable_moods: bool = True, enable_character: bool = True, enable_identity: bool = False, + must_appear=None, ): + if must_appear is None: + must_appear = [] tags = [] + # 必须出现的标签 + tags.extend(must_appear) if self.nsfw_enabled: tags.append('nsfw') if random.random() < 0.1: @@ -199,7 +213,7 @@ def random_prompt(self, *, tags.append('lewd') irs = self.get_weighted_choice([[1, 70], [2, 20], [3, 7], [0, 5]], tags) if self.nsfw_enabled: - irs = self.get_weighted_choice([[1, 35], [2, 20], [3, 7]], tags) + irs = self.get_weighted_choice([[1, 40], [2, 20], [3, 7]], tags) if irs == 0: tags.append('no humans') if random.random() < 0.3: @@ -216,11 +230,11 @@ def random_prompt(self, *, return ', '.join(tags) if random.random() < 0.3: tags.append(self.get_weighted_choice(artStyle, tags)) - if random.random() < 0.5: + if random.random() < 0.6: tags.append("{" + self.get_weighted_choice(rankArtist, tags) + "}") if random.random() < 0.5: tags.append("[" + self.get_weighted_choice(rankArtist, tags) + "]") - if random.random() < 0.5: + if random.random() < 0.4: tags.append("{{" + self.get_weighted_choice(rankArtist, tags) + "}}") c_count = 0 d_count = 0 @@ -260,17 +274,17 @@ def random_prompt(self, *, tags.append(nsfw["ya"]) if d_count > 0: if c_count > 0: - if random.random() < 0.6: + if random.random() < 0.3: tags.append(self.get_weighted_choice(nsfw["penis"], tags)) else: if random.random() < 0.3: tags.append(self.get_weighted_choice(nsfw["penis"], tags)) if d_count > 0 and g_count > 0: - if random.random() < 0.5: + if random.random() < 0.7: tags.append(self.get_weighted_choice(nsfw["analSex"], tags)) if g_count > 0: features = [] - if random.random() < 0.3: + if random.random() < 0.2: features = self.character_features(nsfw['fu'], None, True, irs) if random.random() < 0.6: tags.append(self.get_weighted_choice(nsfw["sex"], tags))