From dfd219404c1f7c093d8cd2ca844d335b2ce3d896 Mon Sep 17 00:00:00 2001 From: nesitor Date: Wed, 24 Jan 2024 12:29:24 +0100 Subject: [PATCH] Allow send custom VRF `request_id` (#29) * Problem: Is not possible to send a custom request_id on the API, because it creates it automatically. Solution: Allow to pass a custom request_id field. If the request_id is already used, we will return it. If not, we will use it on the VRF. If it's not set, we create a new one using uuid, like before. * Fix: Added method to check the message integrity checking also the previous send messages of all executors. * Fix: Solved some PR issues. * Fix: Solved last PR comment issue. * Fix: Accept request_id parameter also on the coordinator API. * Fix: Solved error code to return in case of failure and 2 more issues related with SDK updates. --------- Co-authored-by: Andres D. Molins --- .../coordinator/executor_selection.py | 8 +- src/aleph_vrf/coordinator/main.py | 21 ++- src/aleph_vrf/coordinator/vrf.py | 147 +++++++++++++++--- src/aleph_vrf/exceptions.py | 44 +++++- src/aleph_vrf/executor/main.py | 14 +- src/aleph_vrf/models.py | 57 ++++++- src/aleph_vrf/utils.py | 1 - 7 files changed, 248 insertions(+), 44 deletions(-) diff --git a/src/aleph_vrf/coordinator/executor_selection.py b/src/aleph_vrf/coordinator/executor_selection.py index e3aef64..1ec9c1b 100644 --- a/src/aleph_vrf/coordinator/executor_selection.py +++ b/src/aleph_vrf/coordinator/executor_selection.py @@ -1,14 +1,14 @@ import abc import json -from pathlib import Path -from typing import List, Dict, Any, AsyncIterator import random +from pathlib import Path +from typing import Any, AsyncIterator, Dict, List import aiohttp from aleph_message.models import ItemHash -from aleph_vrf.exceptions import NotEnoughExecutors, AlephNetworkError -from aleph_vrf.models import Executor, Node, AlephExecutor, ComputeResourceNode +from aleph_vrf.exceptions import AlephNetworkError, NotEnoughExecutors +from aleph_vrf.models import AlephExecutor, ComputeResourceNode, Executor, Node from aleph_vrf.settings import settings diff --git a/src/aleph_vrf/coordinator/main.py b/src/aleph_vrf/coordinator/main.py index 8d85560..43ed916 100644 --- a/src/aleph_vrf/coordinator/main.py +++ b/src/aleph_vrf/coordinator/main.py @@ -1,5 +1,7 @@ import logging -from typing import Union +from typing import Optional, Union + +from pydantic import BaseModel from aleph_vrf.settings import settings @@ -10,11 +12,11 @@ from aleph.sdk.vm.cache import VmCache logger.debug("import fastapi") -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException logger.debug("local imports") from aleph_vrf.coordinator.vrf import generate_vrf -from aleph_vrf.models import APIResponse, PublishedVRFResponse, APIError +from aleph_vrf.models import APIError, APIResponse, PublishedVRFResponse logger.debug("imports done") @@ -23,6 +25,10 @@ cache = VmCache() +class VRFRequest(BaseModel): + request_id: Optional[str] + + @app.get("/") async def index(): return { @@ -34,7 +40,9 @@ async def index(): @app.post("/vrf") -async def receive_vrf() -> APIResponse[Union[PublishedVRFResponse, APIError]]: +async def receive_vrf( + request: Optional[VRFRequest] = None, +) -> APIResponse[Union[PublishedVRFResponse, APIError]]: """ Goes through the VRF random number generation process and returns a random number along with details on how the number was generated. @@ -44,9 +52,10 @@ async def receive_vrf() -> APIResponse[Union[PublishedVRFResponse, APIError]]: response: Union[PublishedVRFResponse, APIError] + request_id = request.request_id if request and request.request_id else None try: - response = await generate_vrf(account) + response = await generate_vrf(account=account, request_id=request_id) except Exception as err: - response = APIError(error=str(err)) + raise HTTPException(status_code=500, detail=str(err)) return APIResponse(data=response) diff --git a/src/aleph_vrf/coordinator/vrf.py b/src/aleph_vrf/coordinator/vrf.py index 56b663c..a9da23d 100644 --- a/src/aleph_vrf/coordinator/vrf.py +++ b/src/aleph_vrf/coordinator/vrf.py @@ -2,13 +2,13 @@ import json import logging from hashlib import sha3_256 -from typing import Dict, List, Type, TypeVar, Union, Optional +from typing import Dict, List, Optional, Type, TypeVar, Union from uuid import uuid4 import aiohttp from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.client import AuthenticatedAlephClient -from aleph_message.models import ItemHash +from aleph_message.models import ItemHash, MessageType, PostMessage from aleph_message.status import MessageStatus from hexbytes import HexBytes from pydantic import BaseModel @@ -19,37 +19,33 @@ ExecutorSelectionPolicy, ) from aleph_vrf.exceptions import ( - HashValidationFailed, AlephNetworkError, ExecutorHttpError, - RandomNumberPublicationFailed, - RandomNumberGenerationFailed, HashesDoNotMatch, + HashValidationFailed, + PublishedHashesDoNotMatch, + PublishedHashValidationFailed, + RandomNumberGenerationFailed, + RandomNumberPublicationFailed, ) from aleph_vrf.models import ( + Executor, ExecutorVRFResponse, - VRFRequest, - VRFResponse, - PublishedVRFRandomNumberHash, PublishedVRFRandomNumber, - Executor, + PublishedVRFRandomNumberHash, PublishedVRFResponse, + VRFRequest, + VRFResponse, ) from aleph_vrf.settings import settings -from aleph_vrf.types import RequestId, Nonce -from aleph_vrf.utils import ( - generate_nonce, - verify, - xor_all, -) +from aleph_vrf.types import Nonce, RequestId +from aleph_vrf.utils import generate_nonce, verify, xor_all VRF_FUNCTION_GENERATE_PATH = "generate" VRF_FUNCTION_PUBLISH_PATH = "publish" - logger = logging.getLogger(__name__) - M = TypeVar("M", bound=BaseModel) @@ -66,12 +62,28 @@ async def post_executor_api_request(url: str, model: Type[M]) -> M: return model.parse_obj(response["data"]) +async def prepare_executor_api_request(url: str) -> bool: + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=120) as resp: + try: + resp.raise_for_status() + except aiohttp.ClientResponseError as error: + raise ExecutorHttpError( + url=url, status_code=resp.status, response_text=await resp.text() + ) from error + + response = await resp.json() + + return response["name"] == "vrf_generate_api" + + async def _generate_vrf( aleph_client: AuthenticatedAlephClient, nb_executors: int, nb_bytes: int, vrf_function: ItemHash, executor_selection_policy: ExecutorSelectionPolicy, + request_id: Optional[str] = None, ) -> PublishedVRFResponse: executors = await executor_selection_policy.select_executors(nb_executors) selected_nodes_json = json.dumps( @@ -80,12 +92,20 @@ async def _generate_vrf( nonce = generate_nonce() + if request_id: + existing_message = await get_existing_vrf_message(aleph_client, request_id) + if existing_message: + message = PublishedVRFResponse.from_vrf_post_message(existing_message) + await check_message_integrity(aleph_client, message) + + return message + vrf_request = VRFRequest( nb_bytes=nb_bytes, nb_executors=nb_executors, nonce=nonce, vrf_function=vrf_function, - request_id=RequestId(str(uuid4())), + request_id=RequestId(request_id or str(uuid4())), node_list_hash=sha3_256(selected_nodes_json).hexdigest(), ) @@ -136,6 +156,7 @@ async def _generate_vrf( async def generate_vrf( account: ETHAccount, + request_id: Optional[str] = None, nb_executors: Optional[int] = None, nb_bytes: Optional[int] = None, vrf_function: Optional[ItemHash] = None, @@ -152,6 +173,7 @@ async def generate_vrf( ) as aleph_client: return await _generate_vrf( aleph_client=aleph_client, + request_id=request_id, nb_executors=nb_executors or settings.NB_EXECUTORS, nb_bytes=nb_bytes or settings.NB_BYTES, vrf_function=vrf_function or settings.FUNCTION, @@ -292,3 +314,92 @@ async def publish_data( ) return message.item_hash + + +async def get_existing_vrf_message( + aleph_client: AuthenticatedAlephClient, + request_id: str, +) -> Optional[PostMessage]: + channel = f"vrf_{request_id}" + ref = f"vrf_{request_id}" + + logger.debug( + f"Getting VRF messages on {aleph_client.api_server} from request id {request_id}" + ) + + messages = await aleph_client.get_messages( + message_type=MessageType.post, + channels=[channel], + refs=[ref], + ) + + if messages.messages: + if len(messages.messages) > 1: + logger.warning(f"Multiple VRF messages found for request id {request_id}") + return messages.messages[0] + else: + logger.debug(f"Existing VRF message for request id {request_id} not found") + return None + + +async def get_existing_message( + aleph_client: AuthenticatedAlephClient, + item_hash: ItemHash, +) -> Optional[PostMessage]: + logger.debug( + f"Getting VRF message on {aleph_client.api_server} for item_hash {item_hash}" + ) + + message = await aleph_client.get_message( + item_hash=item_hash, + ) + + if not message: + raise AlephNetworkError( + f"Message could not be read for item_hash {message.item_hash}" + ) + + return message + + +async def check_message_integrity( + aleph_client: AuthenticatedAlephClient, vrf_response: PublishedVRFResponse +): + logger.debug( + f"Checking VRF response message on {aleph_client.api_server} for item_hash {vrf_response.message_hash}" + ) + + for executor in vrf_response.executors: + generation_message = await get_existing_message( + aleph_client, executor.generation_message_hash + ) + loaded_generation_message = PublishedVRFRandomNumberHash.from_published_message( + generation_message + ) + publish_message = await get_existing_message( + aleph_client, executor.publication_message_hash + ) + loaded_publish_message = PublishedVRFRandomNumber.from_published_message( + publish_message + ) + + if ( + loaded_generation_message.random_number_hash + != loaded_publish_message.random_number_hash + ): + raise PublishedHashesDoNotMatch( + executor=executor, + generation_hash=loaded_generation_message.random_number_hash, + publication_hash=loaded_publish_message.random_number_hash, + ) + + if not verify( + HexBytes(loaded_publish_message.random_number), + loaded_generation_message.nonce, + loaded_generation_message.random_number_hash, + ): + raise PublishedHashValidationFailed( + executor=executor, + random_number=loaded_publish_message, + random_number_hash=loaded_generation_message.random_number_hash, + ) diff --git a/src/aleph_vrf/exceptions.py b/src/aleph_vrf/exceptions.py index 604b159..2c237f5 100644 --- a/src/aleph_vrf/exceptions.py +++ b/src/aleph_vrf/exceptions.py @@ -1,4 +1,4 @@ -from aleph_vrf.models import Executor, PublishedVRFRandomNumber +from aleph_vrf.models import Executor, ExecutorVRFResponse, PublishedVRFRandomNumber class VrfException(Exception): @@ -72,6 +72,25 @@ def __str__(self): ) +class PublishedHashesDoNotMatch(VrfException): + """ + The random number hash received from /publish is different from the one received from /generate. + """ + + def __init__( + self, executor: ExecutorVRFResponse, generation_hash: str, publication_hash: str + ): + self.executor = executor + self.generation_hash = generation_hash + self.publication_hash = publication_hash + + def __str__(self): + return ( + f"Published random number hash ({self.publication_hash})" + f"does not match the generated one ({self.generation_hash})." + ) + + class HashValidationFailed(VrfException): """ A random number does not match the SHA3 hash sent by the executor. @@ -95,6 +114,29 @@ def __str__(self): ) +class PublishedHashValidationFailed(VrfException): + """ + A random number does not match the SHA3 hash sent by the executor. + """ + + def __init__( + self, + random_number: PublishedVRFRandomNumber, + random_number_hash: str, + executor: ExecutorVRFResponse, + ): + self.random_number = random_number + self.random_number_hash = random_number_hash + self.executor = executor + + def __str__(self): + return ( + f"The random number published by {self.executor.url} " + f"(execution ID: {self.random_number.execution_id}) " + "does not match the hash." + ) + + class NotEnoughExecutors(VrfException): """ There are not enough executors available to satisfy the user requirements. diff --git a/src/aleph_vrf/executor/main.py b/src/aleph_vrf/executor/main.py index b7c8646..827cd71 100644 --- a/src/aleph_vrf/executor/main.py +++ b/src/aleph_vrf/executor/main.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Dict, Union, Set +from typing import Dict, Set, Union from uuid import uuid4 from aleph_vrf.exceptions import AlephNetworkError @@ -26,17 +26,17 @@ from aleph_message.status import MessageStatus logger.debug("import fastapi") -from fastapi import FastAPI, Depends +from fastapi import Depends, FastAPI logger.debug("local imports") from aleph_vrf.models import ( APIResponse, + PublishedVRFRandomNumber, + PublishedVRFRandomNumberHash, VRFRandomNumber, VRFRandomNumberHash, - get_vrf_request_from_message, get_random_number_hash_from_message, - PublishedVRFRandomNumberHash, - PublishedVRFRandomNumber, + get_vrf_request_from_message, ) from aleph_vrf.utils import generate @@ -115,7 +115,9 @@ async def receive_generate( detail=f"A random number has already been generated for request {vrf_request_hash}", ) - random_number, random_number_hash = generate(vrf_request.nb_bytes, vrf_request.nonce) + random_number, random_number_hash = generate( + vrf_request.nb_bytes, vrf_request.nonce + ) GENERATED_NUMBERS[execution_id] = random_number ANSWERED_REQUESTS.add(vrf_request.request_id) diff --git a/src/aleph_vrf/models.py b/src/aleph_vrf/models.py index 8ee1fc6..ad9b24a 100644 --- a/src/aleph_vrf/models.py +++ b/src/aleph_vrf/models.py @@ -1,14 +1,14 @@ -from typing import List -from typing import TypeVar, Generic +from __future__ import annotations + +from typing import Generic, List, TypeVar import fastapi from aleph_message.models import ItemHash, PostMessage from aleph_message.models.abstract import HashableModel -from pydantic import BaseModel -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from pydantic.generics import GenericModel -from aleph_vrf.types import Nonce, RequestId, ExecutionId +from aleph_vrf.types import ExecutionId, Nonce, RequestId class Node(HashableModel): @@ -77,7 +77,7 @@ class PublishedVRFRandomNumberHash(VRFRandomNumberHash): @classmethod def from_vrf_response_hash( cls, vrf_response_hash: VRFRandomNumberHash, message_hash: ItemHash - ) -> "PublishedVRFRandomNumberHash": + ) -> PublishedVRFRandomNumberHash: return cls( nb_bytes=vrf_response_hash.nb_bytes, nonce=vrf_response_hash.nonce, @@ -88,6 +88,21 @@ def from_vrf_response_hash( message_hash=message_hash, ) + @classmethod + def from_published_message( + cls, message: PostMessage + ) -> PublishedVRFRandomNumberHash: + vrf_response_hash = VRFRandomNumberHash.parse_obj(message.content.content) + return cls( + nb_bytes=vrf_response_hash.nb_bytes, + nonce=vrf_response_hash.nonce, + request_id=vrf_response_hash.request_id, + execution_id=vrf_response_hash.execution_id, + vrf_request=vrf_response_hash.vrf_request, + random_number_hash=vrf_response_hash.random_number_hash, + message_hash=message.item_hash, + ) + def get_random_number_hash_from_message( message: PostMessage, @@ -120,7 +135,7 @@ class PublishedVRFRandomNumber(VRFRandomNumber): @classmethod def from_vrf_random_number( cls, vrf_random_number: VRFRandomNumber, message_hash: ItemHash - ) -> "PublishedVRFRandomNumber": + ) -> PublishedVRFRandomNumber: return cls( request_id=vrf_random_number.request_id, execution_id=vrf_random_number.execution_id, @@ -130,6 +145,18 @@ def from_vrf_random_number( message_hash=message_hash, ) + @classmethod + def from_published_message(cls, message: PostMessage) -> PublishedVRFRandomNumber: + vrf_random_number = VRFRandomNumber.parse_obj(message.content.content) + return cls( + request_id=vrf_random_number.request_id, + execution_id=vrf_random_number.execution_id, + vrf_request=vrf_random_number.vrf_request, + random_number=vrf_random_number.random_number, + random_number_hash=vrf_random_number.random_number_hash, + message_hash=message.item_hash, + ) + class ExecutorVRFResponse(BaseModel): url: str @@ -156,7 +183,7 @@ class PublishedVRFResponse(VRFResponse): @classmethod def from_vrf_response( cls, vrf_response: VRFResponse, message_hash: ItemHash - ) -> "PublishedVRFResponse": + ) -> PublishedVRFResponse: return cls( nb_bytes=vrf_response.nb_bytes, nb_executors=vrf_response.nb_executors, @@ -168,6 +195,20 @@ def from_vrf_response( message_hash=message_hash, ) + @classmethod + def from_vrf_post_message(cls, post_message: PostMessage) -> PublishedVRFResponse: + vrf_response = VRFResponse.parse_obj(post_message.content.content) + return cls( + nb_bytes=vrf_response.nb_bytes, + nb_executors=vrf_response.nb_executors, + nonce=vrf_response.nonce, + vrf_function=vrf_response.vrf_function, + request_id=vrf_response.request_id, + executors=vrf_response.executors, + random_number=vrf_response.random_number, + message_hash=post_message.item_hash, + ) + M = TypeVar("M", bound=BaseModel) diff --git a/src/aleph_vrf/utils.py b/src/aleph_vrf/utils.py index dc6d79b..8330e97 100644 --- a/src/aleph_vrf/utils.py +++ b/src/aleph_vrf/utils.py @@ -6,7 +6,6 @@ from aleph_vrf.types import Nonce - # Used for compatibility with HexBytes or any class that inherits bytes BytesLike = TypeVar("BytesLike", bound="bytes")