Skip to content

Commit

Permalink
Add save state support
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel committed Sep 21, 2022
1 parent 353bc0c commit d999127
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 13 deletions.
58 changes: 57 additions & 1 deletion gambaterm/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,24 @@ class Input(IntEnum):
UP = 0x40
DOWN = 0x80

class Event(IntEnum):
SELECT_STATE_0 = 0
SELECT_STATE_1 = 1
SELECT_STATE_2 = 2
SELECT_STATE_3 = 3
SELECT_STATE_4 = 4
SELECT_STATE_5 = 5
SELECT_STATE_6 = 6
SELECT_STATE_7 = 7
SELECT_STATE_8 = 8
SELECT_STATE_9 = 9
INCREMENT_STATE = 10
DECREMENT_STATE = 11
LOAD_STATE = 12
SAVE_STATE = 13

romfile: str
last_video: npt.NDArray[np.uint32] | None

@classmethod
def add_console_arguments(cls, parser: argparse.ArgumentParser) -> None:
Expand All @@ -36,14 +53,40 @@ def add_console_arguments(cls, parser: argparse.ArgumentParser) -> None:
def __init__(self, args: argparse.Namespace):
self.romfile = args.romfile

def set_input(self, value: set[Console.Input]) -> None:
def set_input(self, value: set[Input]) -> None:
pass

def advance_one_frame(
self, video: npt.NDArray[np.uint32], audio: npt.NDArray[np.int16]
) -> tuple[int, int]:
raise NotImplementedError

def get_current_state(self) -> int:
raise NotImplementedError

def set_current_state(self, state: int) -> None:
raise NotImplementedError

def load_state(self) -> None:
raise NotImplementedError

def save_state(self) -> None:
raise NotImplementedError

def handle_event(self, event: Event) -> None:
if event.value < 10:
self.set_current_state(event.value)
elif event == event.INCREMENT_STATE:
self.set_current_state(self.get_current_state() + 1)
elif event == event.DECREMENT_STATE:
self.set_current_state(self.get_current_state() - 1)
elif event == event.LOAD_STATE:
self.load_state()
elif event == event.SAVE_STATE:
self.save_state()
else:
assert False


# Type Alias
InputGetter = Callable[[], Set[Console.Input]]
Expand Down Expand Up @@ -93,4 +136,17 @@ def set_input(self, input_set: set[Console.Input]) -> None:
def advance_one_frame(
self, video: npt.NDArray[np.uint32], audio: npt.NDArray[np.int16]
) -> tuple[int, int]:
self.last_video = video
return self.gb.run_for(video, self.WIDTH, audio, self.TICKS_IN_FRAME)

def get_current_state(self) -> int:
return self.gb.current_state() % 10

def set_current_state(self, value: int) -> None:
self.gb.select_state(value % 10)

def load_state(self) -> None:
self.gb.load_state()

def save_state(self) -> None:
self.gb.save_state(self.last_video, self.WIDTH)
30 changes: 24 additions & 6 deletions gambaterm/controller_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .console import Console, InputGetter


def get_controller_mapping(console: Console) -> dict[str, Console.Input]:
def get_controller_input_mapping(console: Console) -> dict[str, Console.Input]:
return {
# Directions
"A1-": console.Input.UP,
Expand All @@ -31,6 +31,16 @@ def get_controller_mapping(console: Console) -> dict[str, Console.Input]:
}


def get_controller_event_mapping(console: Console) -> dict[str, Console.Event]:
return {
# Directions
"B4": console.Event.INCREMENT_STATE,
"B5": console.Event.DECREMENT_STATE,
"B8": console.Event.LOAD_STATE,
"B9": console.Event.SAVE_STATE,
}


