Skip to content

Commit

Permalink
[BUGFIX] Argilla server: looking for records with external_id or `i…
Browse files Browse the repository at this point in the history
…d` on bulk operations (#5014)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR fixes the problem when updating records in bulk with a wrong
`external_id` but a correct `id`.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Jun 17, 2024
1 parent 0927fa5 commit cff6e42
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ These are the section headers that we use:

### Fixed

- Fixed error when updating records in bulk with wrong `external_id` but correct record `id`. ([#5014](https://github.com/argilla-io/argilla/pull/5014))
- Fixed error when searching all record response values. ([#5003](https://github.com/argilla-io/argilla/pull/5003))

## [1.29.0](https://github.com/argilla-io/argilla/compare/v1.28.0...v1.29.0)
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp
records = []
async with self._db.begin_nested():
for record_upsert in bulk_upsert.items:
record = found_records.get(record_upsert.external_id or record_upsert.id)
record = found_records.get(record_upsert.id) or found_records.get(record_upsert.external_id)
if not record:
record = Record(
fields=record_upsert.fields,
Expand Down
5 changes: 3 additions & 2 deletions argilla-server/src/argilla_server/validators/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def _validate_all_bulk_records(self, dataset: Dataset, records_upsert: List[Reco
for idx, record_upsert in enumerate(records_upsert):
try:
record = self._existing_records_by_external_id_or_record_id.get(
record_upsert.external_id or record_upsert.id
)
record_upsert.id
) or self._existing_records_by_external_id_or_record_id.get(record_upsert.external_id)

if record:
RecordUpdateValidator(RecordUpdate.parse_obj(record_upsert)).validate_for(dataset)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from uuid import UUID

import pytest
Expand Down Expand Up @@ -105,7 +105,7 @@ async def test_update_record_metadata_by_id(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict, metadata: dict
) -> None:
dataset = await self.test_dataset()
records = await RecordFactory.create_batch(dataset=dataset, size=100)
records = await RecordFactory.create_batch(dataset=dataset, size=10)

response = await async_client.put(
self.url(dataset.id),
Expand All @@ -124,7 +124,7 @@ async def test_update_record_metadata_by_external_id(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict, metadata: dict
):
dataset = await self.test_dataset()
records = await RecordFactory.create_batch(dataset=dataset, size=100)
records = await RecordFactory.create_batch(dataset=dataset, size=10)

response = await async_client.put(
self.url(dataset.id),
Expand All @@ -140,6 +140,30 @@ async def test_update_record_metadata_by_external_id(
for record in updated_records:
assert record.metadata_ == metadata

async def test_update_record_metadata_with_invalid_external_id_but_correct_id(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict
):
dataset = await self.test_dataset()
records = await RecordFactory.create_batch(dataset=dataset, size=10)

new_metadata = {"whatever": "whatever"}
response = await async_client.put(
self.url(dataset.id),
headers=owner_auth_header,
json={
"items": [
{"id": str(record.id), "external_id": str(uuid.uuid4()), "metadata": new_metadata}
for record in records
],
},
)

assert response.status_code == 200, response.json()
assert (await db.execute(select(func.count(Record.id)))).scalar_one() == len(records)
updated_records = (await db.execute(select(Record))).scalars().all()
for record in updated_records:
assert record.metadata_ == new_metadata

async def test_update_record_for_other_dataset(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict
):
Expand Down

0 comments on commit cff6e42

Please sign in to comment.