Skip to content

Commit

Permalink
updating tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Feb 4, 2025
1 parent a831fef commit 369a641
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 70 deletions.
5 changes: 5 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ services:
- SURREAL_PASS=root
- SURREAL_INSECURE_FORWARD_ACCESS_ERRORS=true
- SURREAL_LOG=debug
- SURREAL_CAPS_ALLOW_GUESTS=true
ports:
- 8000:8000

Expand All @@ -18,6 +19,7 @@ services:
- SURREAL_USER=root
- SURREAL_PASS=root
- SURREAL_LOG=trace
- SURREAL_CAPS_ALLOW_GUESTS=true
ports:
- 8121:8000

Expand All @@ -28,6 +30,7 @@ services:
- SURREAL_USER=root
- SURREAL_PASS=root
- SURREAL_LOG=trace
- SURREAL_CAPS_ALLOW_GUESTS=true
ports:
- 8120:8000

Expand All @@ -38,6 +41,7 @@ services:
- SURREAL_USER=root
- SURREAL_PASS=root
- SURREAL_LOG=trace
- SURREAL_CAPS_ALLOW_GUESTS=true
ports:
- 8101:8000

Expand All @@ -48,5 +52,6 @@ services:
- SURREAL_USER=root
- SURREAL_PASS=root
- SURREAL_LOG=trace
- SURREAL_CAPS_ALLOW_GUESTS=true
ports:
- 8111:8000
1 change: 0 additions & 1 deletion src/surrealdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from surrealdb.data.types.geometry import Geometry
from surrealdb.data.types.range import Range
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.datetime import DatetimeWrapper
from surrealdb.data.types.datetime import IsoDateTimeWrapper

class AsyncSurrealDBMeta(type):
Expand Down
2 changes: 1 addition & 1 deletion src/surrealdb/connections/async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def authenticate(self) -> None:
message = RequestMessage(
self.id,
RequestMethod.AUTHENTICATE,
token=token
token=self.token
)
return await self._send(message, "authenticating")

Expand Down
24 changes: 5 additions & 19 deletions src/surrealdb/data/cbor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import datetime, timedelta, timezone

import cbor2

