Skip to content

Commit

Permalink
Update chat.py
Browse files Browse the repository at this point in the history
added a few enhancements
  • Loading branch information
Inserian authored Dec 21, 2024
1 parent 7e8dbe4 commit 1c7cb96
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def load_from_file(self, model_file_path: str, token_map_file_path: str = None):
self.embeddings = np.array(model_dict["embeddings"], dtype=np.float32)

# Load token-to-ID and ID-to-token mappings if present
if "token_to_id" in model_dict:
if "token_to_id" in model_dict and "id_to_token" in model_dict:
self.token_to_id = model_dict["token_to_id"]
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
print("token_to_id loaded from the model file.")
self.id_to_token = {int(k): v for k, v in model_dict["id_to_token"].items()}
print("token_to_id and id_to_token loaded from the model file.")
else:
if token_map_file_path:
self.load_token_map_from_file(token_map_file_path)
else:
print("token_to_id not found in the model file and no token mapping file provided.")
print("token_to_id and id_to_token not found in the model file and no token mapping file provided.")
# Initialize empty mappings
self.token_to_id = {}
self.id_to_token = {}
Expand Down Expand Up @@ -172,8 +172,12 @@ def load_token_map_from_file(self, token_map_file_path: str):
if not isinstance(token_map, dict):
raise ValueError("Token mapping file is not a valid dictionary.")

self.token_to_id = token_map
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
# Ensure that both 'token_to_id' and 'id_to_token' exist
if "token_to_id" not in token_map or "id_to_token" not in token_map:
raise ValueError("Token mapping file must contain both 'token_to_id' and 'id_to_token'.")

self.token_to_id = token_map["token_to_id"]
self.id_to_token = {int(k): v for k, v in token_map["id_to_token"].items()}
print(f"Token mapping loaded from '{os.path.basename(token_map_file_path)}'.")
except Exception as e:
raise ValueError(f"Failed to load token mapping from '{token_map_file_path}': {e}")
Expand Down Expand Up @@ -208,7 +212,11 @@ def run_inference(self, input_text: str, max_length: int = 10):

for _ in range(max_length):
# Use the embeddings to generate a response
input_vector = normalize_vector(np.sum(self.embeddings[current_input_ids], axis=0))
try:
input_vector = normalize_vector(np.sum(self.embeddings[current_input_ids], axis=0))
except IndexError as e:
raise ValueError(f"One of the input_ids {current_input_ids} is out of bounds for embeddings.") from e

print(f"Input Vector Shape: {input_vector.shape}") # Debugging

# If projection layer exists, apply it
Expand All @@ -232,7 +240,11 @@ def run_inference(self, input_text: str, max_length: int = 10):
probabilities = softmax(logits)

# Sample a token based on the probability distribution
sampled_id = np.random.choice(self.vocab_size, p=probabilities)
try:
sampled_id = np.random.choice(self.vocab_size, p=probabilities)
except ValueError as e:
raise ValueError("Probabilities do not sum to 1 or contain invalid values.") from e

print(f"Sampled ID: {sampled_id}") # Debugging

# Decode the sampled ID back into a token
Expand Down Expand Up @@ -432,11 +444,11 @@ def load_model(self):
self.update_chat("System", f"Model loaded from '{os.path.basename(model_file_path)}'.", color=self.system_color)
self.status_label.config(text=f"Model loaded: {os.path.basename(model_file_path)}")

if not self.model.token_to_id:
if not self.model.token_to_id or not self.model.id_to_token:
# Prompt user to load token mapping
response = messagebox.askyesno(
"Token Mapping Missing",
"The model file does not contain 'token_to_id' mapping. Would you like to load it from a separate JSON file?"
"The model file does not contain 'token_to_id' and 'id_to_token' mappings. Would you like to load it from a separate JSON file?"
)
if response:
self.prompt_token_map_loading()
Expand Down

0 comments on commit 1c7cb96

Please sign in to comment.