Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
flimdejong committed Oct 14, 2024
1 parent 917cf7e commit 1bebf37
Show file tree
Hide file tree
Showing 14 changed files with 362 additions and 0 deletions.
72 changes: 72 additions & 0 deletions roboteam_ai/src/RL/getRefereeState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import sys
import socket
import struct
import binascii
from google.protobuf.json_format import MessageToJson

'''
getRefereeState.py is a script to get state of the referee.
This includes the current command, designed position for ball placement, and the score for both teams.
'''

# Make sure to go back to the main roboteam directory
current_dir = os.path.dirname(os.path.abspath(__file__))
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))

# Add to sys.path
sys.path.append(roboteam_path)

# Now import the generated protobuf classes
from roboteam_networking.proto.ssl_gc_referee_message_pb2 import Referee
from roboteam_networking.proto.ssl_gc_game_event_pb2 import GameEvent

MULTICAST_GROUP = '224.5.23.1'
MULTICAST_PORT = 10003

def get_referee_state():
# Create the socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

# Bind to the server address
sock.bind(('', MULTICAST_PORT))

# Tell the operating system to add the socket to the multicast group
group = socket.inet_aton(MULTICAST_GROUP)
mreq = struct.pack('4sL', group, socket.INADDR_ANY)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)

print(f"Listening for Referee messages on {MULTICAST_GROUP}:{MULTICAST_PORT}")

command = None
pos_x = None
pos_y = None
yellow_score = None
blue_score = None

try:
data, _ = sock.recvfrom(4096) # Increased buffer size to 4096 bytes
referee = Referee()
referee.ParseFromString(data)
command = Referee.Command.Name(referee.command)

if referee.HasField('designated_position'):
pos_x = referee.designated_position.x
pos_y = referee.designated_position.y

yellow_score = referee.yellow.score
blue_score = referee.blue.score
except Exception as e:
print(f"Error parsing message: {e}")
finally:
sock.close()

return command, pos_x, pos_y, yellow_score, blue_score

if __name__ == "__main__":
command, pos_x, pos_y, yellow_score, blue_score = get_referee_state()
print(f"Command: {command}")
print(f"Designated Position: ({pos_x}, {pos_y})")
print(f"Yellow Team Score: {yellow_score}")
print(f"Blue Team Score: {blue_score}")
40 changes: 40 additions & 0 deletions roboteam_ai/src/RL/receiveActionCommand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <zmq.hpp>
#include <string>
#include <iostream>
#include <vector>
#include "ActionCommand.pb.h"