@contextmanager
def pygame_button_pressed_context(
deadzone: float = 0.4,
Expand Down Expand Up @@ -88,16 +98,24 @@ def get_pressed() -> set[str]:

@contextmanager
def console_input_from_controller_context(console: Console) -> Iterator[InputGetter]:
controller_mapping = get_controller_mapping(console)
input_mapping = get_controller_input_mapping(console)
event_mapping = get_controller_event_mapping(console)
current_pressed: set[str] = set()

def get_gb_input() -> set[Console.Input]:
nonlocal current_pressed
old_pressed, current_pressed = current_pressed, set(get_pressed())
for event in map(event_mapping.get, current_pressed - old_pressed):
if event is None:
continue
console.handle_event(event)
return {
controller_mapping[keysym]
for keysym in joystick_get_pressed()
if keysym in controller_mapping
input_mapping[keysym]
for keysym in current_pressed
if keysym in input_mapping
}

with pygame_button_pressed_context() as joystick_get_pressed:
with pygame_button_pressed_context() as get_pressed:
yield get_gb_input


Expand Down
61 changes: 55 additions & 6 deletions gambaterm/keyboard_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pynput # type: ignore


def get_xlib_mapping(console: Console) -> dict[int, Console.Input]:
def get_xlib_input_mapping(console: Console) -> dict[int, Console.Input]:
from Xlib import XK # type: ignore

return {
Expand All @@ -40,7 +40,26 @@ def get_xlib_mapping(console: Console) -> dict[int, Console.Input]:
}


def get_keyboard_mapping(console: Console) -> dict[str, Console.Input]:
def get_xlib_event_mapping(console: Console) -> dict[int, Console.Event]:
from Xlib import XK # type: ignore

return {
XK.XK_0: console.Event.SELECT_STATE_0,
XK.XK_1: console.Event.SELECT_STATE_1,
XK.XK_2: console.Event.SELECT_STATE_2,
XK.XK_3: console.Event.SELECT_STATE_3,
XK.XK_4: console.Event.SELECT_STATE_4,
XK.XK_5: console.Event.SELECT_STATE_5,
XK.XK_6: console.Event.SELECT_STATE_6,
XK.XK_7: console.Event.SELECT_STATE_7,
XK.XK_8: console.Event.SELECT_STATE_8,
XK.XK_9: console.Event.SELECT_STATE_9,
XK.XK_l: console.Event.LOAD_STATE,
XK.XK_k: console.Event.SAVE_STATE,
}


def get_keyboard_input_mapping(console: Console) -> dict[str, Console.Input]:
return {
# Directions
"up": console.Input.UP,
Expand All @@ -65,6 +84,23 @@ def get_keyboard_mapping(console: Console) -> dict[str, Console.Input]:
}


def get_keyboard_event_mapping(console: Console) -> dict[str, Console.Event]:
return {
"0": console.Event.SELECT_STATE_0,
"1": console.Event.SELECT_STATE_1,
"2": console.Event.SELECT_STATE_2,
"3": console.Event.SELECT_STATE_3,
"4": console.Event.SELECT_STATE_4,
"5": console.Event.SELECT_STATE_5,
"6": console.Event.SELECT_STATE_6,
"7": console.Event.SELECT_STATE_7,
"8": console.Event.SELECT_STATE_8,
"9": console.Event.SELECT_STATE_9,
"l": console.Event.LOAD_STATE,
"k": console.Event.SAVE_STATE,
}


@contextmanager
def xlib_key_pressed_context(
display: str | None = None,
Expand Down Expand Up @@ -185,14 +221,28 @@ def console_input_from_keyboard_context(
console: Console, display: str | None = None
) -> Iterator[InputGetter]:
if sys.platform == "linux":
mapping = get_xlib_mapping(console)
current_pressed: set[int] = set()
input_mapping = get_xlib_input_mapping(console)
event_mapping = get_xlib_event_mapping(console)
key_pressed_context = xlib_key_pressed_context
else:
mapping = get_keyboard_mapping(console)
current_pressed: set[str] = set()
input_mapping = get_keyboard_input_mapping(console)
event_mapping = get_keyboard_event_mapping(console)
key_pressed_context = pynput_key_pressed_context

def get_input() -> set[Console.Input]:
return {mapping[keysym] for keysym in get_pressed() if keysym in mapping}
nonlocal current_pressed
old_pressed, current_pressed = current_pressed, set(get_pressed())
for event in map(event_mapping.get, current_pressed - old_pressed):
if event is None:
continue
console.handle_event(event)
return {
input_mapping[keysym]
for keysym in current_pressed
if keysym in input_mapping
}

with key_pressed_context(display=display) as get_pressed:
yield get_input
Expand All @@ -208,7 +258,6 @@ def main() -> None:
}
mapping = reverse_lookup.get
else:
mapping = get_keyboard_mapping()
key_pressed_context = pynput_key_pressed_context
mapping = str

Expand Down
4 changes: 4 additions & 0 deletions gambaterm/libgambatte.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ class GB:
) -> tuple[int, int]: ...
def set_input(self, value: int) -> None: ...
def set_save_directory(self, path: str) -> None: ...
def current_state(self) -> int: ...
def select_state(self, state: int) -> None: ...
def load_state(self) -> bool: ...
def save_state(self, video: npt.NDArray[np.uint32] | None, pitch: int) -> bool: ...
6 changes: 6 additions & 0 deletions libgambatte_ext/_libgambatte.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ cdef extern from "gambatte.h" namespace "gambatte":
void setInputGetter(GetInput *getInput);
void setSaveDir(string& sdir);

# Save state
void selectState(int state);
int currentState();
int loadState();
int saveState(unsigned int *videoBuf, ptrdiff_t pitch);


cdef extern from "input.h" namespace "gambatte":
cdef cppclass GetInput:
Expand Down
19 changes: 19 additions & 0 deletions libgambatte_ext/libgambatte.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,22 @@ cdef class GB:

def set_save_directory(self, str path):
self.c_gb.setSaveDir(path.encode())

def select_state(self, int state):
self.c_gb.selectState(state)

def current_state(self):
return self.c_gb.currentState()

def load_state(self):
return self.c_gb.loadState()

def save_state(
self,
np.ndarray[np.uint32_t, ndim=2] video,
ptrdiff_t pitch,
):
if video is None:
return self.c_gb.saveState(NULL, 0)
cdef unsigned int* video_buffer = <unsigned int*> video.data
return self.c_gb.saveState(video_buffer, pitch)

0 comments on commit d999127

Please sign in to comment.