from surrealdb.data.types import constants
from surrealdb.data.types.datetime import IsoDateTimeWrapper
from surrealdb.data.types.duration import Duration
from surrealdb.data.types.future import Future
from surrealdb.data.types.geometry import (
Expand All @@ -15,9 +18,6 @@
from surrealdb.data.types.range import BoundIncluded, BoundExcluded, Range
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.table import Table
from surrealdb.data.types.datetime import DatetimeWrapper, IsoDateTimeWrapper
from datetime import datetime, timedelta, timezone
import pytz


@cbor2.shareable_encoder
Expand Down Expand Up @@ -68,14 +68,6 @@ def default_encoder(encoder, obj):
elif isinstance(obj, Duration):
tagged = cbor2.CBORTag(constants.TAG_DURATION, obj.get_seconds_and_nano())

elif isinstance(obj, DatetimeWrapper):
if obj.dt.tzinfo is None: # Make sure it's timezone-aware
obj.dt = obj.dt.replace(tzinfo=timezone.utc)

tagged = cbor2.CBORTag(
constants.TAG_DATETIME_COMPACT,
[int(obj.dt.timestamp()), obj.dt.microsecond * 1000]
)
elif isinstance(obj, IsoDateTimeWrapper):
tagged = cbor2.CBORTag(constants.TAG_DATETIME, obj.dt)
else:
Expand Down Expand Up @@ -134,21 +126,15 @@ def tag_decoder(decoder, tag, shareable_index=None):
seconds = tag.value[0]
nanoseconds = tag.value[1]
microseconds = nanoseconds // 1000 # Convert nanoseconds to microseconds
return DatetimeWrapper(
datetime.fromtimestamp(seconds) + timedelta(microseconds=microseconds)
)

elif tag.tag == constants.TAG_DATETIME:
dt_obj = datetime.fromisoformat(tag.value)
return DatetimeWrapper(dt_obj)# String (ISO 8601 datetime)
return datetime.fromtimestamp(seconds) + timedelta(microseconds=microseconds)

else:
raise BufferError("no decoder for tag", tag.tag)



def encode(obj):
return cbor2.dumps(obj, default=default_encoder)
return cbor2.dumps(obj, default=default_encoder, timezone=timezone.utc)


def decode(data):
Expand Down
18 changes: 9 additions & 9 deletions src/surrealdb/data/types/datetime.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from datetime import datetime
# from datetime import datetime


class DatetimeWrapper:

def __init__(self, dt: datetime):
self.dt = dt

@staticmethod
def now() -> "DatetimeWrapper":
return DatetimeWrapper(datetime.now())
# class DatetimeWrapper:
#
# def __init__(self, dt: datetime):
# self.dt = dt
#
# @staticmethod
# def now() -> "DatetimeWrapper":
# return DatetimeWrapper(datetime.now())


class IsoDateTimeWrapper:
Expand Down
34 changes: 28 additions & 6 deletions tests/unit_tests/connections/invalidate/test_async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@ async def asyncSetUp(self):
_ = await self.connection.signin(self.vars_params)
_ = await self.connection.use(namespace=self.namespace, database=self.database_name)

async def test_invalidate(self):
async def test_run_test(self):
if os.environ.get("NO_GUEST_MODE") == "True":
await self.invalidate_test_for_no_guest_mode()
else:
await self.invalidate_with_guest_mode_on()

async def invalidate_with_guest_mode_on(self):
"""
This test only works if the SURREAL_CAPS_ALLOW_GUESTS=false is set in the docker container
"""
outcome = await self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
Expand All @@ -37,18 +47,30 @@ async def test_invalidate(self):
self.assertEqual(0, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
await self.main_connection.query("DELETE user;")

async def invalidate_test_for_no_guest_mode(self):
"""
This test asserts that there is an error thrown due to no guest mode being allowed
Only run this test if SURREAL_CAPS_ALLOW_GUESTS=false is set in the docker container
"""
outcome = await self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = await self.connection.invalidate()

'''
# Exceptions are raised only when SurrealDB doesn't allow guest mode
with self.assertRaises(Exception) as context:
_ = await self.connection.query("CREATE user:jaime SET name = 'Jaime';")
_ = await self.connection.query("SELECT * FROM user;")
self.assertEqual(
"IAM error: Not enough permissions" in str(context.exception),
True
)
'''

outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
await self.main_connection.query("DELETE user;")


if __name__ == "__main__":
main()
25 changes: 20 additions & 5 deletions tests/unit_tests/connections/invalidate/test_async_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ async def asyncSetUp(self):
_ = await self.connection.signin(self.vars_params)
_ = await self.connection.use(namespace=self.namespace, database=self.database_name)

async def test_invalidate(self):
async def test_run_test(self):
if os.environ.get("NO_GUEST_MODE") == "True":
await self.invalidate_test_for_no_guest_mode()
else:
await self.invalidate_with_guest_mode_on()

async def invalidate_with_guest_mode_on(self):
outcome = await self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
Expand All @@ -37,16 +43,25 @@ async def test_invalidate(self):
self.assertEqual(0, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
await self.main_connection.query("DELETE user;")

async def invalidate_test_for_no_guest_mode(self):
outcome = await self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = await self.connection.invalidate()

'''
# Exceptions are raised only when SurrealDB doesn't allow guest mode
with self.assertRaises(Exception) as context:
_ = await self.connection.query("CREATE user:jaime SET name = 'Jaime';")
_ = await self.connection.query("SELECT * FROM user;")

self.assertEqual(
"IAM error: Not enough permissions" in str(context.exception),
True
)
'''
outcome = await self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

await self.main_connection.query("DELETE user;")
await self.main_connection.close()
Expand Down
36 changes: 26 additions & 10 deletions tests/unit_tests/connections/invalidate/test_blocking_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,44 @@ def setUp(self):
_ = self.connection.signin(self.vars_params)
_ = self.connection.use(namespace=self.namespace, database=self.database_name)

def test_invalidate(self):
def test_run_test(self):
if os.environ.get("NO_GUEST_MODE") == "True":
self.invalidate_test_for_no_guest_mode()
else:
self.invalidate_with_guest_mode_on()

def invalidate_test_for_no_guest_mode(self):
outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = self.connection.invalidate()

outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(0, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

'''
# Exceptions are raised only when SurrealDB doesn't allow guest mode
with self.assertRaises(Exception) as context:
_ = self.connection.query("CREATE user:jaime SET name = 'Jaime';")
_ = self.connection.query("SELECT * FROM user;")

self.assertEqual(
"IAM error: Not enough permissions" in str(context.exception),
True
)
'''
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

self.main_connection.query("DELETE user;")

def invalidate_with_guest_mode_on(self):
outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = self.connection.invalidate()

outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(0, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

self.main_connection.query("DELETE user;")

Expand Down
39 changes: 28 additions & 11 deletions tests/unit_tests/connections/invalidate/test_blocking_ws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import main, TestCase
import os
from unittest import main, TestCase

from surrealdb.connections.blocking_ws import BlockingWsSurrealConnection


Expand All @@ -25,28 +26,44 @@ def setUp(self):
_ = self.connection.signin(self.vars_params)
_ = self.connection.use(namespace=self.namespace, database=self.database_name)

def test_invalidate(self):
def test_run_test(self):
if os.environ.get("NO_GUEST_MODE") == "True":
self.invalidate_test_for_no_guest_mode()
else:
self.invalidate_with_guest_mode_on()

def invalidate_test_for_no_guest_mode(self):
outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = self.connection.invalidate()

outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(0, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

'''
# Exceptions are raised only when SurrealDB doesn't allow guest mode
with self.assertRaises(Exception) as context:
_ = self.connection.query("CREATE user:jaime SET name = 'Jaime';")
_ = self.connection.query("SELECT * FROM user;")

self.assertEqual(
"IAM error: Not enough permissions" in str(context.exception),
True
)
'''
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

self.main_connection.query("DELETE user;")

def invalidate_with_guest_mode_on(self):
outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

_ = self.connection.invalidate()

outcome = self.connection.query("SELECT * FROM user;")
self.assertEqual(0, len(outcome))
outcome = self.main_connection.query("SELECT * FROM user;")
self.assertEqual(1, len(outcome))

self.main_connection.query("DELETE user;")
self.main_connection.close()
Expand Down
Loading

0 comments on commit 369a641

Please sign in to comment.