int main() {
zmq::context_t context(1);
zmq::socket_t socket(context, ZMQ_SUB);

std::cout << "Connecting to ActionCommand sender..." << std::endl;
socket.connect("tcp://localhost:5555");
socket.setsockopt(ZMQ_SUBSCRIBE, "", 0);

std::cout << "ActionCommand receiver started. Ctrl+C to exit." << std::endl;

while (true) {
zmq::message_t message;
socket.recv(&message);

ActionCommand action_command;
if (!action_command.ParseFromArray(message.data(), message.size())) {
std::cerr << "Failed to parse ActionCommand." << std::endl;
continue;
}

if (action_command.numrobots_size() != 3) {
std::cerr << "Received incorrect number of values. Expected 3, got "
<< action_command.numrobots_size() << std::endl;
continue;
}

std::cout << "Received: ["
<< action_command.numrobots(0) << ", "
<< action_command.numrobots(1) << ", "
<< action_command.numrobots(2) << "]" << std::endl;
}

return 0;
}
51 changes: 51 additions & 0 deletions roboteam_ai/src/RL/resetRefereeState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
resetRefereeState is a script to reset the referee state of the game (it basically resets the match (clock))
"""

import asyncio
import websockets
from google.protobuf.json_format import MessageToJson, Parse
import os
import sys

# Make sure to go back to the main roboteam directory
current_dir = os.path.dirname(os.path.abspath(__file__))
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))

# Add to sys.path
sys.path.append(roboteam_path)

# Now import the generated protobuf classes
from roboteam_networking.proto.ssl_gc_api_pb2 import Input
from roboteam_networking.proto.ssl_gc_change_pb2 import Change
from roboteam_networking.proto.ssl_gc_common_pb2 import Team
from roboteam_networking.proto.ssl_gc_state_pb2 import Command


async def reset_and_stop_match(uri='ws://localhost:8081/api/control'):
async with websockets.connect(uri) as websocket:

# Step 1: Reset the match
reset_message = Input(reset_match=True)
await websocket.send(MessageToJson(reset_message))
response = await websocket.recv()

# Step 2: Send STOP command
stop_message = Input(
change=Change(
new_command_change=Change.NewCommand(
command=Command(
type=Command.Type.STOP,
for_team=Team.UNKNOWN
)
)
)
)
await websocket.send(MessageToJson(stop_message))
print(f"Sent STOP command: {MessageToJson(stop_message)}")
response = await websocket.recv()

print("Reset and STOP commands sent to SSL Game Controller")

if __name__ == "__main__":
asyncio.run(reset_and_stop_match())
24 changes: 24 additions & 0 deletions roboteam_ai/src/RL/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import gymnasium as gym
from stable_baselines3 import PPO

# Import your custom environment
from env import RoboTeamEnv

# Create the environment
env = RoboTeamEnv()

# Create and train the PPO model
model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=1000)

# Test the trained model
obs, _ = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated or truncated:
obs, _ = env.reset()

env.close()

13 changes: 13 additions & 0 deletions roboteam_mpi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### Explanation
roboteam_mpi is meant to house all the communication for MPI (message processing interface) to work on HPC clusters.


This will be a framework where other teams can attach their AI to.


## Core components

# MPIManager
Handles initialization, finalization and standard operations


28 changes: 28 additions & 0 deletions roboteam_mpi/mpi_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# mpi_combined.py
from mpi4py import MPI
import sys

def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print(f"Process {rank}: I am rank {rank} out of {size} processes")
sys.stdout.flush()

if rank == 0:
number = 42
print(f"Process {rank}: Sending number {number} to rank 1")
sys.stdout.flush()
comm.send(number, dest=1, tag=11)
print(f"Process {rank}: Number {number} sent to rank 1")
sys.stdout.flush()
elif rank == 1:
print(f"Process {rank}: Waiting to receive number from rank 0")
sys.stdout.flush()
number = comm.recv(source=0, tag=11)
print(f"Process {rank}: Received number {number} from rank 0")
sys.stdout.flush()

if __name__ == "__main__":
main()
Binary file added roboteam_mpi/mpi_message
Binary file not shown.
38 changes: 38 additions & 0 deletions roboteam_mpi/mpi_message.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <mpi.h>
#include <iostream>

int main(int argc, char *argv[]){

MPI_Init(&argc, &argv);

int world_size; //World size is total amount of processes
MPI_Comm_size(MPI_COMM_WORLD, &world_size);

int world_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);

if (world_rank == 0)
{
// Sending a message
const int message = 42;
MPI_Send(&message, //
1,
MPI_INT,
1,
0,
MPI_COMM_WORLD);

std::cout << "Process 0 sends number " << message << " to process 1\n";
}
else if (world_rank == 1)
{
// Receiving a message
int received_message;
MPI_Recv(&received_message, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
std::cout << "Process 1 received number " << received_message << " from process 0\n";
}

MPI_Finalize();
return 0;

}
14 changes: 14 additions & 0 deletions roboteam_mpi/mpi_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from mpi4py import MPI

def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

if rank == 1:
while True:
message = comm.recv(source=0, tag=11)
print(f"Receiver: {message}")

if __name__ == "__main__":
main()

12 changes: 12 additions & 0 deletions roboteam_mpi/mpi_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from mpi4py import MPI

def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

if rank == 0:
message = "Test"
comm.send(message, dest=1, tag=11)
print(f"Sender: {message}")
if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions roboteam_mpi/mpi_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print(f"Hello from process {rank} out of {size} processes")
37 changes: 37 additions & 0 deletions roboteam_mpi/src/MPIManager.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "MPIManager.h"
#include <stdexcept>

bool MPIManager::initialized = false;
int MPIManager::world_rank;
int MPIManager::world_size;

void MPIManager::init(int& argc, char**& argv) {
if (!initialized) {
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); // Unique ID for every process
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // Total number of processes
initialized = true;
}
}

void MPIManager::finalize() {
if (initialized) {
MPI_Finalize();
initialized = false;
}
}

int MPIManager::getRank() {
if (!initialized) throw std::runtime_error("MPI not initialized");
return rank;
}

int MPIManager::getSize() {
if (!initialized) throw std::runtime_error("MPI not initialized");
return size;
}

void MPIManager::send(const void* data, int count, MPI_Datatype datatype, int dest, int tag) {
if (!initialized) throw std::runtime_error("MPI not initialized");
MPI_Send(data, count, datatype, dest, tag, MPI_COMM_WORLD);
}
25 changes: 25 additions & 0 deletions roboteam_mpi/src/MPIManager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef MPI_MANAGER_H
#define MPI_MANAGER_H

#include <mpi.h>
#include <vector>
#include <string>

class MPIManager {
public:
static void init(int& argc, char**& argv);
static void finalize();
static int getRank();
static int getSize();

static void send(const void* data, int count, MPI_Datatype datatype, int dest, int tag);
static void recv(void* data, int count, MPI_Datatype datatype, int source, int tag);
static void bcast(void* data, int count, MPI_Datatype datatype, int root);

private:
static bool initialized;
static int rank;
static int size;

MPIManager() = delete; // Prevent instantiation
};
1 change: 1 addition & 0 deletions ssl-game-controller
Submodule ssl-game-controller added at feb9b7

0 comments on commit 1bebf37

Please sign in to comment.