-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
917cf7e
commit 1bebf37
Showing
14 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import sys | ||
import socket | ||
import struct | ||
import binascii | ||
from google.protobuf.json_format import MessageToJson | ||
|
||
''' | ||
getRefereeState.py is a script to get state of the referee. | ||
This includes the current command, designed position for ball placement, and the score for both teams. | ||
''' | ||
|
||
# Make sure to go back to the main roboteam directory | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) | ||
|
||
# Add to sys.path | ||
sys.path.append(roboteam_path) | ||
|
||
# Now import the generated protobuf classes | ||
from roboteam_networking.proto.ssl_gc_referee_message_pb2 import Referee | ||
from roboteam_networking.proto.ssl_gc_game_event_pb2 import GameEvent | ||
|
||
MULTICAST_GROUP = '224.5.23.1' | ||
MULTICAST_PORT = 10003 | ||
|
||
def get_referee_state(): | ||
# Create the socket | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) | ||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||
|
||
# Bind to the server address | ||
sock.bind(('', MULTICAST_PORT)) | ||
|
||
# Tell the operating system to add the socket to the multicast group | ||
group = socket.inet_aton(MULTICAST_GROUP) | ||
mreq = struct.pack('4sL', group, socket.INADDR_ANY) | ||
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) | ||
|
||
print(f"Listening for Referee messages on {MULTICAST_GROUP}:{MULTICAST_PORT}") | ||
|
||
command = None | ||
pos_x = None | ||
pos_y = None | ||
yellow_score = None | ||
blue_score = None | ||
|
||
try: | ||
data, _ = sock.recvfrom(4096) # Increased buffer size to 4096 bytes | ||
referee = Referee() | ||
referee.ParseFromString(data) | ||
command = Referee.Command.Name(referee.command) | ||
|
||
if referee.HasField('designated_position'): | ||
pos_x = referee.designated_position.x | ||
pos_y = referee.designated_position.y | ||
|
||
yellow_score = referee.yellow.score | ||
blue_score = referee.blue.score | ||
except Exception as e: | ||
print(f"Error parsing message: {e}") | ||
finally: | ||
sock.close() | ||
|
||
return command, pos_x, pos_y, yellow_score, blue_score | ||
|
||
if __name__ == "__main__": | ||
command, pos_x, pos_y, yellow_score, blue_score = get_referee_state() | ||
print(f"Command: {command}") | ||
print(f"Designated Position: ({pos_x}, {pos_y})") | ||
print(f"Yellow Team Score: {yellow_score}") | ||
print(f"Blue Team Score: {blue_score}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#include <zmq.hpp> | ||
#include <string> | ||
#include <iostream> | ||
#include <vector> | ||
#include "ActionCommand.pb.h" | ||
|
||
int main() { | ||
zmq::context_t context(1); | ||
zmq::socket_t socket(context, ZMQ_SUB); | ||
|
||
std::cout << "Connecting to ActionCommand sender..." << std::endl; | ||
socket.connect("tcp://localhost:5555"); | ||
socket.setsockopt(ZMQ_SUBSCRIBE, "", 0); | ||
|
||
std::cout << "ActionCommand receiver started. Ctrl+C to exit." << std::endl; | ||
|
||
while (true) { | ||
zmq::message_t message; | ||
socket.recv(&message); | ||
|
||
ActionCommand action_command; | ||
if (!action_command.ParseFromArray(message.data(), message.size())) { | ||
std::cerr << "Failed to parse ActionCommand." << std::endl; | ||
continue; | ||
} | ||
|
||
if (action_command.numrobots_size() != 3) { | ||
std::cerr << "Received incorrect number of values. Expected 3, got " | ||
<< action_command.numrobots_size() << std::endl; | ||
continue; | ||
} | ||
|
||
std::cout << "Received: [" | ||
<< action_command.numrobots(0) << ", " | ||
<< action_command.numrobots(1) << ", " | ||
<< action_command.numrobots(2) << "]" << std::endl; | ||
} | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
resetRefereeState is a script to reset the referee state of the game (it basically resets the match (clock)) | ||
""" | ||
|
||
import asyncio | ||
import websockets | ||
from google.protobuf.json_format import MessageToJson, Parse | ||
import os | ||
import sys | ||
|
||
# Make sure to go back to the main roboteam directory | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) | ||
|
||
# Add to sys.path | ||
sys.path.append(roboteam_path) | ||
|
||
# Now import the generated protobuf classes | ||
from roboteam_networking.proto.ssl_gc_api_pb2 import Input | ||
from roboteam_networking.proto.ssl_gc_change_pb2 import Change | ||
from roboteam_networking.proto.ssl_gc_common_pb2 import Team | ||
from roboteam_networking.proto.ssl_gc_state_pb2 import Command | ||
|
||
|
||
async def reset_and_stop_match(uri='ws://localhost:8081/api/control'): | ||
async with websockets.connect(uri) as websocket: | ||
|
||
# Step 1: Reset the match | ||
reset_message = Input(reset_match=True) | ||
await websocket.send(MessageToJson(reset_message)) | ||
response = await websocket.recv() | ||
|
||
# Step 2: Send STOP command | ||
stop_message = Input( | ||
change=Change( | ||
new_command_change=Change.NewCommand( | ||
command=Command( | ||
type=Command.Type.STOP, | ||
for_team=Team.UNKNOWN | ||
) | ||
) | ||
) | ||
) | ||
await websocket.send(MessageToJson(stop_message)) | ||
print(f"Sent STOP command: {MessageToJson(stop_message)}") | ||
response = await websocket.recv() | ||
|
||
print("Reset and STOP commands sent to SSL Game Controller") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(reset_and_stop_match()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import gymnasium as gym | ||
from stable_baselines3 import PPO | ||
|
||
# Import your custom environment | ||
from env import RoboTeamEnv | ||
|
||
# Create the environment | ||
env = RoboTeamEnv() | ||
|
||
# Create and train the PPO model | ||
model = PPO("MultiInputPolicy", env, verbose=1) | ||
model.learn(total_timesteps=1000) | ||
|
||
# Test the trained model | ||
obs, _ = env.reset() | ||
for i in range(1000): | ||
action, _states = model.predict(obs, deterministic=True) | ||
obs, reward, terminated, truncated, info = env.step(action) | ||
env.render() | ||
if terminated or truncated: | ||
obs, _ = env.reset() | ||
|
||
env.close() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
### Explanation | ||
roboteam_mpi is meant to house all the communication for MPI (message processing interface) to work on HPC clusters. | ||
|
||
|
||
This will be a framework where other teams can attach their AI to. | ||
|
||
|
||
## Core components | ||
|
||
# MPIManager | ||
Handles initialization, finalization and standard operations | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# mpi_combined.py | ||
from mpi4py import MPI | ||
import sys | ||
|
||
def main(): | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
|
||
print(f"Process {rank}: I am rank {rank} out of {size} processes") | ||
sys.stdout.flush() | ||
|
||
if rank == 0: | ||
number = 42 | ||
print(f"Process {rank}: Sending number {number} to rank 1") | ||
sys.stdout.flush() | ||
comm.send(number, dest=1, tag=11) | ||
print(f"Process {rank}: Number {number} sent to rank 1") | ||
sys.stdout.flush() | ||
elif rank == 1: | ||
print(f"Process {rank}: Waiting to receive number from rank 0") | ||
sys.stdout.flush() | ||
number = comm.recv(source=0, tag=11) | ||
print(f"Process {rank}: Received number {number} from rank 0") | ||
sys.stdout.flush() | ||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <mpi.h> | ||
#include <iostream> | ||
|
||
int main(int argc, char *argv[]){ | ||
|
||
MPI_Init(&argc, &argv); | ||
|
||
int world_size; //World size is total amount of processes | ||
MPI_Comm_size(MPI_COMM_WORLD, &world_size); | ||
|
||
int world_rank; | ||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | ||
|
||
if (world_rank == 0) | ||
{ | ||
// Sending a message | ||
const int message = 42; | ||
MPI_Send(&message, // | ||
1, | ||
MPI_INT, | ||
1, | ||
0, | ||
MPI_COMM_WORLD); | ||
|
||
std::cout << "Process 0 sends number " << message << " to process 1\n"; | ||
} | ||
else if (world_rank == 1) | ||
{ | ||
// Receiving a message | ||
int received_message; | ||
MPI_Recv(&received_message, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); | ||
std::cout << "Process 1 received number " << received_message << " from process 0\n"; | ||
} | ||
|
||
MPI_Finalize(); | ||
return 0; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from mpi4py import MPI | ||
|
||
def main(): | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
|
||
if rank == 1: | ||
while True: | ||
message = comm.recv(source=0, tag=11) | ||
print(f"Receiver: {message}") | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from mpi4py import MPI | ||
|
||
def main(): | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
|
||
if rank == 0: | ||
message = "Test" | ||
comm.send(message, dest=1, tag=11) | ||
print(f"Sender: {message}") | ||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from mpi4py import MPI | ||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
|
||
print(f"Hello from process {rank} out of {size} processes") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "MPIManager.h" | ||
#include <stdexcept> | ||
|
||
bool MPIManager::initialized = false; | ||
int MPIManager::world_rank; | ||
int MPIManager::world_size; | ||
|
||
void MPIManager::init(int& argc, char**& argv) { | ||
if (!initialized) { | ||
MPI_Init(&argc, &argv); | ||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); // Unique ID for every process | ||
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // Total number of processes | ||
initialized = true; | ||
} | ||
} | ||
|
||
void MPIManager::finalize() { | ||
if (initialized) { | ||
MPI_Finalize(); | ||
initialized = false; | ||
} | ||
} | ||
|
||
int MPIManager::getRank() { | ||
if (!initialized) throw std::runtime_error("MPI not initialized"); | ||
return rank; | ||
} | ||
|
||
int MPIManager::getSize() { | ||
if (!initialized) throw std::runtime_error("MPI not initialized"); | ||
return size; | ||
} | ||
|
||
void MPIManager::send(const void* data, int count, MPI_Datatype datatype, int dest, int tag) { | ||
if (!initialized) throw std::runtime_error("MPI not initialized"); | ||
MPI_Send(data, count, datatype, dest, tag, MPI_COMM_WORLD); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef MPI_MANAGER_H | ||
#define MPI_MANAGER_H | ||
|
||
#include <mpi.h> | ||
#include <vector> | ||
#include <string> | ||
|
||
class MPIManager { | ||
public: | ||
static void init(int& argc, char**& argv); | ||
static void finalize(); | ||
static int getRank(); | ||
static int getSize(); | ||
|
||
static void send(const void* data, int count, MPI_Datatype datatype, int dest, int tag); | ||
static void recv(void* data, int count, MPI_Datatype datatype, int source, int tag); | ||
static void bcast(void* data, int count, MPI_Datatype datatype, int root); | ||
|
||
private: | ||
static bool initialized; | ||
static int rank; | ||
static int size; | ||
|
||
MPIManager() = delete; // Prevent instantiation | ||
}; |
Submodule ssl-game-controller
added at
feb9b7