Skip to content

Commit

Permalink
add model.py, .pre-commit-config.yaml and .flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanabrooks committed Apr 24, 2022
1 parent 3011eff commit d8e9959
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 162 deletions.
6 changes: 6 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
ignore = E203, E266, E501, E731, W503, C901, B008, E741
# 80 to use as a soft test
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
28 changes: 28 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 21.10b0
hooks:
- id: black
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.5
hooks:
- id: shellcheck
- id: shfmt
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [gql]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black"]
7 changes: 4 additions & 3 deletions gql/env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
from typing import Generator, Iterator, Optional, Tuple, Union
from typing import Generator, Optional, Tuple

import gym
import gym.spaces
import numpy as np

actions = [
ACTIONS = [
"Go left.",
"Try reward",
"Go right.",
Expand All @@ -27,7 +28,7 @@ def success_str():

@staticmethod
def action_str(action: int) -> str:
return actions[action]
return ACTIONS[action]

def generator(self) -> Generator[Tuple[int, float, bool, dict], int, None]:
state = self.random.choice(self.n)
Expand Down
3 changes: 2 additions & 1 deletion gql/example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import openai

openai.api_key = os.getenv("OPENAI_API_KEY")
Expand All @@ -7,7 +8,7 @@
You are at state 3. Go left. You are at state 2. Receive a reward.
You are at state 2. Receive a reward.
You are at state 0. Go right. You are at state 1. Go right. You are at state 2. Receive a reward.
You are at state 4. Go left. You are at 3. Go left. You are at state 2. Receive a reward.
You are at state 4. Go left. You are at 3. Go left. You are at state 2. Receive a reward.
You are at state 1. Go right. You are at state 2. Receive a reward.\
"""
response = openai.Completion.create(
Expand Down
199 changes: 41 additions & 158 deletions gql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import shelve
from collections import deque
from dataclasses import dataclass
from typing import Deque, List, Optional, Tuple, cast
from typing import List, Optional

import numpy as np
import openai
from gym.spaces import Discrete

from env import Env
from model import GPT3, Prompt, Q, V


@dataclass
Expand All @@ -24,7 +22,8 @@ def main(
goal: int = 3,
max_trajectory: int = 5,
prompt_buffer_size: int = 20,
prompt_size: int = 10,
q_prompt_size: int = 10,
v_prompt_size: int = 5,
replay_buffer_size: int = 50,
seed: int = 0,
states: int = 5,
Expand All @@ -33,58 +32,8 @@ def main(
openai.api_key = os.getenv("OPENAI_API_KEY")
env = Env(states, goal, seed)
assert batch_size <= replay_buffer_size
rng = np.random.default_rng(seed)

@dataclass
class Prompt:
state: int
action: int
value: str

@staticmethod
def make(state: int, action: int, value: str):
return Prompt(state, action, value.lstrip())

def to_string(self):
return f"{env.state_str(self.state)} {env.action_str(self.action)} {self.value}"

class PromptBuffer:
def __init__(self):
self.success_buffer = deque(maxlen=prompt_buffer_size)
self.failure_buffer = deque(maxlen=prompt_buffer_size)

def add(self, prompt: Prompt):
buffer = (
self.failure_buffer
if env.quantify(prompt.value) == 0
else self.success_buffer
)
buffer.append(prompt)

def sample(self):
num_failure = prompt_size - len(self.success_buffer)
if num_failure > 0:
failure_idxs = rng.choice(
len(self.failure_buffer), size=num_failure, replace=False
)
failure_prompts = [
self.failure_buffer[k].to_string() for k in failure_idxs
]
else:
failure_prompts = []
success_prompts = [p.to_string() for p in self.success_buffer]
prompts = failure_prompts + success_prompts
rng.shuffle(prompts)
return prompts

def ready(self) -> bool:
return (
len(self.success_buffer) + len(self.failure_buffer)
>= prompt_buffer_size
)

replay_buffer: Deque[TimeStep] = deque(maxlen=replay_buffer_size)
prompt_buffer: PromptBuffer = PromptBuffer()
last10 = deque(maxlen=10)

def evaluate_trajectory(_trajectory: List[TimeStep]) -> str:
if not _trajectory:
Expand All @@ -99,125 +48,59 @@ def evaluate_trajectory(_trajectory: List[TimeStep]) -> str:
sep = " " if tail_trajectory and reward_str else ""
return Prompt.make(
head.state, head.action, f"{reward_str}{sep}{tail_trajectory}"
).to_string()

with shelve.open("completions/completions.db") as db:

def q(state: int, action: int) -> str:
prompt = "\n".join(
[
*prompt_buffer.sample(),
f"{env.state_str(state)} {env.action_str(action)}",
]
)
if prompt in db:
value = cast(str, db[prompt])
# print("Completion:")
# print(value)
return value

print("Prompt:")
print(prompt)

# print("Querying...", end=" ")
choice, *_ = openai.Completion.create(
engine="text-davinci-002",
prompt=prompt,
temperature=0.1,
max_tokens=200,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
).choices
# print("Received response.")
value, *_ = choice.text.split("\n")
print("Completion:")
print(value)
breakpoint()
# print("Completion:")
# print(value)
db[prompt] = value
return value

def choose_action(state: int) -> Tuple[int, str]:
assert isinstance(env.action_space, Discrete)
actions = range(env.action_space.n)
values = [q(state, a) for a in actions]
action, value = max(zip(actions, values), key=lambda x: env.quantify(x[1]))
print("state", state)
print("action", action)
print("value", value)
breakpoint()
return action, value
).to_string(env)

with shelve.open("completions/completions.pkl") as db:
gpt3 = GPT3(db)
v = V(
env=env,
gpt3=gpt3,
optimistic=True,
prompt_buffer_size=prompt_buffer_size,
prompt_size=v_prompt_size,
seed=seed,
)
q = Q(
env=env,
gpt3=gpt3,
optimistic=False,
prompt_buffer_size=prompt_buffer_size,
prompt_size=q_prompt_size,
seed=seed,
)

for i in range(training_steps):
done = False
state = env.reset()
trajectory: List[TimeStep] = []
# use_v = i % 2 == 0 and v.ready()
model = q
while not done:
if len(replay_buffer) >= batch_size and prompt_buffer.ready():
action, _ = choose_action(state)
else:
action = env.action_space.sample()

action = model.act(state)
next_state, reward, done, _ = env.step(action)
step = TimeStep(state, action, reward, None if done else next_state)
if done:
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$ Reward:", reward)
if done and q.ready():
print("state", state)
print("action", action)
print("reward", reward)
last10.append(reward)
# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$ Reward:", reward)
trajectory.append(step)
state = next_state

replay_buffer.extend(trajectory)

def get_last_10():
count = 0
for ts in reversed(replay_buffer):
if count == 10:
return
if ts.next_state is None:
count += 1
yield ts

last_10 = list(get_last_10())
print(
"".join(
["#" for ts in last_10 if ts.reward == 1]
+ [
" "
for ts in last_10
if ts.next_state is None and ts.reward == 0
]
)
+ "|"
)
if v.ready():
_last10 = sorted(last10, reverse=True)
print("".join(["#" if r else " " for r in _last10]) + "|")

if len(trajectory) < max_trajectory:
head, *tail = trajectory
value = evaluate_trajectory(tail)
if not value:
value = env.reward_str(head.reward)
prompt_buffer.add(Prompt.make(head.state, head.action, value)) # TODO

if len(replay_buffer) >= batch_size:
sample = rng.choice(len(replay_buffer), size=batch_size, replace=False)
for i in sample:
transition = replay_buffer[i]
if prompt_buffer.ready():
next_action, next_value = choose_action(transition.next_state)
done = transition.next_state is None
value = (
env.reward_str(transition.reward)
if done
else Prompt.make(
transition.next_state, next_action, next_value
).to_string()
)
prompt = Prompt.make(
transition.state,
transition.action,
value,
)
prompt_buffer.add(prompt)
prompt = Prompt.make(head.state, head.action, value)
q.learn(prompt)
v.learn(prompt)
# print(len(v))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit d8e9959

Please sign in to comment.