From c9754a92234786b2b5c5e9ef1720a3e5ba2c4aaa Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 24 Sep 2024 17:44:15 +0300 Subject: [PATCH] Per segment chunks (#8272) ### Motivation and context - Changed chunk generation from per-task chunks to per-segment chunks - Fixed a memory leak in video reading on the server side (only in media_extractors, so there are several more left) - Fixed a potential hang in `import` worker or the server process on process shutdown - Disabled multithreading in video reading in endpoints (not in static chunk generation) - Refactored static chunk generation code (moved after job creation) - Refactored various server internal APIs for frame retrieval - Updated UI logic to access chunks, added support for non-sequential frames in chunks - Added a new server configuration option `CVAT_ALLOW_STATIC_CACHE` (boolean) to enable and disable static cache support. The option is disabled by default (it's changed from the previous behavior) - Added tests for the changes made - Added missing original chunk type field in job responses - Fixed invalid kvrocks cleanup in tests for Helm deployment - Added a new 0-based `index` parameter in `GET /api/jobs/{id}/data/?type=chunk` to simplify indexing - GT job chunks with non-sequential frames have no placeholders inside When this update is applied to the server, there will be a data storage setting migration for the tasks. Existing tasks using static chunks (`task.data.storage_method == FILE_SYSTEM`) will be switched to the dynamic cache (i.e. to `== CACHE)`). The remaining files should be removed manually, there will be a list of such tasks in the migration log file. After this update, you'll have an option to enable or disable static cache use during task creation. This allows, in particular, prohibit new tasks using the static cache. With this option, any tasks using static cache will use the dynamic cache instead on data access. User-observable changes: - Job chunk ids now start from 0 for each job instead of using parent task ids - The `use_cache = false` or `storage_method = filesystem` parameters in task creation can be ignored by the server - Task chunk access may be slower for some chunks (particularly, for tasks with overlap configured, for chunks on segment boundaries, and for tasks previously using static chunks) - The last chunk in a job will contain only the frames from the current job, even if there are more frames in the task ### How has this been tested? ### Checklist - [ ] I submit my changes into the `develop` branch - [ ] I have created a changelog fragment - [ ] I have updated the documentation accordingly - [ ] I have added tests to cover my changes - [ ] I have linked related issues (see [GitHub docs]( https://help.github.com/en/github/managing-your-work-on-github/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword)) - [ ] I have increased versions of npm packages if it is necessary ([cvat-canvas](https://github.com/cvat-ai/cvat/tree/develop/cvat-canvas#versioning), [cvat-core](https://github.com/cvat-ai/cvat/tree/develop/cvat-core#versioning), [cvat-data](https://github.com/cvat-ai/cvat/tree/develop/cvat-data#versioning) and [cvat-ui](https://github.com/cvat-ai/cvat/tree/develop/cvat-ui#versioning)) ### License - [ ] I submit _my code changes_ under the same [MIT License]( https://github.com/cvat-ai/cvat/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Introduced a new server setting to disable media chunks on the local filesystem. - Enhanced frame prefetching with a `startFrame` parameter for improved chunk calculations. - Added a new property, `data_original_chunk_type`, for enhanced job differentiation in the metadata. - **Bug Fixes** - Resolved memory management issues to prevent leaks during video processing. - Corrected naming inconsistencies related to the `prefetchAnalyzer`. - **Documentation** - Included configuration for code formatting tools to ensure consistent code quality across the project. - **Refactor** - Restructured classes and methods for improved clarity and maintainability, particularly in media handling and task processing. - **Chores** - Updated formatting scripts to include additional directories for automated code formatting. --- .github/workflows/full.yml | 2 + .github/workflows/main.yml | 3 +- .github/workflows/schedule.yml | 6 + .../20240812_161617_mzhiltso_job_chunks.md | 24 + cvat-core/src/frames.ts | 201 ++++- cvat-core/src/server-proxy.ts | 2 +- cvat-core/src/session-implementation.ts | 17 +- cvat-core/src/session.ts | 10 +- cvat-data/src/ts/cvat-data.ts | 103 ++- cvat-ui/src/actions/annotation-actions.ts | 10 +- .../top-bar/player-navigation.tsx | 17 +- cvat/apps/dataset_manager/bindings.py | 47 +- cvat/apps/dataset_manager/formats/cvat.py | 23 +- .../dataset_manager/tests/test_formats.py | 42 +- .../tests/test_rest_api_formats.py | 17 +- cvat/apps/engine/apps.py | 8 + cvat/apps/engine/cache.py | 847 +++++++++++++----- cvat/apps/engine/default_settings.py | 16 + cvat/apps/engine/frame_provider.py | 833 +++++++++++++---- cvat/apps/engine/log.py | 39 +- cvat/apps/engine/media_extractors.py | 584 ++++++++---- .../migrations/0083_move_to_segment_chunks.py | 118 +++ cvat/apps/engine/models.py | 29 +- cvat/apps/engine/pyproject.toml | 12 + cvat/apps/engine/serializers.py | 4 +- cvat/apps/engine/task.py | 646 +++++++------ cvat/apps/engine/tests/test_rest_api.py | 142 ++- cvat/apps/engine/tests/test_rest_api_3D.py | 18 +- cvat/apps/engine/tests/utils.py | 20 +- cvat/apps/engine/views.py | 233 ++--- cvat/apps/lambda_manager/tests/test_lambda.py | 45 +- cvat/apps/lambda_manager/views.py | 14 +- cvat/requirements/base.in | 6 + cvat/schema.yml | 13 +- dev/format_python_code.sh | 3 + docker-compose.yml | 1 + helm-chart/test.values.yaml | 6 + tests/python/rest_api/test_jobs.py | 92 +- tests/python/rest_api/test_queues.py | 2 +- .../rest_api/test_resource_import_export.py | 2 +- tests/python/rest_api/test_tasks.py | 541 ++++++++++- tests/python/sdk/test_auto_annotation.py | 1 + tests/python/sdk/test_datasets.py | 1 + tests/python/sdk/test_jobs.py | 1 + tests/python/sdk/test_projects.py | 1 + tests/python/sdk/test_pytorch.py | 1 + tests/python/sdk/test_tasks.py | 1 + tests/python/shared/assets/jobs.json | 27 + tests/python/shared/fixtures/init.py | 36 +- tests/python/shared/utils/helpers.py | 22 +- 50 files changed, 3516 insertions(+), 1373 deletions(-) create mode 100644 changelog.d/20240812_161617_mzhiltso_job_chunks.md create mode 100644 cvat/apps/engine/default_settings.py create mode 100644 cvat/apps/engine/migrations/0083_move_to_segment_chunks.py create mode 100644 cvat/apps/engine/pyproject.toml diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index 9502e7a0b185..c6369340b5c3 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -165,6 +165,8 @@ jobs: id: run_tests run: | pytest tests/python/ + ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py + CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python - name: Creating a log file from cvat containers if: failure() && steps.run_tests.conclusion == 'failure' diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 75b200597f47..0c9211b0c4a5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -177,8 +177,9 @@ jobs: COVERAGE_PROCESS_START: ".coveragerc" run: | pytest tests/python/ --cov --cov-report=json - for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py --cov --cov-report=json + CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python --cov --cov-report=json + for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done - name: Uploading code coverage results as an artifact uses: actions/upload-artifact@v4 diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml index d8e514cbb449..c2071cd85d13 100644 --- a/.github/workflows/schedule.yml +++ b/.github/workflows/schedule.yml @@ -170,6 +170,12 @@ jobs: pytest tests/python/ pytest tests/python/ --stop-services + ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py + pytest tests/python/ --stop-services + + CVAT_ALLOW_STATIC_CACHE="true" pytest tests/python + pytest tests/python/ --stop-services + - name: Unit tests env: HOST_COVERAGE_DATA_DIR: ${{ github.workspace }} diff --git a/changelog.d/20240812_161617_mzhiltso_job_chunks.md b/changelog.d/20240812_161617_mzhiltso_job_chunks.md new file mode 100644 index 000000000000..af931641d6df --- /dev/null +++ b/changelog.d/20240812_161617_mzhiltso_job_chunks.md @@ -0,0 +1,24 @@ +### Added + +- A server setting to enable or disable storage of permanent media chunks on the server filesystem + () +- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&index=x` parameter combination. + The new `index` parameter allows to retrieve job chunks using 0-based index in each job, + instead of the `number` parameter, which used task chunk ids. + () + +### Changed + +- Job assignees will not receive frames from adjacent jobs in chunks + () + +### Deprecated + +- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&number=x` parameter combination + () + + +### Fixed + +- Various memory leaks in video reading on the server + () diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts index 96295af7d57d..dda847cf7a72 100644 --- a/cvat-core/src/frames.ts +++ b/cvat-core/src/frames.ts @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -import _ from 'lodash'; +import _, { range, sortedIndexOf } from 'lodash'; import { FrameDecoder, BlockType, DimensionType, ChunkQuality, decodeContextImages, RequestOutdatedError, } from 'cvat-data'; @@ -25,7 +25,7 @@ const frameDataCache: Record | null; activeContextRequest: Promise> | null; @@ -34,7 +34,7 @@ const frameDataCache: Record; - getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise; + getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise; }> = {}; // frame meta data storage by job id @@ -55,6 +55,8 @@ export class FramesMetaData { public size: number; public startFrame: number; public stopFrame: number; + public frameStep: number; + public chunkCount: number; #updateTrigger: FieldUpdateTrigger; @@ -103,6 +105,17 @@ export class FramesMetaData { } } + const frameStep: number = (() => { + if (data.frame_filter) { + const frameStepParts = data.frame_filter.split('=', 2); + if (frameStepParts.length !== 2) { + throw new ArgumentError(`Invalid frame filter '${data.frame_filter}'`); + } + return +frameStepParts[1]; + } + return 1; + })(); + Object.defineProperties( this, Object.freeze({ @@ -133,6 +146,20 @@ export class FramesMetaData { stopFrame: { get: () => data.stop_frame, }, + frameStep: { + get: () => frameStep, + }, + }), + ); + + const chunkCount: number = Math.ceil(this.getDataFrameNumbers().length / this.chunkSize); + + Object.defineProperties( + this, + Object.freeze({ + chunkCount: { + get: () => chunkCount, + }, }), ); } @@ -144,6 +171,40 @@ export class FramesMetaData { resetUpdated(): void { this.#updateTrigger.reset(); } + + getFrameIndex(dataFrameNumber: number): number { + // Here we use absolute (task source data) frame numbers. + // TODO: migrate from data frame numbers to local frame numbers to simplify code. + // Requires server changes in api/jobs/{id}/data/meta/ + // for included_frames, start_frame, stop_frame fields + + if (dataFrameNumber < this.startFrame || dataFrameNumber > this.stopFrame) { + throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`); + } + + let frameIndex = null; + if (this.includedFrames) { + frameIndex = sortedIndexOf(this.includedFrames, dataFrameNumber); + if (frameIndex === -1) { + throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`); + } + } else { + frameIndex = Math.floor((dataFrameNumber - this.startFrame) / this.frameStep); + } + return frameIndex; + } + + getFrameChunkIndex(dataFrameNumber: number): number { + return Math.floor(this.getFrameIndex(dataFrameNumber) / this.chunkSize); + } + + getDataFrameNumbers(): number[] { + if (this.includedFrames) { + return this.includedFrames; + } + + return range(this.startFrame, this.stopFrame + 1, this.frameStep); + } } export class FrameData { @@ -206,12 +267,14 @@ export class FrameData { } class PrefetchAnalyzer { - #chunkSize: number; #requestedFrames: number[]; + #meta: FramesMetaData; + #getDataFrameNumber: (frameNumber: number) => number; - constructor(chunkSize) { - this.#chunkSize = chunkSize; + constructor(meta: FramesMetaData, dataFrameNumberGetter: (frameNumber: number) => number) { this.#requestedFrames = []; + this.#meta = meta; + this.#getDataFrameNumber = dataFrameNumberGetter; } shouldPrefetchNext(current: number, isPlaying: boolean, isChunkCached: (chunk) => boolean): boolean { @@ -219,13 +282,16 @@ class PrefetchAnalyzer { return true; } - const currentChunk = Math.floor(current / this.#chunkSize); + const currentDataFrameNumber = this.#getDataFrameNumber(current); + const currentChunk = this.#meta.getFrameChunkIndex(currentDataFrameNumber); const { length } = this.#requestedFrames; const isIncreasingOrder = this.#requestedFrames .every((val, index) => index === 0 || val > this.#requestedFrames[index - 1]); if ( length && (isIncreasingOrder && current > this.#requestedFrames[length - 1]) && - (current % this.#chunkSize) >= Math.ceil(this.#chunkSize / 2) && + ( + this.#meta.getFrameIndex(currentDataFrameNumber) % this.#meta.chunkSize + ) >= Math.ceil(this.#meta.chunkSize / 2) && !isChunkCached(currentChunk + 1) ) { // is increasing order including the current frame @@ -247,13 +313,25 @@ class PrefetchAnalyzer { this.#requestedFrames.push(frame); // only half of chunk size is considered in this logic - const limit = Math.ceil(this.#chunkSize / 2); + const limit = Math.ceil(this.#meta.chunkSize / 2); if (this.#requestedFrames.length > limit) { this.#requestedFrames.shift(); } } } +function getDataStartFrame(meta: FramesMetaData, localStartFrame: number): number { + return meta.startFrame - localStartFrame * meta.frameStep; +} + +function getDataFrameNumber(frameNumber: number, dataStartFrame: number, step: number): number { + return frameNumber * step + dataStartFrame; +} + +function getFrameNumber(dataFrameNumber: number, dataStartFrame: number, step: number): number { + return (dataFrameNumber - dataStartFrame) / step; +} + Object.defineProperty(FrameData.prototype.data, 'implementation', { value(this: FrameData, onServerRequest) { return new Promise<{ @@ -262,40 +340,57 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { imageData: ImageBitmap | Blob; } | Blob>((resolve, reject) => { const { - provider, prefetchAnalizer, chunkSize, stopFrame, decodeForward, forwardStep, decodedBlocksCacheSize, + meta, provider, prefetchAnalyzer, chunkSize, startFrame, + decodeForward, forwardStep, decodedBlocksCacheSize, } = frameDataCache[this.jobID]; const requestId = +_.uniqueId(); - const chunkNumber = Math.floor(this.number / chunkSize); + const dataStartFrame = getDataStartFrame(meta, startFrame); + const requestedDataFrameNumber = getDataFrameNumber( + this.number, dataStartFrame, meta.frameStep, + ); + const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber); + const segmentFrameNumbers = meta.getDataFrameNumbers().map( + (dataFrameNumber: number) => getFrameNumber( + dataFrameNumber, dataStartFrame, meta.frameStep, + ), + ); const frame = provider.frame(this.number); - function findTheNextNotDecodedChunk(searchFrom: number): number { - let firstFrameInNextChunk = searchFrom + forwardStep; - let nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize); - while (nextChunkNumber === chunkNumber) { - firstFrameInNextChunk += forwardStep; - nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize); + function findTheNextNotDecodedChunk(currentFrameIndex: number): number | null { + const { chunkCount } = meta; + let nextFrameIndex = currentFrameIndex + forwardStep; + let nextChunkIndex = Math.floor(nextFrameIndex / chunkSize); + while (nextChunkIndex === chunkIndex) { + nextFrameIndex += forwardStep; + nextChunkIndex = Math.floor(nextFrameIndex / chunkSize); } - if (provider.isChunkCached(nextChunkNumber)) { - return findTheNextNotDecodedChunk(firstFrameInNextChunk); + if (nextChunkIndex < 0 || chunkCount <= nextChunkIndex) { + return null; } - return nextChunkNumber; + if (provider.isChunkCached(nextChunkIndex)) { + return findTheNextNotDecodedChunk(nextFrameIndex); + } + + return nextChunkIndex; } if (frame) { if ( - prefetchAnalizer.shouldPrefetchNext( + prefetchAnalyzer.shouldPrefetchNext( this.number, decodeForward, (chunk) => provider.isChunkCached(chunk), ) && decodedBlocksCacheSize > 1 && !frameDataCache[this.jobID].activeChunkRequest ) { - const nextChunkNumber = findTheNextNotDecodedChunk(this.number); + const nextChunkIndex = findTheNextNotDecodedChunk( + meta.getFrameIndex(requestedDataFrameNumber), + ); const predecodeChunksMax = Math.floor(decodedBlocksCacheSize / 2); - if (nextChunkNumber * chunkSize <= stopFrame && - nextChunkNumber <= chunkNumber + predecodeChunksMax + if (nextChunkIndex !== null && + nextChunkIndex <= chunkIndex + predecodeChunksMax ) { frameDataCache[this.jobID].activeChunkRequest = new Promise((resolveForward) => { const releasePromise = (): void => { @@ -304,7 +399,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { }; frameDataCache[this.jobID].getChunk( - nextChunkNumber, ChunkQuality.COMPRESSED, + nextChunkIndex, ChunkQuality.COMPRESSED, ).then((chunk: ArrayBuffer) => { if (!(this.jobID in frameDataCache)) { // check if frameDataCache still exist @@ -316,8 +411,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { provider.cleanup(1); provider.requestDecodeBlock( chunk, - nextChunkNumber * chunkSize, - Math.min(stopFrame, (nextChunkNumber + 1) * chunkSize - 1), + nextChunkIndex, + segmentFrameNumbers.slice( + nextChunkIndex * chunkSize, + (nextChunkIndex + 1) * chunkSize, + ), () => {}, releasePromise, releasePromise, @@ -334,7 +432,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: frame, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); return; } @@ -355,7 +453,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: currentFrame, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); return; } @@ -364,7 +462,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { ) => { let wasResolved = false; frameDataCache[this.jobID].getChunk( - chunkNumber, ChunkQuality.COMPRESSED, + chunkIndex, ChunkQuality.COMPRESSED, ).then((chunk: ArrayBuffer) => { try { if (!(this.jobID in frameDataCache)) { @@ -378,8 +476,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { provider .requestDecodeBlock( chunk, - chunkNumber * chunkSize, - Math.min(stopFrame, (chunkNumber + 1) * chunkSize - 1), + chunkIndex, + segmentFrameNumbers.slice( + chunkIndex * chunkSize, + (chunkIndex + 1) * chunkSize, + ), (_frame: number, bitmap: ImageBitmap | Blob) => { if (decodeForward) { // resolve immediately only if is not playing @@ -395,7 +496,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: bitmap, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); } }, () => { frameDataCache[this.jobID].activeChunkRequest = null; @@ -592,7 +693,7 @@ export async function getFrame( isPlaying: boolean, step: number, dimension: DimensionType, - getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise, + getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise, ): Promise { if (!(jobID in frameDataCache)) { const blockType = chunkType === 'video' ? BlockType.MP4VIDEO : BlockType.ARCHIVE; @@ -608,6 +709,13 @@ export async function getFrame( const decodedBlocksCacheSize = Math.min( Math.floor((2048 * 1024 * 1024) / ((mean + stdDev) * 4 * chunkSize)) || 1, 10, ); + + // TODO: migrate to local frame numbers + const dataStartFrame = getDataStartFrame(meta, startFrame); + const dataFrameNumberGetter = (frameNumber: number): number => ( + getDataFrameNumber(frameNumber, dataStartFrame, meta.frameStep) + ); + frameDataCache[jobID] = { meta, chunkSize, @@ -618,11 +726,13 @@ export async function getFrame( forwardStep: step, provider: new FrameDecoder( blockType, - chunkSize, decodedBlocksCacheSize, + (frameNumber: number): number => ( + meta.getFrameChunkIndex(dataFrameNumberGetter(frameNumber)) + ), dimension, ), - prefetchAnalizer: new PrefetchAnalyzer(chunkSize), + prefetchAnalyzer: new PrefetchAnalyzer(meta, dataFrameNumberGetter), decodedBlocksCacheSize, activeChunkRequest: null, activeContextRequest: null, @@ -697,8 +807,11 @@ export async function findFrame( let lastUndeletedFrame = null; const check = (frame): boolean => { if (meta.includedFrames) { - return (meta.includedFrames.includes(frame)) && - (!filters.notDeleted || !(frame in meta.deletedFrames)); + // meta.includedFrames contains input frame numbers now + const dataStartFrame = meta.startFrame; // this is only true when includedFrames is set + return (meta.includedFrames.includes( + getDataFrameNumber(frame, dataStartFrame, meta.frameStep)) + ) && (!filters.notDeleted || !(frame in meta.deletedFrames)); } if (filters.notDeleted) { return !(frame in meta.deletedFrames); @@ -726,6 +839,18 @@ export function getCachedChunks(jobID): number[] { return frameDataCache[jobID].provider.cachedChunks(true); } +export function getJobFrameNumbers(jobID): number[] { + if (!(jobID in frameDataCache)) { + return []; + } + + const { meta, startFrame } = frameDataCache[jobID]; + const dataStartFrame = getDataStartFrame(meta, startFrame); + return meta.getDataFrameNumbers().map((dataFrameNumber: number): number => ( + getFrameNumber(dataFrameNumber, dataStartFrame, meta.frameStep) + )); +} + export function clear(jobID: number): void { if (jobID in frameDataCache) { frameDataCache[jobID].provider.close(); diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 91dc52a71821..51309426198a 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -1438,7 +1438,7 @@ async function getData(jid: number, chunk: number, quality: ChunkQuality, retry ...enableOrganization(), quality, type: 'chunk', - number: chunk, + index: chunk, }, responseType: 'arraybuffer', }); diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index fa77c934abde..961771708726 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -18,6 +18,7 @@ import { deleteFrame, restoreFrame, getCachedChunks, + getJobFrameNumbers, clear as clearFrames, findFrame, getContextImage, @@ -189,7 +190,7 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { isPlaying, step, this.dimension, - (chunkNumber, quality) => this.frames.chunk(chunkNumber, quality), + (chunkIndex, quality) => this.frames.chunk(chunkIndex, quality), ); }, }); @@ -244,6 +245,14 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { }, }); + Object.defineProperty(Job.prototype.frames.frameNumbers, 'implementation', { + value: function includedFramesImplementation( + this: JobClass, + ): ReturnType { + return Promise.resolve(getJobFrameNumbers(this.id)); + }, + }); + Object.defineProperty(Job.prototype.frames.preview, 'implementation', { value: function previewImplementation( this: JobClass, @@ -273,10 +282,10 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { Object.defineProperty(Job.prototype.frames.chunk, 'implementation', { value: function chunkImplementation( this: JobClass, - chunkNumber: Parameters[0], + chunkIndex: Parameters[0], quality: Parameters[1], ): ReturnType { - return serverProxy.frames.getData(this.id, chunkNumber, quality); + return serverProxy.frames.getData(this.id, chunkIndex, quality); }, }); @@ -829,7 +838,7 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { isPlaying, step, this.dimension, - (chunkNumber, quality) => job.frames.chunk(chunkNumber, quality), + (chunkIndex, quality) => job.frames.chunk(chunkIndex, quality), ); return result; }, diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index 1985a72b2683..54133ff6b667 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -233,6 +233,10 @@ function buildDuplicatedAPI(prototype) { const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.cachedChunks); return result; }, + async frameNumbers() { + const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.frameNumbers); + return result; + }, async preview() { const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.preview); return result; @@ -255,11 +259,11 @@ function buildDuplicatedAPI(prototype) { ); return result; }, - async chunk(chunkNumber, quality) { + async chunk(chunkIndex, quality) { const result = await PluginRegistry.apiWrapper.call( this, prototype.frames.chunk, - chunkNumber, + chunkIndex, quality, ); return result; @@ -380,6 +384,7 @@ export class Session { restore: (frame: number) => Promise; save: () => Promise; cachedChunks: () => Promise; + frameNumbers: () => Promise; preview: () => Promise; contextImage: (frame: number) => Promise>; search: ( @@ -443,6 +448,7 @@ export class Session { restore: Object.getPrototypeOf(this).frames.restore.bind(this), save: Object.getPrototypeOf(this).frames.save.bind(this), cachedChunks: Object.getPrototypeOf(this).frames.cachedChunks.bind(this), + frameNumbers: Object.getPrototypeOf(this).frames.frameNumbers.bind(this), preview: Object.getPrototypeOf(this).frames.preview.bind(this), search: Object.getPrototypeOf(this).frames.search.bind(this), contextImage: Object.getPrototypeOf(this).frames.contextImage.bind(this), diff --git a/cvat-data/src/ts/cvat-data.ts b/cvat-data/src/ts/cvat-data.ts index 2f832ac9d3f5..baf00ac443c1 100644 --- a/cvat-data/src/ts/cvat-data.ts +++ b/cvat-data/src/ts/cvat-data.ts @@ -72,8 +72,8 @@ export function decodeContextImages( decodeContextImages.mutex = new Mutex(); interface BlockToDecode { - start: number; - end: number; + chunkFrameNumbers: number[]; + chunkIndex: number; block: ArrayBuffer; onDecodeAll(): void; onDecode(frame: number, bitmap: ImageBitmap | Blob): void; @@ -82,7 +82,6 @@ interface BlockToDecode { export class FrameDecoder { private blockType: BlockType; - private chunkSize: number; /* ImageBitmap when decode zip or video chunks Blob when 3D dimension @@ -100,11 +99,12 @@ export class FrameDecoder { private renderHeight: number; private zipWorker: Worker | null; private videoWorker: Worker | null; + private getChunkIndex: (frame: number) => number; constructor( blockType: BlockType, - chunkSize: number, cachedBlockCount: number, + getChunkIndex: (frame: number) => number, dimension: DimensionType = DimensionType.DIMENSION_2D, ) { this.mutex = new Mutex(); @@ -117,7 +117,7 @@ export class FrameDecoder { this.renderWidth = 1920; this.renderHeight = 1080; - this.chunkSize = chunkSize; + this.getChunkIndex = getChunkIndex; this.blockType = blockType; this.decodedChunks = {}; @@ -125,8 +125,8 @@ export class FrameDecoder { this.chunkIsBeingDecoded = null; } - isChunkCached(chunkNumber: number): boolean { - return chunkNumber in this.decodedChunks; + isChunkCached(chunkIndex: number): boolean { + return chunkIndex in this.decodedChunks; } hasFreeSpace(): boolean { @@ -155,17 +155,37 @@ export class FrameDecoder { } } + private validateFrameNumbers(frameNumbers: number[]): void { + if (!Array.isArray(frameNumbers) || !frameNumbers.length) { + throw new Error('chunkFrameNumbers must not be empty'); + } + + // ensure is ordered + for (let i = 1; i < frameNumbers.length; ++i) { + const prev = frameNumbers[i - 1]; + const current = frameNumbers[i]; + if (current <= prev) { + throw new Error( + 'chunkFrameNumbers must be sorted in the ascending order, ' + + `got a (${prev}, ${current}) pair instead`, + ); + } + } + } + requestDecodeBlock( block: ArrayBuffer, - start: number, - end: number, + chunkIndex: number, + chunkFrameNumbers: number[], onDecode: (frame: number, bitmap: ImageBitmap | Blob) => void, onDecodeAll: () => void, onReject: (e: Error) => void, ): void { + this.validateFrameNumbers(chunkFrameNumbers); + if (this.requestedChunkToDecode !== null) { // a chunk was already requested to be decoded, but decoding didn't start yet - if (start === this.requestedChunkToDecode.start && end === this.requestedChunkToDecode.end) { + if (chunkIndex === this.requestedChunkToDecode.chunkIndex) { // it was the same chunk this.requestedChunkToDecode.onReject(new RequestOutdatedError()); @@ -175,12 +195,14 @@ export class FrameDecoder { // it was other chunk this.requestedChunkToDecode.onReject(new RequestOutdatedError()); } - } else if (this.chunkIsBeingDecoded === null || this.chunkIsBeingDecoded.start !== start) { + } else if (this.chunkIsBeingDecoded === null || + chunkIndex !== this.chunkIsBeingDecoded.chunkIndex + ) { // everything was decoded or decoding other chunk is in process this.requestedChunkToDecode = { + chunkFrameNumbers, + chunkIndex, block, - start, - end, onDecode, onDecodeAll, onReject, @@ -203,9 +225,9 @@ export class FrameDecoder { } frame(frameNumber: number): ImageBitmap | Blob | null { - const chunkNumber = Math.floor(frameNumber / this.chunkSize); - if (chunkNumber in this.decodedChunks) { - return this.decodedChunks[chunkNumber][frameNumber]; + const chunkIndex = this.getChunkIndex(frameNumber); + if (chunkIndex in this.decodedChunks) { + return this.decodedChunks[chunkIndex][frameNumber]; } return null; @@ -253,8 +275,8 @@ export class FrameDecoder { releaseMutex(); }; try { - const { start, end, block } = this.requestedChunkToDecode; - if (start !== blockToDecode.start) { + const { chunkFrameNumbers, chunkIndex, block } = this.requestedChunkToDecode; + if (chunkIndex !== blockToDecode.chunkIndex) { // request is not relevant, another block was already requested // it happens when A is being decoded, B comes and wait for mutex, C comes and wait for mutex // B is not necessary anymore, because C already was requested @@ -262,8 +284,11 @@ export class FrameDecoder { throw new RequestOutdatedError(); } - const chunkNumber = Math.floor(start / this.chunkSize); - this.orderedStack = [chunkNumber, ...this.orderedStack]; + const getFrameNumber = (chunkFrameIndex: number): number => ( + chunkFrameNumbers[chunkFrameIndex] + ); + + this.orderedStack = [chunkIndex, ...this.orderedStack]; this.cleanup(); const decodedFrames: Record = {}; this.chunkIsBeingDecoded = this.requestedChunkToDecode; @@ -273,7 +298,7 @@ export class FrameDecoder { this.videoWorker = new Worker( new URL('./3rdparty/Decoder.worker', import.meta.url), ); - let index = start; + let index = 0; this.videoWorker.onmessage = (e) => { if (e.data.consoleLog) { @@ -281,6 +306,7 @@ export class FrameDecoder { return; } const keptIndex = index; + const frameNumber = getFrameNumber(keptIndex); // do not use e.data.height and e.data.width because they might be not correct // instead, try to understand real height and width of decoded image via scale factor @@ -295,11 +321,11 @@ export class FrameDecoder { width, height, )).then((bitmap) => { - decodedFrames[keptIndex] = bitmap; - this.chunkIsBeingDecoded.onDecode(keptIndex, decodedFrames[keptIndex]); + decodedFrames[frameNumber] = bitmap; + this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]); - if (keptIndex === end) { - this.decodedChunks[chunkNumber] = decodedFrames; + if (keptIndex === chunkFrameNumbers.length - 1) { + this.decodedChunks[chunkIndex] = decodedFrames; this.chunkIsBeingDecoded.onDecodeAll(); this.chunkIsBeingDecoded = null; release(); @@ -343,7 +369,7 @@ export class FrameDecoder { this.zipWorker = this.zipWorker || new Worker( new URL('./unzip_imgs.worker', import.meta.url), ); - let index = start; + let decodedCount = 0; this.zipWorker.onmessage = async (event) => { if (event.data.error) { @@ -353,16 +379,18 @@ export class FrameDecoder { return; } - decodedFrames[event.data.index] = event.data.data as ImageBitmap | Blob; - this.chunkIsBeingDecoded.onDecode(event.data.index, decodedFrames[event.data.index]); + const frameNumber = getFrameNumber(event.data.index); + decodedFrames[frameNumber] = event.data.data as ImageBitmap | Blob; + this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]); - if (index === end) { - this.decodedChunks[chunkNumber] = decodedFrames; + if (decodedCount === chunkFrameNumbers.length - 1) { + this.decodedChunks[chunkIndex] = decodedFrames; this.chunkIsBeingDecoded.onDecodeAll(); this.chunkIsBeingDecoded = null; release(); } - index++; + + decodedCount++; }; this.zipWorker.onerror = (event: ErrorEvent) => { @@ -373,8 +401,8 @@ export class FrameDecoder { this.zipWorker.postMessage({ block, - start, - end, + start: 0, + end: chunkFrameNumbers.length - 1, dimension: this.dimension, dimension2D: DimensionType.DIMENSION_2D, }); @@ -400,9 +428,12 @@ export class FrameDecoder { } public cachedChunks(includeInProgress = false): number[] { - const chunkIsBeingDecoded = includeInProgress && this.chunkIsBeingDecoded ? - Math.floor(this.chunkIsBeingDecoded.start / this.chunkSize) : null; - return Object.keys(this.decodedChunks).map((chunkNumber: string) => +chunkNumber).concat( + const chunkIsBeingDecoded = ( + includeInProgress && this.chunkIsBeingDecoded ? + this.chunkIsBeingDecoded.chunkIndex : + null + ); + return Object.keys(this.decodedChunks).map((chunkIndex: string) => +chunkIndex).concat( ...(chunkIsBeingDecoded !== null ? [chunkIsBeingDecoded] : []), ).sort((a, b) => a - b); } diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts index 31b73314a131..b3fa8b503aaa 100644 --- a/cvat-ui/src/actions/annotation-actions.ts +++ b/cvat-ui/src/actions/annotation-actions.ts @@ -587,12 +587,13 @@ export function confirmCanvasReadyAsync(): ThunkAction { const { instance: job } = state.annotation.job; const { changeFrameEvent } = state.annotation.player.frame; const chunks = await job.frames.cachedChunks() as number[]; - const { startFrame, stopFrame, dataChunkSize } = job; + const includedFrames = await job.frames.frameNumbers() as number[]; + const { frameCount, dataChunkSize } = job; const ranges = chunks.map((chunk) => ( [ - Math.max(startFrame, chunk * dataChunkSize), - Math.min(stopFrame, (chunk + 1) * dataChunkSize - 1), + includedFrames[chunk * dataChunkSize], + includedFrames[Math.min(frameCount - 1, (chunk + 1) * dataChunkSize - 1)], ] )).reduce>((acc, val) => { if (acc.length && acc[acc.length - 1][1] + 1 === val[0]) { @@ -905,7 +906,8 @@ export function getJobAsync({ // frame query parameter does not work for GT job const frameNumber = Number.isInteger(initialFrame) && gtJob?.id !== job.id ? - initialFrame as number : (await job.frames.search( + initialFrame as number : + (await job.frames.search( { notDeleted: !showDeletedFrames }, job.startFrame, job.stopFrame, )) || job.startFrame; diff --git a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx index 2088d14d7ccf..f1a2e9cf2892 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx @@ -169,17 +169,14 @@ function PlayerNavigation(props: Props): JSX.Element { {!!ranges && ( {ranges.split(';').map((range) => { - const [start, end] = range.split(':').map((num) => +num); - const adjustedStart = Math.max(0, start - 1); - let totalSegments = stopFrame - startFrame; - if (totalSegments === 0) { - // corner case for jobs with one image - totalSegments = 1; - } + const [rangeStart, rangeStop] = range.split(':').map((num) => +num); + const totalSegments = stopFrame - startFrame + 1; const segmentWidth = 1000 / totalSegments; - const width = Math.max((end - adjustedStart), 1) * segmentWidth; - const offset = (Math.max((adjustedStart - startFrame), 0) / totalSegments) * 1000; - return (); + const width = (rangeStop - rangeStart + 1) * segmentWidth; + const offset = (Math.max((rangeStart - startFrame), 0) / totalSegments) * 1000; + return ( + + ); })} )} diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 534f885449e1..eb8fdf26b52c 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -30,8 +30,8 @@ from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.dataset_manager.util import add_prefetch_fields -from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, DimensionType, Job, +from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality, FrameOutputType +from cvat.apps.engine.models import (AttributeSpec, AttributeType, DimensionType, Job, JobType, Label, LabelType, Project, SegmentType, ShapeType, Task) from cvat.apps.engine.rq_job_handler import RQJobMetaField @@ -301,7 +301,7 @@ def start(self) -> int: @property def stop(self) -> int: - return len(self) + return max(0, len(self) - 1) def _get_queryset(self): raise NotImplementedError() @@ -437,7 +437,7 @@ def _export_tag(self, tag): def _export_track(self, track, idx): track['shapes'] = list(filter(lambda x: not self._is_frame_deleted(x['frame']), track['shapes'])) tracked_shapes = TrackManager.get_interpolated_shapes( - track, 0, self.stop, self._annotation_ir.dimension) + track, 0, self.stop + 1, self._annotation_ir.dimension) for tracked_shape in tracked_shapes: tracked_shape["attributes"] += track["attributes"] tracked_shape["track_id"] = track["track_id"] if self._use_server_track_ids else idx @@ -493,7 +493,7 @@ def get_frame(idx): anno_manager = AnnotationManager(self._annotation_ir) for shape in sorted( - anno_manager.to_shapes(self.stop, self._annotation_ir.dimension, + anno_manager.to_shapes(self.stop + 1, self._annotation_ir.dimension, # Skip outside, deleted and excluded frames included_frames=included_frames, include_outside=False, @@ -840,7 +840,7 @@ def start(self) -> int: @property def stop(self) -> int: segment = self._db_job.segment - return segment.stop_frame + 1 + return segment.stop_frame @property def db_instance(self): @@ -1410,7 +1410,7 @@ def add_task(self, task, files): @attrs(frozen=True, auto_attribs=True) class ImageSource: - db_data: Data + db_task: Task is_video: bool = attrib(kw_only=True) class ImageProvider: @@ -1439,8 +1439,10 @@ def video_frame_loader(_): # optimization for videos: use numpy arrays instead of bytes # some formats or transforms can require image data return self._frame_provider.get_frame(frame_index, - quality=FrameProvider.Quality.ORIGINAL, - out_type=FrameProvider.Type.NUMPY_ARRAY)[0] + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.NUMPY_ARRAY + ).data + return dm.Image(data=video_frame_loader, **image_kwargs) else: def image_loader(_): @@ -1448,8 +1450,10 @@ def image_loader(_): # for images use encoded data to avoid recoding return self._frame_provider.get_frame(frame_index, - quality=FrameProvider.Quality.ORIGINAL, - out_type=FrameProvider.Type.BUFFER)[0].getvalue() + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.BUFFER + ).data.getvalue() + return dm.ByteImage(data=image_loader, **image_kwargs) def _load_source(self, source_id: int, source: ImageSource) -> None: @@ -1457,7 +1461,7 @@ def _load_source(self, source_id: int, source: ImageSource) -> None: return self._unload_source() - self._frame_provider = FrameProvider(source.db_data) + self._frame_provider = TaskFrameProvider(source.db_task) self._current_source_id = source_id def _unload_source(self) -> None: @@ -1473,7 +1477,7 @@ def __init__(self, sources: Dict[int, ImageSource]) -> None: self._images_per_source = { source_id: { image.id: image - for image in source.db_data.images.prefetch_related('related_files') + for image in source.db_task.data.images.prefetch_related('related_files') } for source_id, source in sources.items() } @@ -1482,7 +1486,7 @@ def get_image_for_frame(self, source_id: int, frame_id: int, **image_kwargs): source = self._sources[source_id] point_cloud_path = osp.join( - source.db_data.get_upload_dirname(), image_kwargs['path'], + source.db_task.data.get_upload_dirname(), image_kwargs['path'], ) image = self._images_per_source[source_id][frame_id] @@ -1595,11 +1599,18 @@ def __init__( is_video = instance_meta['mode'] == 'interpolation' ext = '' if is_video: - ext = FrameProvider.VIDEO_FRAME_EXT + ext = TaskFrameProvider.VIDEO_FRAME_EXT if dimension == DimensionType.DIM_3D or include_images: + if isinstance(instance_data, TaskData): + db_task = instance_data.db_instance + elif isinstance(instance_data, JobData): + db_task = instance_data.db_instance.segment.task + else: + assert False + self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[dimension]( - {0: ImageSource(instance_data.db_data, is_video=is_video)} + {0: ImageSource(db_task, is_video=is_video)} ) for frame_data in instance_data.group_by_frame(include_empty=True): @@ -1681,13 +1692,13 @@ def __init__( if self._dimension == DimensionType.DIM_3D or include_images: self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[self._dimension]( { - task.id: ImageSource(task.data, is_video=task.mode == 'interpolation') + task.id: ImageSource(task, is_video=task.mode == 'interpolation') for task in project_data.tasks } ) ext_per_task: Dict[int, str] = { - task.id: FrameProvider.VIDEO_FRAME_EXT if is_video else '' + task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else '' for task in project_data.tasks for is_video in [task.mode == 'interpolation'] } diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index 0191dfe1c8c4..4651fd398451 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -27,7 +27,7 @@ import_dm_annotations, match_dm_item) from cvat.apps.dataset_manager.util import make_zip_archive -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider from .registry import dm_env, exporter, importer @@ -1371,16 +1371,19 @@ def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callb dumper.close_document() def dump_media_files(instance_data: CommonData, img_dir: str, project_data: ProjectData = None): + frame_provider = make_frame_provider(instance_data.db_instance) + ext = '' if instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation': - ext = FrameProvider.VIDEO_FRAME_EXT - - frame_provider = FrameProvider(instance_data.db_data) - frames = frame_provider.get_frames( - instance_data.start, instance_data.stop, - frame_provider.Quality.ORIGINAL, - frame_provider.Type.BUFFER) - for frame_id, (frame_data, _) in zip(instance_data.rel_range, frames): + ext = frame_provider.VIDEO_FRAME_EXT + + frames = frame_provider.iterate_frames( + start_frame=instance_data.start, + stop_frame=instance_data.stop, + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.BUFFER, + ) + for frame_id, frame in zip(instance_data.rel_range, frames): if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \ or frame_id in instance_data.deleted_frames: continue @@ -1389,7 +1392,7 @@ def dump_media_files(instance_data: CommonData, img_dir: str, project_data: Proj img_path = osp.join(img_dir, frame_name + ext) os.makedirs(osp.dirname(img_path), exist_ok=True) with open(img_path, 'wb') as f: - f.write(frame_data.getvalue()) + f.write(frame.data.getvalue()) def _export_task_or_job(dst_file, temp_dir, instance_data, anno_callback, save_images=False): with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f: diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 6a03e41c8aac..42b2337304b4 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -14,10 +14,8 @@ from datumaro.components.dataset import Dataset, DatasetItem from datumaro.components.annotation import Mask from django.contrib.auth.models import Group, User -from PIL import Image from rest_framework import status -from rest_framework.test import APIClient, APITestCase import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.annotation import AnnotationIR @@ -26,36 +24,13 @@ from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import Task -from cvat.apps.engine.tests.utils import get_paginated_collection +from cvat.apps.engine.tests.utils import ( + get_paginated_collection, ForceLogin, generate_image_file, ApiTestBase +) - -def generate_image_file(filename, size=(100, 100)): - f = BytesIO() - image = Image.new('RGB', size=size) - image.save(f, 'jpeg') - f.name = filename - f.seek(0) - return f - -class ForceLogin: - def __init__(self, user, client): - self.user = user - self.client = client - - def __enter__(self): - if self.user: - self.client.force_login(self.user, - backend='django.contrib.auth.backends.ModelBackend') - - return self - - def __exit__(self, exception_type, exception_value, traceback): - if self.user: - self.client.logout() - -class _DbTestBase(APITestCase): +class _DbTestBase(ApiTestBase): def setUp(self): - self.client = APIClient() + super().setUp() @classmethod def setUpTestData(cls): @@ -94,6 +69,11 @@ def _create_task(self, data, image_data): response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index de961318a5ab..a0717c1ef111 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -171,6 +171,11 @@ def _create_task(self, data, image_data): response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) @@ -412,7 +417,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): url = self._generate_url_dump_tasks_annotations(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -520,7 +525,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): url = self._generate_url_dump_tasks_annotations(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -605,7 +610,7 @@ def test_api_v2_dump_tag_annotations(self): for user, edata in list(expected.items()): with self.subTest(format=f"{edata['name']}"): with TestDir() as test_dir: - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] url = self._generate_url_dump_tasks_annotations(task_id) @@ -847,7 +852,7 @@ def test_api_v2_export_dataset(self): # dump annotations url = self._generate_url_dump_task_dataset(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -2107,7 +2112,7 @@ def test_api_v2_export_import_dataset(self): self._create_annotations(task, dump_format_name, "random") for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -2170,7 +2175,7 @@ def test_api_v2_export_annotations(self): url = self._generate_url_dump_project_annotations(project['id'], dump_format_name) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') diff --git a/cvat/apps/engine/apps.py b/cvat/apps/engine/apps.py index 326920e8b494..bcad84510f5d 100644 --- a/cvat/apps/engine/apps.py +++ b/cvat/apps/engine/apps.py @@ -10,6 +10,14 @@ class EngineConfig(AppConfig): name = 'cvat.apps.engine' def ready(self): + from django.conf import settings + + from . import default_settings + + for key in dir(default_settings): + if key.isupper() and not hasattr(settings, key): + setattr(settings, key, getattr(default_settings, key)) + # Required to define signals in application import cvat.apps.engine.signals # Required in order to silent "unused-import" in pyflake diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index 2603c2fd5a13..bc4c8616bd7f 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -1,349 +1,700 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +from __future__ import annotations + import io import os -import zipfile -from datetime import datetime, timezone -from io import BytesIO -import shutil +import os.path +import pickle # nosec import tempfile +import zipfile import zlib - -from typing import Optional, Tuple - +from contextlib import ExitStack, closing +from datetime import datetime, timezone +from itertools import groupby, pairwise +from typing import ( + Any, + Callable, + Collection, + Generator, + Iterator, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, +) + +import av import cv2 import PIL.Image -import pickle # nosec -from django.conf import settings +import PIL.ImageOps from django.core.cache import caches from rest_framework.exceptions import NotFound, ValidationError -from cvat.apps.engine.cloud_provider import (Credentials, - db_storage_to_storage_instance, - get_cloud_storage_instance) +from cvat.apps.engine import models +from cvat.apps.engine.cloud_provider import ( + Credentials, + db_storage_to_storage_instance, + get_cloud_storage_instance, +) from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.media_extractors import (ImageDatasetManifestReader, - Mpeg4ChunkWriter, - Mpeg4CompressedChunkWriter, - VideoDatasetManifestReader, - ZipChunkWriter, - ZipCompressedChunkWriter) -from cvat.apps.engine.mime_types import mimetypes -from cvat.apps.engine.models import (DataChoice, DimensionType, Job, Image, - StorageChoice, CloudStorage) +from cvat.apps.engine.media_extractors import ( + FrameQuality, + IChunkWriter, + ImageReaderWithManifest, + Mpeg4ChunkWriter, + Mpeg4CompressedChunkWriter, + VideoReader, + VideoReaderWithManifest, + ZipChunkWriter, + ZipCompressedChunkWriter, +) from cvat.apps.engine.utils import md5_hash, preload_images from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) + +DataWithMime = Tuple[io.BytesIO, str] +_CacheItem = Tuple[io.BytesIO, str, int] + + class MediaCache: - def __init__(self, dimension=DimensionType.DIM_2D): - self._dimension = dimension - self._cache = caches['media'] - - def _get_or_set_cache_item(self, key, create_function): - def create_item(): - slogger.glob.info(f'Starting to prepare chunk: key {key}') - item = create_function() - slogger.glob.info(f'Ending to prepare chunk: key {key}') - - if item[0]: - item = (item[0], item[1], zlib.crc32(item[0].getbuffer())) + def __init__(self) -> None: + self._cache = caches["media"] + + def _get_checksum(self, value: bytes) -> int: + return zlib.crc32(value) + + def _get_or_set_cache_item( + self, key: str, create_callback: Callable[[], DataWithMime] + ) -> _CacheItem: + def create_item() -> _CacheItem: + slogger.glob.info(f"Starting to prepare chunk: key {key}") + item_data = create_callback() + slogger.glob.info(f"Ending to prepare chunk: key {key}") + + item_data_bytes = item_data[0].getvalue() + item = (item_data[0], item_data[1], self._get_checksum(item_data_bytes)) + if item_data_bytes: self._cache.set(key, item) return item - slogger.glob.info(f'Starting to get chunk from cache: key {key}') - try: - item = self._cache.get(key) - except pickle.UnpicklingError: - slogger.glob.error(f'Unable to get item from cache: key {key}', exc_info=True) - item = None - slogger.glob.info(f'Ending to get chunk from cache: key {key}, is_cached {bool(item)}') - + item = self._get_cache_item(key) if not item: item = create_item() else: # compare checksum item_data = item[0].getbuffer() if isinstance(item[0], io.BytesIO) else item[0] item_checksum = item[2] if len(item) == 3 else None - if item_checksum != zlib.crc32(item_data): - slogger.glob.info(f'Recreating cache item {key} due to checksum mismatch') + if item_checksum != self._get_checksum(item_data): + slogger.glob.info(f"Recreating cache item {key} due to checksum mismatch") item = create_item() - return item[0], item[1] + return item - def get_task_chunk_data_with_mime(self, chunk_number, quality, db_data): - item = self._get_or_set_cache_item( - key=f'{db_data.id}_{chunk_number}_{quality}', - create_function=lambda: self._prepare_task_chunk(db_data, quality, chunk_number), - ) + def _get_cache_item(self, key: str) -> Optional[_CacheItem]: + slogger.glob.info(f"Starting to get chunk from cache: key {key}") + try: + item = self._cache.get(key) + except pickle.UnpicklingError: + slogger.glob.error(f"Unable to get item from cache: key {key}", exc_info=True) + item = None + slogger.glob.info(f"Ending to get chunk from cache: key {key}, is_cached {bool(item)}") return item - def get_selective_job_chunk_data_with_mime(self, chunk_number, quality, job): - item = self._get_or_set_cache_item( - key=f'job_{job.id}_{chunk_number}_{quality}', - create_function=lambda: self.prepare_selective_job_chunk(job, quality, chunk_number), - ) + def _has_key(self, key: str) -> bool: + return self._cache.has_key(key) + + def _make_cache_key_prefix( + self, obj: Union[models.Task, models.Segment, models.Job, models.CloudStorage] + ) -> str: + if isinstance(obj, models.Task): + return f"task_{obj.id}" + elif isinstance(obj, models.Segment): + return f"segment_{obj.id}" + elif isinstance(obj, models.Job): + return f"job_{obj.id}" + elif isinstance(obj, models.CloudStorage): + return f"cloudstorage_{obj.id}" + else: + assert False, f"Unexpected object type {type(obj)}" - return item + def _make_chunk_key( + self, + db_obj: Union[models.Task, models.Segment, models.Job], + chunk_number: int, + *, + quality: FrameQuality, + ) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_chunk_{chunk_number}_{quality}" + + def _make_preview_key(self, db_obj: Union[models.Segment, models.CloudStorage]) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_preview" - def get_local_preview_with_mime(self, frame_number, db_data): - item = self._get_or_set_cache_item( - key=f'data_{db_data.id}_{frame_number}_preview', - create_function=lambda: self._prepare_local_preview(frame_number, db_data), + def _make_segment_task_chunk_key( + self, + db_obj: models.Segment, + chunk_number: int, + *, + quality: FrameQuality, + ) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_task_chunk_{chunk_number}_{quality}" + + def _make_context_image_preview_key(self, db_data: models.Data, frame_number: int) -> str: + return f"context_image_{db_data.id}_{frame_number}_preview" + + @overload + def _to_data_with_mime(self, cache_item: _CacheItem) -> DataWithMime: ... + + @overload + def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: ... + + def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: + if not cache_item: + return None + + return cache_item[:2] + + def get_or_set_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_segment, chunk_number, quality=quality), + create_callback=lambda: self.prepare_segment_chunk( + db_segment, chunk_number, quality=quality + ), + ) ) - return item + def get_task_chunk( + self, db_task: models.Task, chunk_number: int, *, quality: FrameQuality + ) -> Optional[DataWithMime]: + return self._to_data_with_mime( + self._get_cache_item(key=self._make_chunk_key(db_task, chunk_number, quality=quality)) + ) - def get_cloud_preview_with_mime( + def get_or_set_task_chunk( self, - db_storage: CloudStorage, - ) -> Optional[Tuple[io.BytesIO, str]]: - key = f'cloudstorage_{db_storage.id}_preview' - return self._cache.get(key) + db_task: models.Task, + chunk_number: int, + *, + quality: FrameQuality, + set_callback: Callable[[], DataWithMime], + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_task, chunk_number, quality=quality), + create_callback=set_callback, + ) + ) - def get_or_set_cloud_preview_with_mime( - self, - db_storage: CloudStorage, - ) -> Tuple[io.BytesIO, str]: - key = f'cloudstorage_{db_storage.id}_preview' + def get_segment_task_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> Optional[DataWithMime]: + return self._to_data_with_mime( + self._get_cache_item( + key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality) + ) + ) - item = self._get_or_set_cache_item( - key, create_function=lambda: self._prepare_cloud_preview(db_storage) + def get_or_set_segment_task_chunk( + self, + db_segment: models.Segment, + chunk_number: int, + *, + quality: FrameQuality, + set_callback: Callable[[], DataWithMime], + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), + create_callback=set_callback, + ) ) - return item + def get_or_set_selective_job_chunk( + self, db_job: models.Job, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_job, chunk_number, quality=quality), + create_callback=lambda: self.prepare_masked_range_segment_chunk( + db_job.segment, chunk_number, quality=quality + ), + ) + ) - def get_frame_context_images(self, db_data, frame_number): - item = self._get_or_set_cache_item( - key=f'context_image_{db_data.id}_{frame_number}', - create_function=lambda: self._prepare_context_image(db_data, frame_number) + def get_or_set_segment_preview(self, db_segment: models.Segment) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + self._make_preview_key(db_segment), + create_callback=lambda: self._prepare_segment_preview(db_segment), + ) ) - return item + def get_cloud_preview(self, db_storage: models.CloudStorage) -> Optional[DataWithMime]: + return self._to_data_with_mime(self._get_cache_item(self._make_preview_key(db_storage))) - @staticmethod - def _get_frame_provider_class(): - from cvat.apps.engine.frame_provider import \ - FrameProvider # TODO: remove circular dependency - return FrameProvider - - from contextlib import contextmanager - - @staticmethod - @contextmanager - def _get_images(db_data, chunk_number, dimension): - images = [] - tmp_dir = None - upload_dir = { - StorageChoice.LOCAL: db_data.get_upload_dirname(), - StorageChoice.SHARE: settings.SHARE_ROOT, - StorageChoice.CLOUD_STORAGE: db_data.get_upload_dirname(), - }[db_data.storage] + def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + self._make_preview_key(db_storage), + create_callback=lambda: self._prepare_cloud_preview(db_storage), + ) + ) - try: - if hasattr(db_data, 'video'): - source_path = os.path.join(upload_dir, db_data.video.path) - - reader = VideoDatasetManifestReader(manifest_path=db_data.get_manifest_path(), - source_path=source_path, chunk_number=chunk_number, - chunk_size=db_data.chunk_size, start=db_data.start_frame, - stop=db_data.stop_frame, step=db_data.get_frame_step()) - for frame in reader: - images.append((frame, source_path, None)) - else: - reader = ImageDatasetManifestReader(manifest_path=db_data.get_manifest_path(), - chunk_number=chunk_number, chunk_size=db_data.chunk_size, - start=db_data.start_frame, stop=db_data.stop_frame, - step=db_data.get_frame_step()) - if db_data.storage == StorageChoice.CLOUD_STORAGE: - db_cloud_storage = db_data.cloud_storage - assert db_cloud_storage, 'Cloud storage instance was deleted' - credentials = Credentials() - credentials.convert_from_db({ - 'type': db_cloud_storage.credentials_type, - 'value': db_cloud_storage.credentials, - }) - details = { - 'resource': db_cloud_storage.resource, - 'credentials': credentials, - 'specific_attributes': db_cloud_storage.get_specific_attributes() + def get_or_set_frame_context_images_chunk( + self, db_data: models.Data, frame_number: int + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_context_image_preview_key(db_data, frame_number), + create_callback=lambda: self.prepare_context_images_chunk(db_data, frame_number), + ) + ) + + def _read_raw_images( + self, + db_task: models.Task, + frame_ids: Sequence[int], + *, + manifest_path: str, + ): + db_data = db_task.data + + if os.path.isfile(manifest_path) and db_data.storage == models.StorageChoice.CLOUD_STORAGE: + reader = ImageReaderWithManifest(manifest_path) + with ExitStack() as es: + db_cloud_storage = db_data.cloud_storage + assert db_cloud_storage, "Cloud storage instance was deleted" + credentials = Credentials() + credentials.convert_from_db( + { + "type": db_cloud_storage.credentials_type, + "value": db_cloud_storage.credentials, } - cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details) + ) + details = { + "resource": db_cloud_storage.resource, + "credentials": credentials, + "specific_attributes": db_cloud_storage.get_specific_attributes(), + } + cloud_storage_instance = get_cloud_storage_instance( + cloud_provider=db_cloud_storage.provider_type, **details + ) + + tmp_dir = es.enter_context(tempfile.TemporaryDirectory(prefix="cvat")) + files_to_download = [] + checksums = [] + media = [] + for item in reader.iterate_frames(frame_ids): + file_name = f"{item['name']}{item['extension']}" + fs_filename = os.path.join(tmp_dir, file_name) + + files_to_download.append(file_name) + checksums.append(item.get("checksum", None)) + media.append((fs_filename, fs_filename, None)) + + cloud_storage_instance.bulk_download_to_dir( + files=files_to_download, upload_dir=tmp_dir + ) + media = preload_images(media) + + for checksum, (_, fs_filename, _) in zip(checksums, media): + if checksum and not md5_hash(fs_filename) == checksum: + slogger.cloud_storage[db_cloud_storage.id].warning( + "Hash sums of files {} do not match".format(file_name) + ) + + yield from media + else: + requested_frame_iter = iter(frame_ids) + next_requested_frame_id = next(requested_frame_iter, None) + if next_requested_frame_id is None: + return + + # TODO: find a way to use prefetched results, if provided + db_images = ( + db_data.images.order_by("frame") + .filter(frame__gte=frame_ids[0], frame__lte=frame_ids[-1]) + .values_list("frame", "path") + .all() + ) - tmp_dir = tempfile.mkdtemp(prefix='cvat') - files_to_download = [] - checksums = [] - for item in reader: - file_name = f"{item['name']}{item['extension']}" - fs_filename = os.path.join(tmp_dir, file_name) + raw_data_dir = db_data.get_raw_data_dirname() + media = [] + for frame_id, frame_path in db_images: + if frame_id == next_requested_frame_id: + source_path = os.path.join(raw_data_dir, frame_path) + media.append((source_path, source_path, None)) - files_to_download.append(file_name) - checksums.append(item.get('checksum', None)) - images.append((fs_filename, fs_filename, None)) + next_requested_frame_id = next(requested_frame_iter, None) - cloud_storage_instance.bulk_download_to_dir(files=files_to_download, upload_dir=tmp_dir) - images = preload_images(images) + if next_requested_frame_id is None: + break - for checksum, (_, fs_filename, _) in zip(checksums, images): - if checksum and not md5_hash(fs_filename) == checksum: - slogger.cloud_storage[db_cloud_storage.id].warning('Hash sums of files {} do not match'.format(file_name)) - else: - for item in reader: - source_path = os.path.join(upload_dir, f"{item['name']}{item['extension']}") - images.append((source_path, source_path, None)) - if dimension == DimensionType.DIM_2D: - images = preload_images(images) - - yield images - finally: - if db_data.storage == StorageChoice.CLOUD_STORAGE and tmp_dir is not None: - shutil.rmtree(tmp_dir) - - def _prepare_task_chunk(self, db_data, quality, chunk_number): - FrameProvider = self._get_frame_provider_class() - - writer_classes = { - FrameProvider.Quality.COMPRESSED : Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == DataChoice.VIDEO else ZipCompressedChunkWriter, - FrameProvider.Quality.ORIGINAL : Mpeg4ChunkWriter if db_data.original_chunk_type == DataChoice.VIDEO else ZipChunkWriter, - } - - image_quality = 100 if writer_classes[quality] in [Mpeg4ChunkWriter, ZipChunkWriter] else db_data.image_quality - mime_type = 'video/mp4' if writer_classes[quality] in [Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter] else 'application/zip' - - kwargs = {} - if self._dimension == DimensionType.DIM_3D: - kwargs["dimension"] = DimensionType.DIM_3D - writer = writer_classes[quality](image_quality, **kwargs) - - buff = BytesIO() - with self._get_images(db_data, chunk_number, self._dimension) as images: - writer.save_as_chunk(images, buff) - buff.seek(0) + assert next_requested_frame_id is None - return buff, mime_type + if db_task.dimension == models.DimensionType.DIM_2D: + media = preload_images(media) - def prepare_selective_job_chunk(self, db_job: Job, quality, chunk_number: int): - db_data = db_job.segment.task.data + yield from media - FrameProvider = self._get_frame_provider_class() - frame_provider = FrameProvider(db_data, self._dimension) + def _read_raw_frames( + self, db_task: models.Task, frame_ids: Sequence[int] + ) -> Generator[Tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]: + for prev_frame, cur_frame in pairwise(frame_ids): + assert ( + prev_frame <= cur_frame + ), f"Requested frame ids must be sorted, got a ({prev_frame}, {cur_frame}) pair" - frame_set = db_job.segment.frame_set - frame_step = db_data.get_frame_step() - chunk_frames = [] + db_data = db_task.data - writer = ZipCompressedChunkWriter(db_data.image_quality, dimension=self._dimension) - dummy_frame = BytesIO() - PIL.Image.new('RGB', (1, 1)).save(dummy_frame, writer.IMAGE_EXT) + manifest_path = db_data.get_manifest_path() - if hasattr(db_data, 'video'): - frame_size = (db_data.video.width, db_data.video.height) - else: - frame_size = None + if hasattr(db_data, "video"): + source_path = os.path.join(db_data.get_raw_data_dirname(), db_data.video.path) - for frame_idx in range(db_data.chunk_size): - frame_idx = ( - db_data.start_frame + chunk_number * db_data.chunk_size + frame_idx * frame_step + reader = VideoReaderWithManifest( + manifest_path=manifest_path, + source_path=source_path, + allow_threading=False, ) - if db_data.stop_frame < frame_idx: - break - - frame_bytes = None - - if frame_idx in frame_set: - frame_bytes = frame_provider.get_frame(frame_idx, quality=quality)[0] + if not os.path.isfile(manifest_path): + try: + reader.manifest.link(source_path, force=True) + reader.manifest.create() + except Exception as e: + slogger.task[db_task.id].warning( + f"Failed to create video manifest: {e}", exc_info=True + ) + reader = None + + if reader: + for frame in reader.iterate_frames(frame_filter=frame_ids): + yield (frame, source_path, None) + else: + reader = VideoReader([source_path], allow_threading=False) - if frame_size is not None: - # Decoded video frames can have different size, restore the original one + for frame_tuple in reader.iterate_frames(frame_filter=frame_ids): + yield frame_tuple + else: + yield from self._read_raw_images(db_task, frame_ids, manifest_path=manifest_path) + + def prepare_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + if db_segment.type == models.SegmentType.RANGE: + return self.prepare_range_segment_chunk(db_segment, chunk_number, quality=quality) + elif db_segment.type == models.SegmentType.SPECIFIC_FRAMES: + return self.prepare_masked_range_segment_chunk( + db_segment, chunk_number, quality=quality + ) + else: + assert False, f"Unknown segment type {db_segment.type}" + + def prepare_range_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + db_task = db_segment.task + db_data = db_task.data + + chunk_size = db_data.chunk_size + chunk_frame_ids = list(db_segment.frame_set)[ + chunk_size * chunk_number : chunk_size * (chunk_number + 1) + ] + + return self.prepare_custom_range_segment_chunk(db_task, chunk_frame_ids, quality=quality) + + def prepare_custom_range_segment_chunk( + self, db_task: models.Task, frame_ids: Sequence[int], *, quality: FrameQuality + ) -> DataWithMime: + with closing(self._read_raw_frames(db_task, frame_ids=frame_ids)) as frame_iter: + return prepare_chunk(frame_iter, quality=quality, db_task=db_task) + + def prepare_masked_range_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + db_task = db_segment.task + db_data = db_task.data + + chunk_size = db_data.chunk_size + chunk_frame_ids = sorted(db_segment.frame_set)[ + chunk_size * chunk_number : chunk_size * (chunk_number + 1) + ] + + return self.prepare_custom_masked_range_segment_chunk( + db_task, chunk_frame_ids, chunk_number, quality=quality + ) - frame = PIL.Image.open(frame_bytes) - if frame.size != frame_size: - frame = frame.resize(frame_size) + def prepare_custom_masked_range_segment_chunk( + self, + db_task: models.Task, + frame_ids: Collection[int], + chunk_number: int, + *, + quality: FrameQuality, + insert_placeholders: bool = False, + ) -> DataWithMime: + db_data = db_task.data - frame_bytes = BytesIO() - frame.save(frame_bytes, writer.IMAGE_EXT) - frame_bytes.seek(0) + frame_step = db_data.get_frame_step() - else: - # Populate skipped frames with placeholder data, - # this is required for video chunk decoding implementation in UI - frame_bytes = BytesIO(dummy_frame.getvalue()) + image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality + writer = ZipCompressedChunkWriter(image_quality, dimension=db_task.dimension) + + dummy_frame = io.BytesIO() + PIL.Image.new("RGB", (1, 1)).save(dummy_frame, writer.IMAGE_EXT) + + # Optimize frame access if all the required frames are already cached + # Otherwise we might need to download files. + # This is not needed for video tasks, as it will reduce performance + from cvat.apps.engine.frame_provider import FrameOutputType, TaskFrameProvider + + task_frame_provider = TaskFrameProvider(db_task) + + use_cached_data = False + if db_task.mode != "interpolation": + required_frame_set = set(frame_ids) + available_chunks = [ + self._has_key(self._make_chunk_key(db_segment, chunk_number, quality=quality)) + for db_segment in db_task.segment_set.filter(type=models.SegmentType.RANGE).all() + for chunk_number, _ in groupby( + sorted(required_frame_set.intersection(db_segment.frame_set)), + key=lambda frame: frame // db_data.chunk_size, + ) + ] + use_cached_data = bool(available_chunks) and all(available_chunks) + + if hasattr(db_data, "video"): + frame_size = (db_data.video.width, db_data.video.height) + else: + frame_size = None - if frame_bytes is not None: - chunk_frames.append((frame_bytes, None, None)) + def get_frames(): + with ExitStack() as es: + es.callback(task_frame_provider.unload) + + if insert_placeholders: + frame_range = ( + ( + db_data.start_frame + + chunk_number * db_data.chunk_size + + chunk_frame_idx * frame_step + ) + for chunk_frame_idx in range(db_data.chunk_size) + ) + else: + frame_range = frame_ids + + if not use_cached_data: + frames_gen = self._read_raw_frames(db_task, frame_ids) + frames_iter = iter(es.enter_context(closing(frames_gen))) + + for abs_frame_idx in frame_range: + if db_data.stop_frame < abs_frame_idx: + break + + if abs_frame_idx in frame_ids: + if use_cached_data: + frame_data = task_frame_provider.get_frame( + task_frame_provider.get_rel_frame_number(abs_frame_idx), + quality=quality, + out_type=FrameOutputType.BUFFER, + ) + frame = frame_data.data + else: + frame, _, _ = next(frames_iter) + + if hasattr(db_data, "video"): + # Decoded video frames can have different size, restore the original one + + if isinstance(frame, av.VideoFrame): + frame = frame.to_image() + else: + frame = PIL.Image.open(frame) + + if frame.size != frame_size: + frame = frame.resize(frame_size) + else: + # Populate skipped frames with placeholder data, + # this is required for video chunk decoding implementation in UI + frame = io.BytesIO(dummy_frame.getvalue()) + + yield (frame, None, None) + + buff = io.BytesIO() + with closing(get_frames()) as frame_iter: + writer.save_as_chunk( + frame_iter, + buff, + zip_compress_level=1, + # there are likely to be many skips with repeated placeholder frames + # in SPECIFIC_FRAMES segments, it makes sense to compress the archive + ) - buff = BytesIO() - writer.save_as_chunk(chunk_frames, buff, compress_frames=False, - zip_compress_level=1 # these are likely to be many skips in SPECIFIC_FRAMES segments - ) buff.seek(0) + return buff, get_chunk_mime_type_for_writer(writer) - return buff, 'application/zip' + def _prepare_segment_preview(self, db_segment: models.Segment) -> DataWithMime: + if db_segment.task.dimension == models.DimensionType.DIM_3D: + # TODO + preview = PIL.Image.open( + os.path.join(os.path.dirname(__file__), "assets/3d_preview.jpeg") + ) + else: + from cvat.apps.engine.frame_provider import ( # avoid circular import + FrameOutputType, + make_frame_provider, + ) - def _prepare_local_preview(self, frame_number, db_data): - FrameProvider = self._get_frame_provider_class() - frame_provider = FrameProvider(db_data, self._dimension) - buff, mime_type = frame_provider.get_preview(frame_number) + task_frame_provider = make_frame_provider(db_segment.task) + segment_frame_provider = make_frame_provider(db_segment) + preview = segment_frame_provider.get_frame( + task_frame_provider.get_rel_frame_number(min(db_segment.frame_set)), + quality=FrameQuality.COMPRESSED, + out_type=FrameOutputType.PIL, + ).data - return buff, mime_type + return prepare_preview_image(preview) - def _prepare_cloud_preview(self, db_storage): + def _prepare_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: storage = db_storage_to_storage_instance(db_storage) if not db_storage.manifests.count(): - raise ValidationError('Cannot get the cloud storage preview. There is no manifest file') + raise ValidationError("Cannot get the cloud storage preview. There is no manifest file") + preview_path = None - for manifest_model in db_storage.manifests.all(): - manifest_prefix = os.path.dirname(manifest_model.filename) - full_manifest_path = os.path.join(db_storage.get_storage_dirname(), manifest_model.filename) - if not os.path.exists(full_manifest_path) or \ - datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) < storage.get_file_last_modified(manifest_model.filename): - storage.download_file(manifest_model.filename, full_manifest_path) + for db_manifest in db_storage.manifests.all(): + manifest_prefix = os.path.dirname(db_manifest.filename) + + full_manifest_path = os.path.join( + db_storage.get_storage_dirname(), db_manifest.filename + ) + if not os.path.exists(full_manifest_path) or datetime.fromtimestamp( + os.path.getmtime(full_manifest_path), tz=timezone.utc + ) < storage.get_file_last_modified(db_manifest.filename): + storage.download_file(db_manifest.filename, full_manifest_path) + manifest = ImageManifestManager( - os.path.join(db_storage.get_storage_dirname(), manifest_model.filename), - db_storage.get_storage_dirname() + os.path.join(db_storage.get_storage_dirname(), db_manifest.filename), + db_storage.get_storage_dirname(), ) # need to update index manifest.set_index() if not len(manifest): continue + preview_info = manifest[0] - preview_filename = ''.join([preview_info['name'], preview_info['extension']]) + preview_filename = "".join([preview_info["name"], preview_info["extension"]]) preview_path = os.path.join(manifest_prefix, preview_filename) break + if not preview_path: - msg = 'Cloud storage {} does not contain any images'.format(db_storage.pk) + msg = "Cloud storage {} does not contain any images".format(db_storage.pk) slogger.cloud_storage[db_storage.pk].info(msg) raise NotFound(msg) buff = storage.download_fileobj(preview_path) - mime_type = mimetypes.guess_type(preview_path)[0] + image = PIL.Image.open(buff) + return prepare_preview_image(image) - return buff, mime_type + def prepare_context_images_chunk(self, db_data: models.Data, frame_number: int) -> DataWithMime: + zip_buffer = io.BytesIO() - def _prepare_context_image(self, db_data, frame_number): - zip_buffer = BytesIO() - try: - image = Image.objects.get(data_id=db_data.id, frame=frame_number) - except Image.DoesNotExist: - return None, None - with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file: - if not image.related_files.count(): - return None, None - common_path = os.path.commonpath(list(map(lambda x: str(x.path), image.related_files.all()))) - for i in image.related_files.all(): + related_images = db_data.related_files.filter(primary_image__frame=frame_number).all() + if not related_images: + return zip_buffer, "" + + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + common_path = os.path.commonpath(list(map(lambda x: str(x.path), related_images))) + for i in related_images: path = os.path.realpath(str(i.path)) name = os.path.relpath(str(i.path), common_path) image = cv2.imread(path) - success, result = cv2.imencode('.JPEG', image) + success, result = cv2.imencode(".JPEG", image) if not success: raise Exception('Failed to encode image to ".jpeg" format') - zip_file.writestr(f'{name}.jpg', result.tobytes()) - mime_type = 'application/zip' + zip_file.writestr(f"{name}.jpg", result.tobytes()) + zip_buffer.seek(0) + mime_type = "application/zip" return zip_buffer, mime_type + + +def prepare_preview_image(image: PIL.Image.Image) -> DataWithMime: + PREVIEW_SIZE = (256, 256) + PREVIEW_MIME = "image/jpeg" + + image = PIL.ImageOps.exif_transpose(image) + image.thumbnail(PREVIEW_SIZE) + + output_buf = io.BytesIO() + image.convert("RGB").save(output_buf, format="JPEG") + return output_buf, PREVIEW_MIME + + +def prepare_chunk( + task_chunk_frames: Iterator[Tuple[Any, str, int]], + *, + quality: FrameQuality, + db_task: models.Task, + dump_unchanged: bool = False, +) -> DataWithMime: + # TODO: refactor all chunk building into another class + + db_data = db_task.data + + writer_classes: dict[FrameQuality, Type[IChunkWriter]] = { + FrameQuality.COMPRESSED: ( + Mpeg4CompressedChunkWriter + if db_data.compressed_chunk_type == models.DataChoice.VIDEO + else ZipCompressedChunkWriter + ), + FrameQuality.ORIGINAL: ( + Mpeg4ChunkWriter + if db_data.original_chunk_type == models.DataChoice.VIDEO + else ZipChunkWriter + ), + } + + writer_class = writer_classes[quality] + + image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality + + writer_kwargs = {} + if db_task.dimension == models.DimensionType.DIM_3D: + writer_kwargs["dimension"] = models.DimensionType.DIM_3D + merged_chunk_writer = writer_class(image_quality, **writer_kwargs) + + writer_kwargs = {} + if dump_unchanged and isinstance(merged_chunk_writer, ZipCompressedChunkWriter): + writer_kwargs = dict(compress_frames=False, zip_compress_level=1) + + buffer = io.BytesIO() + merged_chunk_writer.save_as_chunk(task_chunk_frames, buffer, **writer_kwargs) + + buffer.seek(0) + return buffer, get_chunk_mime_type_for_writer(writer_class) + + +def get_chunk_mime_type_for_writer(writer: Union[IChunkWriter, Type[IChunkWriter]]) -> str: + if isinstance(writer, IChunkWriter): + writer_class = type(writer) + else: + writer_class = writer + + if issubclass(writer_class, ZipChunkWriter): + return "application/zip" + elif issubclass(writer_class, Mpeg4ChunkWriter): + return "video/mp4" + else: + assert False, f"Unknown chunk writer class {writer_class}" diff --git a/cvat/apps/engine/default_settings.py b/cvat/apps/engine/default_settings.py new file mode 100644 index 000000000000..826fe1c9bef2 --- /dev/null +++ b/cvat/apps/engine/default_settings.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import os + +from attrs.converters import to_bool + +MEDIA_CACHE_ALLOW_STATIC_CACHE = to_bool(os.getenv("CVAT_ALLOW_STATIC_CACHE", False)) +""" +Allow or disallow static media cache. +If disabled, CVAT will only use the dynamic media cache. New tasks requesting static media cache +will be automatically switched to the dynamic cache. +When enabled, this option can increase data access speed and reduce server load, +but significantly increase disk space occupied by tasks. +""" diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 4e2f42ef7933..ea14b40a75ad 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -3,226 +3,693 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + +import io +import itertools import math -from enum import Enum +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto from io import BytesIO -import os - +from typing import ( + Any, + Callable, + Generic, + Iterator, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import av import cv2 import numpy as np -from PIL import Image, ImageOps +from datumaro.util import take_by +from django.conf import settings +from PIL import Image +from rest_framework.exceptions import ValidationError -from cvat.apps.engine.cache import MediaCache -from cvat.apps.engine.media_extractors import VideoReader, ZipReader +from cvat.apps.engine import models +from cvat.apps.engine.cache import DataWithMime, MediaCache, prepare_chunk +from cvat.apps.engine.media_extractors import ( + FrameQuality, + IMediaReader, + RandomAccessIterator, + VideoReader, + ZipReader, +) from cvat.apps.engine.mime_types import mimetypes -from cvat.apps.engine.models import DataChoice, StorageMethodChoice, DimensionType -from rest_framework.exceptions import ValidationError -class RandomAccessIterator: - def __init__(self, iterable): - self.iterable = iterable - self.iterator = None - self.pos = -1 - - def __iter__(self): - return self - - def __next__(self): - return self[self.pos + 1] - - def __getitem__(self, idx): - assert 0 <= idx - if self.iterator is None or idx <= self.pos: - self.reset() - v = None - while self.pos < idx: - # NOTE: don't keep the last item in self, it can be expensive - v = next(self.iterator) - self.pos += 1 - return v - - def reset(self): - self.close() - self.iterator = iter(self.iterable) - - def close(self): - if self.iterator is not None: - if close := getattr(self.iterator, 'close', None): - close() - self.iterator = None - self.pos = -1 - -class FrameProvider: - VIDEO_FRAME_EXT = '.PNG' - VIDEO_FRAME_MIME = 'image/png' - - class Quality(Enum): - COMPRESSED = 0 - ORIGINAL = 100 - - class Type(Enum): - BUFFER = 0 - PIL = 1 - NUMPY_ARRAY = 2 - - class ChunkLoader: - def __init__(self, reader_class, path_getter): - self.chunk_id = None +_T = TypeVar("_T") + + +class _ChunkLoader(metaclass=ABCMeta): + def __init__( + self, + reader_class: Type[IMediaReader], + *, + reader_params: Optional[dict] = None, + ) -> None: + self.chunk_id: Optional[int] = None + self.chunk_reader: Optional[RandomAccessIterator] = None + self.reader_class = reader_class + self.reader_params = reader_params + + def load(self, chunk_id: int) -> RandomAccessIterator[Tuple[Any, str, int]]: + if self.chunk_id != chunk_id: + self.unload() + + self.chunk_id = chunk_id + self.chunk_reader = RandomAccessIterator( + self.reader_class( + [self.read_chunk(chunk_id)[0]], + **(self.reader_params or {}), + ) + ) + return self.chunk_reader + + def unload(self): + self.chunk_id = None + if self.chunk_reader: + self.chunk_reader.close() self.chunk_reader = None - self.reader_class = reader_class - self.get_chunk_path = path_getter - - def load(self, chunk_id): - if self.chunk_id != chunk_id: - self.unload() - - self.chunk_id = chunk_id - self.chunk_reader = RandomAccessIterator( - self.reader_class([self.get_chunk_path(chunk_id)])) - return self.chunk_reader - - def unload(self): - self.chunk_id = None - if self.chunk_reader: - self.chunk_reader.close() - self.chunk_reader = None - - class BuffChunkLoader(ChunkLoader): - def __init__(self, reader_class, path_getter, quality, db_data): - super().__init__(reader_class, path_getter) - self.quality = quality - self.db_data = db_data - - def load(self, chunk_id): - if self.chunk_id != chunk_id: - self.chunk_id = chunk_id - self.chunk_reader = RandomAccessIterator( - self.reader_class([self.get_chunk_path(chunk_id, self.quality, self.db_data)[0]])) - return self.chunk_reader - - def __init__(self, db_data, dimension=DimensionType.DIM_2D): - self._db_data = db_data - self._dimension = dimension - self._loaders = {} - - reader_class = { - DataChoice.IMAGESET: ZipReader, - DataChoice.VIDEO: VideoReader, - } - if db_data.storage_method == StorageMethodChoice.CACHE: - cache = MediaCache(dimension=dimension) - - self._loaders[self.Quality.COMPRESSED] = self.BuffChunkLoader( - reader_class[db_data.compressed_chunk_type], - cache.get_task_chunk_data_with_mime, - self.Quality.COMPRESSED, - self._db_data) - self._loaders[self.Quality.ORIGINAL] = self.BuffChunkLoader( - reader_class[db_data.original_chunk_type], - cache.get_task_chunk_data_with_mime, - self.Quality.ORIGINAL, - self._db_data) - else: - self._loaders[self.Quality.COMPRESSED] = self.ChunkLoader( - reader_class[db_data.compressed_chunk_type], - db_data.get_compressed_chunk_path) - self._loaders[self.Quality.ORIGINAL] = self.ChunkLoader( - reader_class[db_data.original_chunk_type], - db_data.get_original_chunk_path) + @abstractmethod + def read_chunk(self, chunk_id: int) -> DataWithMime: ... - def __len__(self): - return self._db_data.size - def unload(self): - for loader in self._loaders.values(): - loader.unload() +class _FileChunkLoader(_ChunkLoader): + def __init__( + self, + reader_class: Type[IMediaReader], + get_chunk_path_callback: Callable[[int], str], + *, + reader_params: Optional[dict] = None, + ) -> None: + super().__init__(reader_class, reader_params=reader_params) + self.get_chunk_path = get_chunk_path_callback + + def read_chunk(self, chunk_id: int) -> DataWithMime: + chunk_path = self.get_chunk_path(chunk_id) + with open(chunk_path, "rb") as f: + return ( + io.BytesIO(f.read()), + mimetypes.guess_type(chunk_path)[0], + ) + + +class _BufferChunkLoader(_ChunkLoader): + def __init__( + self, + reader_class: Type[IMediaReader], + get_chunk_callback: Callable[[int], DataWithMime], + *, + reader_params: Optional[dict] = None, + ) -> None: + super().__init__(reader_class, reader_params=reader_params) + self.get_chunk = get_chunk_callback + + def read_chunk(self, chunk_id: int) -> DataWithMime: + return self.get_chunk(chunk_id) + - def _validate_frame_number(self, frame_number): - frame_number_ = int(frame_number) - if frame_number_ < 0 or frame_number_ >= self._db_data.size: - raise ValidationError('Incorrect requested frame number: {}'.format(frame_number_)) +class FrameOutputType(Enum): + BUFFER = auto() + PIL = auto() + NUMPY_ARRAY = auto() - chunk_number = frame_number_ // self._db_data.chunk_size - frame_offset = frame_number_ % self._db_data.chunk_size - return frame_number_, chunk_number, frame_offset +Frame2d = Union[BytesIO, np.ndarray, Image.Image] +Frame3d = BytesIO +AnyFrame = Union[Frame2d, Frame3d] - def get_chunk_number(self, frame_number): - return int(frame_number) // self._db_data.chunk_size - def _validate_chunk_number(self, chunk_number): - chunk_number_ = int(chunk_number) - if chunk_number_ < 0 or chunk_number_ >= math.ceil(self._db_data.size / self._db_data.chunk_size): - raise ValidationError('requested chunk does not exist') +@dataclass +class DataWithMeta(Generic[_T]): + data: _T + mime: str - return chunk_number_ + +class IFrameProvider(metaclass=ABCMeta): + VIDEO_FRAME_EXT = ".PNG" + VIDEO_FRAME_MIME = "image/png" + + def unload(self): + pass @classmethod - def _av_frame_to_png_bytes(cls, av_frame): + def _av_frame_to_png_bytes(cls, av_frame: av.VideoFrame) -> BytesIO: ext = cls.VIDEO_FRAME_EXT - image = av_frame.to_ndarray(format='bgr24') + image = av_frame.to_ndarray(format="bgr24") success, result = cv2.imencode(ext, image) if not success: - raise RuntimeError("Failed to encode image to '%s' format" % (ext)) + raise RuntimeError(f"Failed to encode image to '{ext}' format") return BytesIO(result.tobytes()) - def _convert_frame(self, frame, reader_class, out_type): - if out_type == self.Type.BUFFER: - return self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame - elif out_type == self.Type.PIL: - return frame.to_image() if reader_class is VideoReader else Image.open(frame) - elif out_type == self.Type.NUMPY_ARRAY: - if reader_class is VideoReader: - image = frame.to_ndarray(format='bgr24') + def _convert_frame( + self, frame: Any, reader_class: Type[IMediaReader], out_type: FrameOutputType + ) -> AnyFrame: + if out_type == FrameOutputType.BUFFER: + return ( + self._av_frame_to_png_bytes(frame) + if issubclass(reader_class, VideoReader) + else frame + ) + elif out_type == FrameOutputType.PIL: + return frame.to_image() if issubclass(reader_class, VideoReader) else Image.open(frame) + elif out_type == FrameOutputType.NUMPY_ARRAY: + if issubclass(reader_class, VideoReader): + image = frame.to_ndarray(format="bgr24") else: image = np.array(Image.open(frame)) if len(image.shape) == 3 and image.shape[2] in {3, 4}: - image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR + image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR return image else: - raise RuntimeError('unsupported output type') + raise RuntimeError("unsupported output type") + + @abstractmethod + def validate_frame_number(self, frame_number: int) -> int: ... + + @abstractmethod + def validate_chunk_number(self, chunk_number: int) -> int: ... + + @abstractmethod + def get_chunk_number(self, frame_number: int) -> int: ... + + @abstractmethod + def get_preview(self) -> DataWithMeta[BytesIO]: ... + + @abstractmethod + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: ... + + @abstractmethod + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: ... + + @abstractmethod + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: ... + + @abstractmethod + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: ... + + def _get_abs_frame_number(self, db_data: models.Data, rel_frame_number: int) -> int: + return db_data.start_frame + rel_frame_number * db_data.get_frame_step() + + def _get_rel_frame_number(self, db_data: models.Data, abs_frame_number: int) -> int: + return (abs_frame_number - db_data.start_frame) // db_data.get_frame_step() + + +class TaskFrameProvider(IFrameProvider): + def __init__(self, db_task: models.Task) -> None: + self._db_task = db_task + + def validate_frame_number(self, frame_number: int) -> int: + if frame_number not in range(0, self._db_task.data.size): + raise ValidationError( + f"Invalid frame '{frame_number}'. " + f"The frame number should be in the [0, {self._db_task.data.size}] range" + ) + + return frame_number + + def validate_chunk_number(self, chunk_number: int) -> int: + last_chunk = math.ceil(self._db_task.data.size / self._db_task.data.chunk_size) - 1 + if not 0 <= chunk_number <= last_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + f"The chunk number should be in the [0, {last_chunk}] range" + ) + + return chunk_number + + def get_chunk_number(self, frame_number: int) -> int: + return int(frame_number) // self._db_task.data.chunk_size + + def get_abs_frame_number(self, rel_frame_number: int) -> int: + "Returns absolute frame number in the task (in the range [start, stop, step])" + return super()._get_abs_frame_number(self._db_task.data, rel_frame_number) + + def get_rel_frame_number(self, abs_frame_number: int) -> int: + """ + Returns relative frame number in the task (in the range [0, task_size - 1]). + This is the "normal" frame number, expected in other methods. + """ + return super()._get_rel_frame_number(self._db_task.data, abs_frame_number) + + def get_preview(self) -> DataWithMeta[BytesIO]: + return self._get_segment_frame_provider(0).get_preview() + + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: + return_type = DataWithMeta[BytesIO] + chunk_number = self.validate_chunk_number(chunk_number) + + cache = MediaCache() + cached_chunk = cache.get_task_chunk(self._db_task, chunk_number, quality=quality) + if cached_chunk: + return return_type(cached_chunk[0], cached_chunk[1]) + + db_data = self._db_task.data + step = db_data.get_frame_step() + task_chunk_start_frame = chunk_number * db_data.chunk_size + task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1 + task_chunk_frame_set = set( + range( + db_data.start_frame + task_chunk_start_frame * step, + min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step, + step, + ) + ) + + matching_segments: list[models.Segment] = sorted( + [ + s + for s in self._db_task.segment_set.all() + if s.type == models.SegmentType.RANGE + if not task_chunk_frame_set.isdisjoint(s.frame_set) + ], + key=lambda s: s.start_frame, + ) + assert matching_segments + + # Don't put this into set_callback to avoid data duplication in the cache + + if len(matching_segments) == 1: + segment_frame_provider = SegmentFrameProvider(matching_segments[0]) + matching_chunk_index = segment_frame_provider.find_matching_chunk( + sorted(task_chunk_frame_set) + ) + if matching_chunk_index is not None: + # The requested frames match one of the job chunks, we can use it directly + return segment_frame_provider.get_chunk(matching_chunk_index, quality=quality) + + def _set_callback() -> DataWithMime: + # Create and return a joined / cleaned chunk + task_chunk_frames = {} + for db_segment in matching_segments: + segment_frame_provider = SegmentFrameProvider(db_segment) + segment_frame_set = db_segment.frame_set + + for task_chunk_frame_id in sorted(task_chunk_frame_set): + if ( + task_chunk_frame_id not in segment_frame_set + or task_chunk_frame_id in task_chunk_frames + ): + continue + + frame, frame_name, _ = segment_frame_provider._get_raw_frame( + self.get_rel_frame_number(task_chunk_frame_id), quality=quality + ) + task_chunk_frames[task_chunk_frame_id] = (frame, frame_name, None) + + return prepare_chunk( + task_chunk_frames.values(), + quality=quality, + db_task=self._db_task, + dump_unchanged=True, + ) + + buffer, mime_type = cache.get_or_set_task_chunk( + self._db_task, chunk_number, quality=quality, set_callback=_set_callback + ) + + return return_type(data=buffer, mime=mime_type) + + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: + return self._get_segment_frame_provider(frame_number).get_frame( + frame_number, quality=quality, out_type=out_type + ) + + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: + return self._get_segment_frame_provider(frame_number).get_frame_context_images_chunk( + frame_number + ) + + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: + frame_range = itertools.count(start_frame, self._db_task.data.get_frame_step()) + if stop_frame: + frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range) + + db_segment = None + db_segment_frame_set = None + db_segment_frame_provider = None + for idx in frame_range: + if db_segment and idx not in db_segment_frame_set: + db_segment = None + db_segment_frame_set = None + db_segment_frame_provider = None + + if not db_segment: + db_segment = self._get_segment(idx) + db_segment_frame_set = set(db_segment.frame_set) + db_segment_frame_provider = SegmentFrameProvider(db_segment) + + yield db_segment_frame_provider.get_frame(idx, quality=quality, out_type=out_type) + + def _get_segment(self, validated_frame_number: int) -> models.Segment: + if not self._db_task.data or not self._db_task.data.size: + raise ValidationError("Task has no data") + + abs_frame_number = self.get_abs_frame_number(validated_frame_number) + + return next( + s + for s in self._db_task.segment_set.all() + if s.type == models.SegmentType.RANGE + if abs_frame_number in s.frame_set + ) + + def _get_segment_frame_provider(self, frame_number: int) -> SegmentFrameProvider: + return SegmentFrameProvider(self._get_segment(self.validate_frame_number(frame_number))) + + +class SegmentFrameProvider(IFrameProvider): + def __init__(self, db_segment: models.Segment) -> None: + super().__init__() + self._db_segment = db_segment + + db_data = db_segment.task.data + + reader_class: dict[models.DataChoice, Tuple[Type[IMediaReader], Optional[dict]]] = { + models.DataChoice.IMAGESET: (ZipReader, None), + models.DataChoice.VIDEO: ( + VideoReader, + { + "allow_threading": False + # disable threading to avoid unpredictable server + # resource consumption during reading in endpoints + # can be enabled for other clients + }, + ), + } + + self._loaders: dict[FrameQuality, _ChunkLoader] = {} + if ( + db_data.storage_method == models.StorageMethodChoice.CACHE + or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE + # TODO: separate handling, extract cache creation logic from media cache + ): + cache = MediaCache() + + self._loaders[FrameQuality.COMPRESSED] = _BufferChunkLoader( + reader_class=reader_class[db_data.compressed_chunk_type][0], + reader_params=reader_class[db_data.compressed_chunk_type][1], + get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk( + db_segment, chunk_idx, quality=FrameQuality.COMPRESSED + ), + ) + + self._loaders[FrameQuality.ORIGINAL] = _BufferChunkLoader( + reader_class=reader_class[db_data.original_chunk_type][0], + reader_params=reader_class[db_data.original_chunk_type][1], + get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk( + db_segment, chunk_idx, quality=FrameQuality.ORIGINAL + ), + ) + else: + self._loaders[FrameQuality.COMPRESSED] = _FileChunkLoader( + reader_class=reader_class[db_data.compressed_chunk_type][0], + reader_params=reader_class[db_data.compressed_chunk_type][1], + get_chunk_path_callback=lambda chunk_idx: db_data.get_compressed_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + self._loaders[FrameQuality.ORIGINAL] = _FileChunkLoader( + reader_class=reader_class[db_data.original_chunk_type][0], + reader_params=reader_class[db_data.original_chunk_type][1], + get_chunk_path_callback=lambda chunk_idx: db_data.get_original_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + def unload(self): + for loader in self._loaders.values(): + loader.unload() + + def __len__(self): + return self._db_segment.frame_count + + def validate_frame_number(self, frame_number: int) -> Tuple[int, int, int]: + frame_sequence = list(self._db_segment.frame_set) + abs_frame_number = self._get_abs_frame_number(self._db_segment.task.data, frame_number) + if abs_frame_number not in frame_sequence: + raise ValidationError(f"Incorrect requested frame number: {frame_number}") + + # TODO: maybe optimize search + chunk_number, frame_position = divmod( + frame_sequence.index(abs_frame_number), self._db_segment.task.data.chunk_size + ) + return frame_number, chunk_number, frame_position + + def get_chunk_number(self, frame_number: int) -> int: + return int(frame_number) // self._db_segment.task.data.chunk_size + + def find_matching_chunk(self, frames: Sequence[int]) -> Optional[int]: + return next( + ( + i + for i, chunk_frames in enumerate( + take_by( + sorted(self._db_segment.frame_set), self._db_segment.task.data.chunk_size + ) + ) + if frames == set(chunk_frames) + ), + None, + ) + + def validate_chunk_number(self, chunk_number: int) -> int: + segment_size = self._db_segment.frame_count + last_chunk = math.ceil(segment_size / self._db_segment.task.data.chunk_size) - 1 + if not 0 <= chunk_number <= last_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + f"The chunk number should be in the [0, {last_chunk}] range" + ) + + return chunk_number + + def get_preview(self) -> DataWithMeta[BytesIO]: + cache = MediaCache() + preview, mime = cache.get_or_set_segment_preview(self._db_segment) + return DataWithMeta[BytesIO](preview, mime=mime) + + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: + chunk_number = self.validate_chunk_number(chunk_number) + chunk_data, mime = self._loaders[quality].read_chunk(chunk_number) + return DataWithMeta[BytesIO](chunk_data, mime=mime) + + def _get_raw_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + ) -> Tuple[Any, str, Type[IMediaReader]]: + _, chunk_number, frame_offset = self.validate_frame_number(frame_number) + loader = self._loaders[quality] + chunk_reader = loader.load(chunk_number) + frame, frame_name, _ = chunk_reader[frame_offset] + return frame, frame_name, loader.reader_class + + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: + return_type = DataWithMeta[AnyFrame] - def get_preview(self, frame_number): - PREVIEW_SIZE = (256, 256) - PREVIEW_MIME = 'image/jpeg' + frame, frame_name, reader_class = self._get_raw_frame(frame_number, quality=quality) - if self._dimension == DimensionType.DIM_3D: - # TODO - preview = Image.open(os.path.join(os.path.dirname(__file__), 'assets/3d_preview.jpeg')) + frame = self._convert_frame(frame, reader_class, out_type) + if issubclass(reader_class, VideoReader): + return return_type(frame, mime=self.VIDEO_FRAME_MIME) + + return return_type(frame, mime=mimetypes.guess_type(frame_name)[0]) + + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: + self.validate_frame_number(frame_number) + + db_data = self._db_segment.task.data + + cache = MediaCache() + if db_data.storage_method == models.StorageMethodChoice.CACHE: + data, mime = cache.get_or_set_frame_context_images_chunk(db_data, frame_number) else: - preview, _ = self.get_frame(frame_number, self.Quality.COMPRESSED, self.Type.PIL) + data, mime = cache.prepare_context_images_chunk(db_data, frame_number) + + if not data.getvalue(): + return None + + return DataWithMeta[BytesIO](data, mime=mime) + + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: + frame_range = itertools.count(start_frame) + if stop_frame: + frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range) + + segment_frame_set = set(self._db_segment.frame_set) + for idx in frame_range: + if self._get_abs_frame_number(self._db_segment.task.data, idx) in segment_frame_set: + yield self.get_frame(idx, quality=quality, out_type=out_type) + + +class JobFrameProvider(SegmentFrameProvider): + def __init__(self, db_job: models.Job) -> None: + super().__init__(db_job.segment) + + def get_chunk( + self, + chunk_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + is_task_chunk: bool = False, + ) -> DataWithMeta[BytesIO]: + if not is_task_chunk: + return super().get_chunk(chunk_number, quality=quality) + + # Backward compatibility for the "number" parameter + # Reproduce the task chunks, limited by this job + return_type = DataWithMeta[BytesIO] + + task_frame_provider = TaskFrameProvider(self._db_segment.task) + segment_start_chunk = task_frame_provider.get_chunk_number(self._db_segment.start_frame) + segment_stop_chunk = task_frame_provider.get_chunk_number(self._db_segment.stop_frame) + if not segment_start_chunk <= chunk_number <= segment_stop_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + "The chunk number should be in the " + f"[{segment_start_chunk}, {segment_stop_chunk}] range" + ) + + cache = MediaCache() + cached_chunk = cache.get_segment_task_chunk(self._db_segment, chunk_number, quality=quality) + if cached_chunk: + return return_type(cached_chunk[0], cached_chunk[1]) + + db_data = self._db_segment.task.data + step = db_data.get_frame_step() + task_chunk_start_frame = chunk_number * db_data.chunk_size + task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1 + task_chunk_frame_set = set( + range( + db_data.start_frame + task_chunk_start_frame * step, + min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step, + step, + ) + ) + + # Don't put this into set_callback to avoid data duplication in the cache + matching_chunk = self.find_matching_chunk(sorted(task_chunk_frame_set)) + if matching_chunk is not None: + return self.get_chunk(matching_chunk, quality=quality) + + def _set_callback() -> DataWithMime: + # Create and return a joined / cleaned chunk + segment_chunk_frame_ids = sorted( + task_chunk_frame_set.intersection(self._db_segment.frame_set) + ) + + if self._db_segment.type == models.SegmentType.RANGE: + return cache.prepare_custom_range_segment_chunk( + db_task=self._db_segment.task, + frame_ids=segment_chunk_frame_ids, + quality=quality, + ) + elif self._db_segment.type == models.SegmentType.SPECIFIC_FRAMES: + return cache.prepare_custom_masked_range_segment_chunk( + db_task=self._db_segment.task, + frame_ids=segment_chunk_frame_ids, + chunk_number=chunk_number, + quality=quality, + insert_placeholders=True, + ) + else: + assert False - preview = ImageOps.exif_transpose(preview) - preview.thumbnail(PREVIEW_SIZE) + buffer, mime_type = cache.get_or_set_segment_task_chunk( + self._db_segment, chunk_number, quality=quality, set_callback=_set_callback + ) - output_buf = BytesIO() - preview.convert('RGB').save(output_buf, format="JPEG") + return return_type(data=buffer, mime=mime_type) - return output_buf, PREVIEW_MIME - def get_chunk(self, chunk_number, quality=Quality.ORIGINAL): - chunk_number = self._validate_chunk_number(chunk_number) - if self._db_data.storage_method == StorageMethodChoice.CACHE: - return self._loaders[quality].get_chunk_path(chunk_number, quality, self._db_data) - return self._loaders[quality].get_chunk_path(chunk_number) +@overload +def make_frame_provider(data_source: models.Job) -> JobFrameProvider: ... - def get_frame(self, frame_number, quality=Quality.ORIGINAL, - out_type=Type.BUFFER): - _, chunk_number, frame_offset = self._validate_frame_number(frame_number) - loader = self._loaders[quality] - chunk_reader = loader.load(chunk_number) - frame, frame_name, _ = chunk_reader[frame_offset] - frame = self._convert_frame(frame, loader.reader_class, out_type) - if loader.reader_class is VideoReader: - return (frame, self.VIDEO_FRAME_MIME) - return (frame, mimetypes.guess_type(frame_name)[0]) +@overload +def make_frame_provider(data_source: models.Segment) -> SegmentFrameProvider: ... + + +@overload +def make_frame_provider(data_source: models.Task) -> TaskFrameProvider: ... + - def get_frames(self, start_frame, stop_frame, quality=Quality.ORIGINAL, out_type=Type.BUFFER): - for idx in range(start_frame, stop_frame): - yield self.get_frame(idx, quality=quality, out_type=out_type) +def make_frame_provider( + data_source: Union[models.Job, models.Segment, models.Task, Any] +) -> IFrameProvider: + if isinstance(data_source, models.Task): + frame_provider = TaskFrameProvider(data_source) + elif isinstance(data_source, models.Segment): + frame_provider = SegmentFrameProvider(data_source) + elif isinstance(data_source, models.Job): + frame_provider = JobFrameProvider(data_source) + else: + raise TypeError(f"Unexpected data source type {type(data_source)}") - @property - def data_id(self): - return self._db_data.id + return frame_provider diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py index 5f123d33eef8..6f1740e74fd4 100644 --- a/cvat/apps/engine/log.py +++ b/cvat/apps/engine/log.py @@ -59,24 +59,31 @@ def get_logger(logger_name, log_file): vlogger = logging.getLogger('vector') + +def get_migration_log_dir() -> str: + return settings.MIGRATIONS_LOGS_ROOT + +def get_migration_log_file_path(migration_name: str) -> str: + return osp.join(get_migration_log_dir(), f'{migration_name}.log') + @contextmanager def get_migration_logger(migration_name): - migration_log_file = '{}.log'.format(migration_name) + migration_log_file_path = get_migration_log_file_path(migration_name) stdout = sys.stdout stderr = sys.stderr + # redirect all stdout to the file - log_file_object = open(osp.join(settings.MIGRATIONS_LOGS_ROOT, migration_log_file), 'w') - sys.stdout = log_file_object - sys.stderr = log_file_object - - log = logging.getLogger(migration_name) - log.addHandler(logging.StreamHandler(stdout)) - log.addHandler(logging.StreamHandler(log_file_object)) - log.setLevel(logging.INFO) - - try: - yield log - finally: - log_file_object.close() - sys.stdout = stdout - sys.stderr = stderr + with open(migration_log_file_path, 'w') as log_file_object: + sys.stdout = log_file_object + sys.stderr = log_file_object + + log = logging.getLogger(migration_name) + log.addHandler(logging.StreamHandler(stdout)) + log.addHandler(logging.StreamHandler(log_file_object)) + log.setLevel(logging.INFO) + + try: + yield log + finally: + sys.stdout = stdout + sys.stderr = stderr diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 9a352c3b930c..9ddbad10e3a8 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + import os import sysconfig import tempfile @@ -11,12 +13,20 @@ import io import itertools import struct -from enum import IntEnum from abc import ABC, abstractmethod -from contextlib import closing -from typing import Iterable +from bisect import bisect +from contextlib import ExitStack, closing, contextmanager +from dataclasses import dataclass +from enum import IntEnum +from typing import ( + Any, Callable, ContextManager, Generator, Iterable, Iterator, Optional, Protocol, + Sequence, Tuple, TypeVar, Union +) import av +import av.codec +import av.container +import av.video.stream import numpy as np from natsort import os_sorted from pyunpack import Archive @@ -45,6 +55,10 @@ class ORIENTATION(IntEnum): MIRROR_HORIZONTAL_90_ROTATED=7 NORMAL_270_ROTATED=8 +class FrameQuality(IntEnum): + COMPRESSED = 0 + ORIGINAL = 100 + def get_mime(name): for type_name, type_def in MEDIA_TYPES.items(): if type_def['has_mime_type'](name): @@ -78,21 +92,126 @@ def sort(images, sorting_method=SortingMethod.LEXICOGRAPHICAL, func=None): else: raise NotImplementedError() -def image_size_within_orientation(img: Image): +def image_size_within_orientation(img: Image.Image): orientation = img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL) if orientation > 4: return img.height, img.width return img.width, img.height -def has_exif_rotation(img: Image): +def has_exif_rotation(img: Image.Image): return img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL) != ORIENTATION.NORMAL_HORIZONTAL +_T = TypeVar("_T") + + +class RandomAccessIterator(Iterator[_T]): + def __init__(self, iterable: Iterable[_T]): + self.iterable: Iterable[_T] = iterable + self.iterator: Optional[Iterator[_T]] = None + self.pos: int = -1 + + def __iter__(self): + return self + + def __next__(self): + return self[self.pos + 1] + + def __getitem__(self, idx: int) -> Optional[_T]: + assert 0 <= idx + if self.iterator is None or idx <= self.pos: + self.reset() + v = None + while self.pos < idx: + # NOTE: don't keep the last item in self, it can be expensive + v = next(self.iterator) + self.pos += 1 + return v + + def reset(self): + self.close() + self.iterator = iter(self.iterable) + + def close(self): + if self.iterator is not None: + if close := getattr(self.iterator, "close", None): + close() + self.iterator = None + self.pos = -1 + + +class Sized(Protocol): + def get_size(self) -> int: ... + +_MediaT = TypeVar("_MediaT", bound=Sized) + +class CachingMediaIterator(RandomAccessIterator[_MediaT]): + @dataclass + class _CacheItem: + value: _MediaT + size: int + + def __init__( + self, + iterable: Iterable, + *, + max_cache_memory: int, + max_cache_entries: int, + object_size_callback: Optional[Callable[[_MediaT], int]] = None, + ): + super().__init__(iterable) + self.max_cache_entries = max_cache_entries + self.max_cache_memory = max_cache_memory + self._get_object_size_callback = object_size_callback + self.used_cache_memory = 0 + self._cache: dict[int, self._CacheItem] = {} + + def _get_object_size(self, obj: _MediaT) -> int: + if self._get_object_size_callback: + return self._get_object_size_callback(obj) + + return obj.get_size() + + def __getitem__(self, idx: int): + cache_item = self._cache.get(idx) + if cache_item: + return cache_item.value + + value = super().__getitem__(idx) + value_size = self._get_object_size(value) + + while ( + len(self._cache) + 1 > self.max_cache_entries or + self.used_cache_memory + value_size > self.max_cache_memory + ): + min_key = min(self._cache.keys()) + self._cache.pop(min_key) + + if self.used_cache_memory + value_size <= self.max_cache_memory: + self._cache[idx] = self._CacheItem(value, value_size) + + return value + + class IMediaReader(ABC): - def __init__(self, source_path, step, start, stop, dimension): + def __init__( + self, + source_path, + *, + start: int = 0, + stop: Optional[int] = None, + step: int = 1, + dimension: DimensionType = DimensionType.DIM_2D + ): self._source_path = source_path + self._step = step + self._start = start + "The first included index" + self._stop = stop + "The last included index" + self._dimension = dimension @abstractmethod @@ -140,30 +259,25 @@ def _get_preview(obj): def get_image_size(self, i): pass - def __len__(self): - return len(self.frame_range) - - @property - def frame_range(self): - return range(self._start, self._stop, self._step) - class ImageListReader(IMediaReader): def __init__(self, - source_path, - step=1, - start=0, - stop=None, - dimension=DimensionType.DIM_2D, - sorting_method=SortingMethod.LEXICOGRAPHICAL): + source_path, + step: int = 1, + start: int = 0, + stop: Optional[int] = None, + dimension: DimensionType = DimensionType.DIM_2D, + sorting_method: SortingMethod = SortingMethod.LEXICOGRAPHICAL, + ): if not source_path: raise Exception('No image found') if not stop: - stop = len(source_path) + stop = len(source_path) - 1 else: - stop = min(len(source_path), stop + 1) + stop = min(len(source_path) - 1, stop) + step = max(step, 1) - assert stop > start + assert stop >= start super().__init__( source_path=sort(source_path, sorting_method), @@ -176,7 +290,7 @@ def __init__(self, self._sorting_method = sorting_method def __iter__(self): - for i in range(self._start, self._stop, self._step): + for i in self.frame_range: yield (self.get_image(i), self.get_path(i), i) def __contains__(self, media_file): @@ -189,7 +303,7 @@ def filter(self, callback): source_path, step=self._step, start=self._start, - stop=self._stop - 1, + stop=self._stop, dimension=self._dimension, sorting_method=self._sorting_method ) @@ -201,7 +315,7 @@ def get_image(self, i): return self._source_path[i] def get_progress(self, pos): - return (pos - self._start + 1) / (self._stop - self._start) + return (pos + 1) / (len(self.frame_range) or 1) def get_preview(self, frame): if self._dimension == DimensionType.DIM_3D: @@ -233,6 +347,13 @@ def reconcile(self, source_files, step=1, start=0, stop=None, dimension=Dimensio def absolute_source_paths(self): return [self.get_path(idx) for idx, _ in enumerate(self._source_path)] + def __len__(self): + return len(self.frame_range) + + @property + def frame_range(self): + return range(self._start, self._stop + 1, self._step) + class DirectoryReader(ImageListReader): def __init__(self, source_path, @@ -403,57 +524,149 @@ def extract(self): if not self.extract_dir: os.remove(self._zip_source.filename) +class _AvVideoReading: + @contextmanager + def read_av_container(self, source: Union[str, io.BytesIO]) -> av.container.InputContainer: + if isinstance(source, io.BytesIO): + source.seek(0) # required for re-reading + + container = av.open(source) + try: + yield container + finally: + # fixes a memory leak in input container closing + # https://github.com/PyAV-Org/PyAV/issues/1117 + for stream in container.streams: + context = stream.codec_context + if context and context.is_open: + context.close() + + if container.open_files: + container.close() + + def decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + demux_iter = container.demux(video_stream) + try: + for packet in demux_iter: + yield from packet.decode() + finally: + # av v9.2.0 seems to have a memory corruption or a deadlock + # in exception handling for demux() in the multithreaded mode. + # Instead of breaking the iteration, we iterate over packets till the end. + # Fixed in av v12.2.0. + if av.__version__ == "9.2.0" and video_stream.thread_type == 'AUTO': + exhausted = object() + while next(demux_iter, exhausted) is not exhausted: + pass + class VideoReader(IMediaReader): - def __init__(self, source_path, step=1, start=0, stop=None, dimension=DimensionType.DIM_2D): + def __init__( + self, + source_path: Union[str, io.BytesIO], + step: int = 1, + start: int = 0, + stop: Optional[int] = None, + dimension: DimensionType = DimensionType.DIM_2D, + *, + allow_threading: bool = True, + ): super().__init__( source_path=source_path, step=step, start=start, - stop=stop + 1 if stop is not None else stop, + stop=stop, dimension=dimension, ) - def _has_frame(self, i): - if i >= self._start: - if (i - self._start) % self._step == 0: - if self._stop is None or i < self._stop: - return True + self.allow_threading = allow_threading + self._frame_count: Optional[int] = None + self._frame_size: Optional[tuple[int, int]] = None # (w, h) - return False + def iterate_frames( + self, + *, + frame_filter: Union[bool, Iterable[int]] = True, + video_stream: Optional[av.video.stream.VideoStream] = None, + ) -> Iterator[Tuple[av.VideoFrame, str, int]]: + """ + If provided, frame_filter must be an ordered sequence in the ascending order. + 'True' means using the frames configured in the reader object. + 'False' or 'None' means returning all the video frames. + """ - def __iter__(self): - with self._get_av_container() as container: - stream = container.streams.video[0] - stream.thread_type = 'AUTO' - frame_num = 0 - for packet in container.demux(stream): - for image in packet.decode(): - frame_num += 1 - if self._has_frame(frame_num - 1): - if packet.stream.metadata.get('rotate'): - pts = image.pts - image = av.VideoFrame().from_ndarray( + if frame_filter is True: + frame_filter = itertools.count(self._start, self._step) + if self._stop: + frame_filter = itertools.takewhile(lambda x: x <= self._stop, frame_filter) + elif not frame_filter: + frame_filter = itertools.count() + + frame_filter_iter = iter(frame_filter) + next_frame_filter_frame = next(frame_filter_iter, None) + if next_frame_filter_frame is None: + return + + es = ExitStack() + + needs_init = video_stream is None + if needs_init: + container = es.enter_context(self._read_av_container()) + else: + container = video_stream.container + + with es: + if needs_init: + video_stream = container.streams.video[0] + + if self.allow_threading: + video_stream.thread_type = 'AUTO' + + frame_counter = itertools.count() + with closing(self._decode_stream(container, video_stream)) as stream_decoder: + for frame, frame_number in zip(stream_decoder, frame_counter): + if frame_number == next_frame_filter_frame: + if video_stream.metadata.get('rotate'): + pts = frame.pts + frame = av.VideoFrame().from_ndarray( rotate_image( - image.to_ndarray(format='bgr24'), - 360 - int(stream.metadata.get('rotate')) + frame.to_ndarray(format='bgr24'), + 360 - int(video_stream.metadata.get('rotate')) ), format ='bgr24' ) - image.pts = pts - yield (image, self._source_path[0], image.pts) + frame.pts = pts + + if self._frame_size is None: + self._frame_size = (frame.width, frame.height) + + yield (frame, self._source_path[0], frame.pts) + + next_frame_filter_frame = next(frame_filter_iter, None) + + if next_frame_filter_frame is None: + return + + def __iter__(self) -> Iterator[Tuple[av.VideoFrame, str, int]]: + return self.iterate_frames() def get_progress(self, pos): duration = self._get_duration() return pos / duration if duration else None - def _get_av_container(self): - if isinstance(self._source_path[0], io.BytesIO): - self._source_path[0].seek(0) # required for re-reading - return av.open(self._source_path[0]) + def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + return _AvVideoReading().read_av_container(self._source_path[0]) + + def _decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + return _AvVideoReading().decode_stream(container, video_stream) def _get_duration(self): - with self._get_av_container() as container: + with self._read_av_container() as container: stream = container.streams.video[0] + duration = None if stream.duration: duration = stream.duration @@ -468,122 +681,128 @@ def _get_duration(self): return duration def get_preview(self, frame): - with self._get_av_container() as container: + with self._read_av_container() as container: stream = container.streams.video[0] + tb_denominator = stream.time_base.denominator needed_time = int((frame / stream.guessed_rate) * tb_denominator) container.seek(offset=needed_time, stream=stream) - for packet in container.demux(stream): - for frame in packet.decode(): - return self._get_preview(frame.to_image() if not stream.metadata.get('rotate') \ - else av.VideoFrame().from_ndarray( - rotate_image( - frame.to_ndarray(format='bgr24'), - 360 - int(container.streams.video[0].metadata.get('rotate')) - ), - format ='bgr24' - ).to_image() - ) + + with closing(self.iterate_frames(video_stream=stream)) as frame_iter: + return self._get_preview(next(frame_iter)) def get_image_size(self, i): - image = (next(iter(self)))[0] - return image.width, image.height + if self._frame_size is not None: + return self._frame_size -class FragmentMediaReader: - def __init__(self, chunk_number, chunk_size, start, stop, step=1): - self._start = start - self._stop = stop + 1 # up to the last inclusive - self._step = step - self._chunk_number = chunk_number - self._chunk_size = chunk_size - self._start_chunk_frame_number = \ - self._start + self._chunk_number * self._chunk_size * self._step - self._end_chunk_frame_number = min(self._start_chunk_frame_number \ - + (self._chunk_size - 1) * self._step + 1, self._stop) - self._frame_range = self._get_frame_range() + with closing(iter(self)) as frame_iter: + frame = next(frame_iter)[0] + self._frame_size = (frame.width, frame.height) - @property - def frame_range(self): - return self._frame_range + return self._frame_size - def _get_frame_range(self): - frame_range = [] - for idx in range(self._start, self._stop, self._step): - if idx < self._start_chunk_frame_number: - continue - elif idx < self._end_chunk_frame_number and \ - not (idx - self._start_chunk_frame_number) % self._step: - frame_range.append(idx) - elif (idx - self._start_chunk_frame_number) % self._step: - continue - else: - break - return frame_range + def get_frame_count(self) -> int: + """ + Returns total frame count in the video + + Note that not all videos provide length / duration metainfo, so the + result may require full video decoding. + + The total count is NOT affected by the frame filtering options of the object, + i.e. start frame, end frame and frame step. + """ + # It's possible to retrieve frame count from the stream.frames, + # but the number may be incorrect. + # https://superuser.com/questions/1512575/why-total-frame-count-is-different-in-ffmpeg-than-ffprobe + if self._frame_count is not None: + return self._frame_count + + frame_count = 0 + for _ in self.iterate_frames(frame_filter=False): + frame_count += 1 + + self._frame_count = frame_count -class ImageDatasetManifestReader(FragmentMediaReader): - def __init__(self, manifest_path, **kwargs): - super().__init__(**kwargs) + return frame_count + + +class ImageReaderWithManifest: + def __init__(self, manifest_path: str): self._manifest = ImageManifestManager(manifest_path) self._manifest.init_index() - def __iter__(self): - for idx in self._frame_range: + def iterate_frames(self, frame_ids: Iterable[int]): + for idx in frame_ids: yield self._manifest[idx] -class VideoDatasetManifestReader(FragmentMediaReader): - def __init__(self, manifest_path, **kwargs): - self.source_path = kwargs.pop('source_path') - super().__init__(**kwargs) - self._manifest = VideoManifestManager(manifest_path) - self._manifest.init_index() +class VideoReaderWithManifest: + # TODO: merge this class with VideoReader + + def __init__(self, manifest_path: str, source_path: str, *, allow_threading: bool = False): + self.source_path = source_path + self.manifest = VideoManifestManager(manifest_path) + if self.manifest.exists: + self.manifest.init_index() + + self.allow_threading = allow_threading + + def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + return _AvVideoReading().read_av_container(self.source_path) + + def _decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + return _AvVideoReading().decode_stream(container, video_stream) - def _get_nearest_left_key_frame(self): - if self._start_chunk_frame_number >= \ - self._manifest[len(self._manifest) - 1].get('number'): - left_border = len(self._manifest) - 1 + def _get_nearest_left_key_frame(self, frame_id: int) -> tuple[int, int]: + nearest_left_keyframe_pos = bisect( + self.manifest, frame_id, key=lambda entry: entry.get('number') + ) + if nearest_left_keyframe_pos: + frame_number = self.manifest[nearest_left_keyframe_pos - 1].get('number') + timestamp = self.manifest[nearest_left_keyframe_pos - 1].get('pts') else: - left_border = 0 - delta = len(self._manifest) - while delta: - step = delta // 2 - cur_position = left_border + step - if self._manifest[cur_position].get('number') < self._start_chunk_frame_number: - cur_position += 1 - left_border = cur_position - delta -= step + 1 - else: - delta = step - if self._manifest[cur_position].get('number') > self._start_chunk_frame_number: - left_border -= 1 - frame_number = self._manifest[left_border].get('number') - timestamp = self._manifest[left_border].get('pts') + frame_number = 0 + timestamp = 0 return frame_number, timestamp - def __iter__(self): - start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame() - with closing(av.open(self.source_path, mode='r')) as container: - video_stream = next(stream for stream in container.streams if stream.type == 'video') - video_stream.thread_type = 'AUTO' + def iterate_frames(self, *, frame_filter: Iterable[int]) -> Iterable[av.VideoFrame]: + "frame_ids must be an ordered sequence in the ascending order" + + frame_filter_iter = iter(frame_filter) + next_frame_filter_frame = next(frame_filter_iter, None) + if next_frame_filter_frame is None: + return + + start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame( + next_frame_filter_frame + ) + + with self._read_av_container() as container: + video_stream = container.streams.video[0] + if self.allow_threading: + video_stream.thread_type = 'AUTO' container.seek(offset=start_decode_timestamp, stream=video_stream) - frame_number = start_decode_frame_number - 1 - for packet in container.demux(video_stream): - for frame in packet.decode(): - frame_number += 1 - if frame_number in self._frame_range: + frame_counter = itertools.count(start_decode_frame_number) + with closing(self._decode_stream(container, video_stream)) as stream_decoder: + for frame, frame_number in zip(stream_decoder, frame_counter): + if frame_number == next_frame_filter_frame: if video_stream.metadata.get('rotate'): frame = av.VideoFrame().from_ndarray( rotate_image( frame.to_ndarray(format='bgr24'), - 360 - int(container.streams.video[0].metadata.get('rotate')) + 360 - int(video_stream.metadata.get('rotate')) ), format ='bgr24' ) + yield frame - elif frame_number < self._frame_range[-1]: - continue - else: + + next_frame_filter_frame = next(frame_filter_iter, None) + + if next_frame_filter_frame is None: return class IChunkWriter(ABC): @@ -648,33 +867,37 @@ class ZipChunkWriter(IChunkWriter): POINT_CLOUD_EXT = 'pcd' def _write_pcd_file(self, image: str|io.BytesIO) -> tuple[io.BytesIO, str, int, int]: - image_buf = open(image, "rb") if isinstance(image, str) else image - try: + with ExitStack() as es: + if isinstance(image, str): + image_buf = es.enter_context(open(image, "rb")) + else: + image_buf = image + properties = ValidateDimension.get_pcd_properties(image_buf) w, h = int(properties["WIDTH"]), int(properties["HEIGHT"]) image_buf.seek(0, 0) return io.BytesIO(image_buf.read()), self.POINT_CLOUD_EXT, w, h - finally: - if isinstance(image, str): - image_buf.close() - def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str): + def save_as_chunk(self, images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str): with zipfile.ZipFile(chunk_path, 'x') as zip_chunk: for idx, (image, path, _) in enumerate(images): ext = os.path.splitext(path)[1].replace('.', '') - output = io.BytesIO() + if self._dimension == DimensionType.DIM_2D: # current version of Pillow applies exif rotation immediately when TIFF image opened # and it removes rotation tag after that # so, has_exif_rotation(image) will return False for TIFF images even if they were actually rotated # and original files will be added to the archive (without applied rotation) # that is why we need the second part of the condition - if has_exif_rotation(image) or image.format == 'TIFF': + if isinstance(image, Image.Image) and ( + has_exif_rotation(image) or image.format == 'TIFF' + ): + output = io.BytesIO() rot_image = ImageOps.exif_transpose(image) try: if image.format == 'TIFF': # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html - # use loseless lzw compression for tiff images + # use lossless lzw compression for tiff images rot_image.save(output, format='TIFF', compression='tiff_lzw') else: rot_image.save( @@ -686,16 +909,22 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s ) finally: rot_image.close() + elif isinstance(image, io.IOBase): + output = image else: output = path else: - output, ext = self._write_pcd_file(path)[0:2] - arcname = '{:06d}.{}'.format(idx, ext) + if isinstance(image, io.BytesIO): + output, ext = self._write_pcd_file(image)[0:2] + else: + output, ext = self._write_pcd_file(path)[0:2] + arcname = '{:06d}.{}'.format(idx, ext) if isinstance(output, io.BytesIO): zip_chunk.writestr(arcname, output.getvalue()) else: zip_chunk.write(filename=output, arcname=arcname) + # return empty list because ZipChunkWriter write files as is # and does not decode it to know img size. return [] @@ -703,7 +932,7 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s class ZipCompressedChunkWriter(ZipChunkWriter): def save_as_chunk( self, - images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]], + images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str, *, compress_frames: bool = True, zip_compress_level: int = 0 ): image_sizes = [] @@ -719,7 +948,11 @@ def save_as_chunk( w, h = img.size extension = self.IMAGE_EXT else: - image_buf, extension, w, h = self._write_pcd_file(path) + if isinstance(image, io.BytesIO): + image_buf, extension, w, h = self._write_pcd_file(image) + else: + image_buf, extension, w, h = self._write_pcd_file(path) + image_sizes.append((w, h)) arcname = '{:06d}.{}'.format(idx, extension) zip_chunk.writestr(arcname, image_buf.getvalue()) @@ -751,7 +984,7 @@ def __init__(self, quality=67): "preset": "ultrafast", } - def _add_video_stream(self, container, w, h, rate, options): + def _add_video_stream(self, container: av.container.OutputContainer, w, h, rate, options): # x264 requires width and height must be divisible by 2 for yuv420p if h % 2: h += 1 @@ -772,12 +1005,28 @@ def _add_video_stream(self, container, w, h, rate, options): return video_stream - def save_as_chunk(self, images, chunk_path): - if not images: + FrameDescriptor = Tuple[av.VideoFrame, Any, Any] + + def _peek_first_frame( + self, frame_iter: Iterator[FrameDescriptor] + ) -> Tuple[Optional[FrameDescriptor], Iterator[FrameDescriptor]]: + "Gets the first frame and returns the same full iterator" + + if not hasattr(frame_iter, '__next__'): + frame_iter = iter(frame_iter) + + first_frame = next(frame_iter, None) + return first_frame, itertools.chain((first_frame, ), frame_iter) + + def save_as_chunk( + self, images: Iterator[FrameDescriptor], chunk_path: str + ) -> Sequence[Tuple[int, int]]: + first_frame, images = self._peek_first_frame(images) + if not first_frame: raise Exception('no images to save') - input_w = images[0][0].width - input_h = images[0][0].height + input_w = first_frame[0].width + input_h = first_frame[0].height with av.open(chunk_path, 'w', format=self.FORMAT) as output_container: output_v_stream = self._add_video_stream( @@ -788,11 +1037,15 @@ def save_as_chunk(self, images, chunk_path): options=self._codec_opts, ) - self._encode_images(images, output_container, output_v_stream) + with closing(output_v_stream): + self._encode_images(images, output_container, output_v_stream) + return [(input_w, input_h)] @staticmethod - def _encode_images(images, container, stream): + def _encode_images( + images, container: av.container.OutputContainer, stream: av.video.stream.VideoStream + ): for frame, _, _ in images: # let libav set the correct pts and time_base frame.pts = None @@ -818,11 +1071,12 @@ def __init__(self, quality): } def save_as_chunk(self, images, chunk_path): - if not images: + first_frame, images = self._peek_first_frame(images) + if not first_frame: raise Exception('no images to save') - input_w = images[0][0].width - input_h = images[0][0].height + input_w = first_frame[0].width + input_h = first_frame[0].height downscale_factor = 1 while input_h / downscale_factor >= 1080: @@ -840,7 +1094,9 @@ def save_as_chunk(self, images, chunk_path): options=self._codec_opts, ) - self._encode_images(images, output_container, output_v_stream) + with closing(output_v_stream): + self._encode_images(images, output_container, output_v_stream) + return [(input_w, input_h)] def _is_archive(path): diff --git a/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py new file mode 100644 index 000000000000..8ef887d4c54b --- /dev/null +++ b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py @@ -0,0 +1,118 @@ +# Generated by Django 4.2.13 on 2024-08-12 09:49 + +import os +from itertools import islice +from typing import Iterable, TypeVar + +from django.db import migrations + +from cvat.apps.engine.log import get_migration_log_dir, get_migration_logger + +T = TypeVar("T") + + +def take_by(iterable: Iterable[T], count: int) -> Iterable[T]: + """ + Returns elements from the input iterable by batches of N items. + ('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g'] + """ + + it = iter(iterable) + while True: + batch = list(islice(it, count)) + if len(batch) == 0: + break + + yield batch + + +def get_migration_name() -> str: + return os.path.splitext(os.path.basename(__file__))[0] + + +def get_updated_ids_filename(log_dir: str, migration_name: str) -> str: + return os.path.join(log_dir, migration_name + "-data_ids.log") + + +MIGRATION_LOG_HEADER = ( + 'The following Data ids have been switched from using "filesystem" chunk storage ' 'to "cache":' +) + + +def switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor): + migration_name = get_migration_name() + migration_log_dir = get_migration_log_dir() + with get_migration_logger(migration_name) as common_logger: + Data = apps.get_model("engine", "Data") + + data_with_static_cache_query = Data.objects.filter(storage_method="file_system") + + data_with_static_cache_ids = list( + v[0] + for v in ( + data_with_static_cache_query.order_by("id") + .values_list("id") + .iterator(chunk_size=100000) + ) + ) + + data_with_static_cache_query.update(storage_method="cache") + + updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name) + with open(updated_ids_filename, "w") as data_ids_file: + print(MIGRATION_LOG_HEADER, file=data_ids_file) + + for data_id in data_with_static_cache_ids: + print(data_id, file=data_ids_file) + + common_logger.info( + "Information about migrated tasks is available in the migration log file: " + "{}. You will need to remove data manually for these tasks.".format( + updated_ids_filename + ) + ) + + +def revert_switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor): + migration_name = get_migration_name() + migration_log_dir = get_migration_log_dir() + + updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name) + if not os.path.isfile(updated_ids_filename): + raise FileNotFoundError( + "Can't revert the migration: can't file forward migration logfile at " + f"'{updated_ids_filename}'." + ) + + with open(updated_ids_filename, "r") as data_ids_file: + header = data_ids_file.readline().strip() + if header != MIGRATION_LOG_HEADER: + raise ValueError( + "Can't revert the migration: the migration log file has unexpected header" + ) + + forward_updated_ids = tuple(map(int, data_ids_file)) + + if not forward_updated_ids: + return + + Data = apps.get_model("engine", "Data") + + for id_batch in take_by(forward_updated_ids, 1000): + Data.objects.filter(storage_method="cache", id__in=id_batch).update( + storage_method="file_system" + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("engine", "0082_alter_labeledimage_job_and_more"), + ] + + operations = [ + migrations.RunPython( + switch_tasks_with_static_chunks_to_dynamic_chunks, + reverse_code=revert_switch_tasks_with_static_chunks_to_dynamic_chunks, + ) + ] diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index eda765e6bebb..c57eb0371d5e 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -252,6 +252,13 @@ def get_data_dirname(self): def get_upload_dirname(self): return os.path.join(self.get_data_dirname(), "raw") + def get_raw_data_dirname(self) -> str: + return { + StorageChoice.LOCAL: self.get_upload_dirname(), + StorageChoice.SHARE: settings.SHARE_ROOT, + StorageChoice.CLOUD_STORAGE: self.get_upload_dirname(), + }[self.storage] + def get_compressed_cache_dirname(self): return os.path.join(self.get_data_dirname(), "compressed") @@ -259,7 +266,7 @@ def get_original_cache_dirname(self): return os.path.join(self.get_data_dirname(), "original") @staticmethod - def _get_chunk_name(chunk_number, chunk_type): + def _get_chunk_name(segment_id: int, chunk_number: int, chunk_type: DataChoice | str) -> str: if chunk_type == DataChoice.VIDEO: ext = 'mp4' elif chunk_type == DataChoice.IMAGESET: @@ -267,21 +274,21 @@ def _get_chunk_name(chunk_number, chunk_type): else: ext = 'list' - return '{}.{}'.format(chunk_number, ext) + return 'segment_{}-{}.{}'.format(segment_id, chunk_number, ext) - def _get_compressed_chunk_name(self, chunk_number): - return self._get_chunk_name(chunk_number, self.compressed_chunk_type) + def _get_compressed_chunk_name(self, segment_id: int, chunk_number: int) -> str: + return self._get_chunk_name(segment_id, chunk_number, self.compressed_chunk_type) - def _get_original_chunk_name(self, chunk_number): - return self._get_chunk_name(chunk_number, self.original_chunk_type) + def _get_original_chunk_name(self, segment_id: int, chunk_number: int) -> str: + return self._get_chunk_name(segment_id, chunk_number, self.original_chunk_type) - def get_original_chunk_path(self, chunk_number): + def get_original_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str: return os.path.join(self.get_original_cache_dirname(), - self._get_original_chunk_name(chunk_number)) + self._get_original_chunk_name(segment_id, chunk_number)) - def get_compressed_chunk_path(self, chunk_number): + def get_compressed_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str: return os.path.join(self.get_compressed_cache_dirname(), - self._get_compressed_chunk_name(chunk_number)) + self._get_compressed_chunk_name(segment_id, chunk_number)) def get_manifest_path(self): return os.path.join(self.get_upload_dirname(), 'manifest.jsonl') @@ -600,7 +607,7 @@ def __str__(self): class Segment(models.Model): # Common fields - task = models.ForeignKey(Task, on_delete=models.CASCADE) + task = models.ForeignKey(Task, on_delete=models.CASCADE) # TODO: add related name start_frame = models.IntegerField() stop_frame = models.IntegerField() type = models.CharField(choices=SegmentType.choices(), default=SegmentType.RANGE, max_length=32) diff --git a/cvat/apps/engine/pyproject.toml b/cvat/apps/engine/pyproject.toml new file mode 100644 index 000000000000..567b78362580 --- /dev/null +++ b/cvat/apps/engine/pyproject.toml @@ -0,0 +1,12 @@ +[tool.isort] +profile = "black" +forced_separate = ["tests"] +line_length = 100 +skip_gitignore = true # align tool behavior with Black +known_first_party = ["cvat"] + +# Can't just use a pyproject in the root dir, so duplicate +# https://github.com/psf/black/issues/2863 +[tool.black] +line-length = 100 +target-version = ['py38'] diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 9d66b1716c17..ed937a993ffe 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -594,6 +594,7 @@ class JobReadSerializer(serializers.ModelSerializer): dimension = serializers.CharField(max_length=2, source='segment.task.dimension', read_only=True) data_chunk_size = serializers.ReadOnlyField(source='segment.task.data.chunk_size') organization = serializers.ReadOnlyField(source='segment.task.organization.id', allow_null=True) + data_original_chunk_type = serializers.ReadOnlyField(source='segment.task.data.original_chunk_type') data_compressed_chunk_type = serializers.ReadOnlyField(source='segment.task.data.compressed_chunk_type') mode = serializers.ReadOnlyField(source='segment.task.mode') bug_tracker = serializers.CharField(max_length=2000, source='get_bug_tracker', @@ -607,7 +608,8 @@ class Meta: model = models.Job fields = ('url', 'id', 'task_id', 'project_id', 'assignee', 'guide_id', 'dimension', 'bug_tracker', 'status', 'stage', 'state', 'mode', 'frame_count', - 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type', + 'start_frame', 'stop_frame', + 'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type', 'created_date', 'updated_date', 'issues', 'labels', 'type', 'organization', 'target_storage', 'source_storage', 'assignee_updated_date') read_only_fields = fields diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 0db84cebc32b..f24cd686a587 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -1,32 +1,37 @@ # Copyright (C) 2018-2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +import concurrent.futures import itertools import fnmatch import os -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union, Iterable -from rest_framework.serializers import ValidationError -import rq import re +import rq import shutil +from contextlib import closing +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union from urllib import parse as urlparse from urllib import request as urlrequest -import django_rq -import concurrent.futures -import queue +import av +import attrs +import django_rq from django.conf import settings from django.db import transaction from django.http import HttpRequest -from datetime import datetime, timezone -from pathlib import Path +from rest_framework.serializers import ValidationError from cvat.apps.engine import models from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.media_extractors import (MEDIA_TYPES, ImageListReader, Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, - ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort) +from cvat.apps.engine.media_extractors import ( + MEDIA_TYPES, CachingMediaIterator, IMediaReader, ImageListReader, + Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, RandomAccessIterator, + ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort +) from cvat.apps.engine.models import RequestAction, RequestTarget from cvat.apps.engine.utils import ( av_scan_paths,get_rq_job_meta, define_dependent_job, get_rq_lock_by_user, preload_images @@ -71,6 +76,8 @@ def create( class SegmentParams(NamedTuple): start_frame: int stop_frame: int + type: models.SegmentType = models.SegmentType.RANGE + frames: Optional[Sequence[int]] = [] class SegmentsParams(NamedTuple): segments: Iterator[SegmentParams] @@ -116,7 +123,7 @@ def _copy_data_from_share_point( os.makedirs(target_dir) shutil.copyfile(source_path, target_path) -def _get_task_segment_data( +def _generate_segment_params( db_task: models.Task, *, data_size: Optional[int] = None, @@ -127,10 +134,14 @@ def _segments(): # It is assumed here that files are already saved ordered in the task # Here we just need to create segments by the job sizes start_frame = 0 - for jf in job_file_mapping: - segment_size = len(jf) + for job_files in job_file_mapping: + segment_size = len(job_files) stop_frame = start_frame + segment_size - 1 - yield SegmentParams(start_frame, stop_frame) + yield SegmentParams( + start_frame=start_frame, + stop_frame=stop_frame, + type=models.SegmentType.RANGE, + ) start_frame = stop_frame + 1 @@ -153,31 +164,39 @@ def _segments(): ) segments = ( - SegmentParams(start_frame, min(start_frame + segment_size - 1, data_size - 1)) + SegmentParams( + start_frame=start_frame, + stop_frame=min(start_frame + segment_size - 1, data_size - 1), + type=models.SegmentType.RANGE + ) for start_frame in range(0, data_size - overlap, segment_size - overlap) ) return SegmentsParams(segments, segment_size, overlap) -def _save_task_to_db(db_task: models.Task, *, job_file_mapping: Optional[JobFileMapping] = None): - job = rq.get_current_job() - job.meta['status'] = 'Task is being saved in database' - job.save_meta() +def _create_segments_and_jobs( + db_task: models.Task, + *, + job_file_mapping: Optional[JobFileMapping] = None, +): + rq_job = rq.get_current_job() + rq_job.meta['status'] = 'Task is being saved in database' + rq_job.save_meta() - segments, segment_size, overlap = _get_task_segment_data( - db_task=db_task, job_file_mapping=job_file_mapping + segments, segment_size, overlap = _generate_segment_params( + db_task=db_task, job_file_mapping=job_file_mapping, ) db_task.segment_size = segment_size db_task.overlap = overlap - for segment_idx, (start_frame, stop_frame) in enumerate(segments): - slogger.glob.info("New segment for task #{}: idx = {}, start_frame = {}, \ - stop_frame = {}".format(db_task.id, segment_idx, start_frame, stop_frame)) + for segment_idx, segment_params in enumerate(segments): + slogger.glob.info( + "New segment for task #{task_id}: idx = {segment_idx}, start_frame = {start_frame}, " + "stop_frame = {stop_frame}".format( + task_id=db_task.id, segment_idx=segment_idx, **segment_params._asdict() + )) - db_segment = models.Segment() - db_segment.task = db_task - db_segment.start_frame = start_frame - db_segment.stop_frame = stop_frame + db_segment = models.Segment(task=db_task, **segment_params._asdict()) db_segment.save() db_job = models.Job(segment=db_segment) @@ -322,48 +341,28 @@ def _validate_manifest( *, is_in_cloud: bool, db_cloud_storage: Optional[Any], - data_storage_method: str, - data_sorting_method: str, - isBackupRestore: bool, ) -> Optional[str]: - if manifests: - if len(manifests) != 1: - raise ValidationError('Only one manifest file can be attached to data') - manifest_file = manifests[0] - full_manifest_path = os.path.join(root_dir, manifests[0]) - - if is_in_cloud: - cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage) - # check that cloud storage manifest file exists and is up to date - if not os.path.exists(full_manifest_path) or \ - datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \ - < cloud_storage_instance.get_file_last_modified(manifest_file): - cloud_storage_instance.download_file(manifest_file, full_manifest_path) - - if is_manifest(full_manifest_path): - if not ( - data_sorting_method == models.SortingMethod.PREDEFINED or - (settings.USE_CACHE and data_storage_method == models.StorageMethodChoice.CACHE) or - isBackupRestore or is_in_cloud - ): - cache_disabled_message = "" - if data_storage_method == models.StorageMethodChoice.CACHE and not settings.USE_CACHE: - cache_disabled_message = ( - "This server doesn't allow to use cache for data. " - "Please turn 'use cache' off and try to recreate the task" - ) - slogger.glob.warning(cache_disabled_message) - - raise ValidationError( - "A manifest file can only be used with the 'use cache' option " - "or when 'sorting_method' is 'predefined'" + \ - (". " + cache_disabled_message if cache_disabled_message else "") - ) - return manifest_file + if not manifests: + return None + if len(manifests) != 1: + raise ValidationError('Only one manifest file can be attached to data') + manifest_file = manifests[0] + full_manifest_path = os.path.join(root_dir, manifests[0]) + + if is_in_cloud: + cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage) + # check that cloud storage manifest file exists and is up to date + if not os.path.exists(full_manifest_path) or ( + datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \ + < cloud_storage_instance.get_file_last_modified(manifest_file) + ): + cloud_storage_instance.download_file(manifest_file, full_manifest_path) + + if not is_manifest(full_manifest_path): raise ValidationError('Invalid manifest was uploaded') - return None + return manifest_file def _validate_scheme(url): ALLOWED_SCHEMES = ['http', 'https'] @@ -522,18 +521,18 @@ def _create_thread( slogger.glob.info("create task #{}".format(db_task.id)) - job_file_mapping = _validate_job_file_mapping(db_task, data) - - db_data = db_task.data - upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT - is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE - job = rq.get_current_job() def _update_status(msg: str) -> None: job.meta['status'] = msg job.save_meta() + job_file_mapping = _validate_job_file_mapping(db_task, data) + + db_data = db_task.data + upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT + is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE + if data['remote_files'] and not isDatasetImport: data['remote_files'] = _download_data(data['remote_files'], upload_dir) @@ -551,14 +550,17 @@ def _update_status(msg: str) -> None: else: assert False, f"Unknown file storage {db_data.storage}" + if ( + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM and + not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE + ): + db_data.storage_method = models.StorageMethodChoice.CACHE + manifest_file = _validate_manifest( manifest_files, manifest_root, is_in_cloud=is_data_in_cloud, db_cloud_storage=db_data.cloud_storage if is_data_in_cloud else None, - data_storage_method=db_data.storage_method, - data_sorting_method=data['sorting_method'], - isBackupRestore=isBackupRestore, ) manifest = None @@ -668,14 +670,16 @@ def _update_status(msg: str) -> None: is_media_sorted = False if is_data_in_cloud: - # first we need to filter files and keep only supported ones - if any([v for k, v in media.items() if k != 'image']) and db_data.storage_method == models.StorageMethodChoice.CACHE: - # FUTURE-FIXME: This is a temporary workaround for creating tasks - # with unsupported cloud storage data (video, archive, pdf) when use_cache is enabled - db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM - _update_status("The 'use cache' option is ignored") - - if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE: + if ( + # Download remote data if local storage is requested + # TODO: maybe move into cache building to fail faster on invalid task configurations + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or + + # Packed media must be downloaded for task creation + any(v for k, v in media.items() if k != 'image') + ): + _update_status("Downloading input media") + filtered_data = [] for files in (i for i in media.values() if i): filtered_data.extend(files) @@ -690,9 +694,11 @@ def _update_status(msg: str) -> None: step = db_data.get_frame_step() if start_frame or step != 1 or stop_frame != len(filtered_data) - 1: media_to_download = filtered_data[start_frame : stop_frame + 1: step] + _download_data_from_cloud_storage(db_data.cloud_storage, media_to_download, upload_dir) del media_to_download del filtered_data + is_data_in_cloud = False db_data.storage = models.StorageChoice.LOCAL else: @@ -757,7 +763,7 @@ def _update_status(msg: str) -> None: ) # Extract input data - extractor = None + extractor: Optional[IMediaReader] = None manifest_index = _get_manifest_frame_indexer() for media_type, media_files in media.items(): if not media_files: @@ -917,38 +923,9 @@ def _update_status(msg: str) -> None: db_data.compressed_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' and not data['use_zip_chunks'] else models.DataChoice.IMAGESET db_data.original_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' else models.DataChoice.IMAGESET - def update_progress(progress): - progress_animation = '|/-\\' - if not hasattr(update_progress, 'call_counter'): - update_progress.call_counter = 0 - - status_message = 'CVAT is preparing data chunks' - if not progress: - status_message = '{} {}'.format(status_message, progress_animation[update_progress.call_counter]) - job.meta['status'] = status_message - job.meta['task_progress'] = progress or 0. - job.save_meta() - update_progress.call_counter = (update_progress.call_counter + 1) % len(progress_animation) - - compressed_chunk_writer_class = Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == models.DataChoice.VIDEO else ZipCompressedChunkWriter - if db_data.original_chunk_type == models.DataChoice.VIDEO: - original_chunk_writer_class = Mpeg4ChunkWriter - # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks, which should be visually lossless or nearly so. - # A lower value will significantly increase the chunk size with a slight increase of quality. - original_quality = 67 - else: - original_chunk_writer_class = ZipChunkWriter - original_quality = 100 - - kwargs = {} - if validate_dimension.dimension == models.DimensionType.DIM_3D: - kwargs["dimension"] = validate_dimension.dimension - compressed_chunk_writer = compressed_chunk_writer_class(db_data.image_quality, **kwargs) - original_chunk_writer = original_chunk_writer_class(original_quality, **kwargs) - # calculate chunk size if it isn't specified if db_data.chunk_size is None: - if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter): + if db_data.compressed_chunk_type == models.DataChoice.IMAGESET: first_image_idx = db_data.start_frame if not is_data_in_cloud: w, h = extractor.get_image_size(first_image_idx) @@ -960,206 +937,317 @@ def update_progress(progress): else: db_data.chunk_size = 36 - video_path = "" - video_size = (0, 0) + # TODO: try to pull up + # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl') + if (manifest_file and not os.path.exists(db_data.get_manifest_path())): + shutil.copyfile(os.path.join(manifest_root, manifest_file), + db_data.get_manifest_path()) + if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()): + os.remove(os.path.join(manifest_root, manifest_file)) + manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir) - db_images = [] + # Create task frames from the metadata collected + video_path: str = "" + video_frame_size: tuple[int, int] = (0, 0) - if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE: - for media_type, media_files in media.items(): - if not media_files: - continue + images: list[models.Image] = [] - # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl') - if manifest_file and not os.path.exists(db_data.get_manifest_path()): - shutil.copyfile(os.path.join(manifest_root, manifest_file), - db_data.get_manifest_path()) - if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()): - os.remove(os.path.join(manifest_root, manifest_file)) - manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir) + for media_type, media_files in media.items(): + if not media_files: + continue - if task_mode == MEDIA_TYPES['video']['mode']: + if task_mode == MEDIA_TYPES['video']['mode']: + if manifest_file: try: - manifest_is_prepared = False - if manifest_file: - try: - manifest = VideoManifestValidator(source_path=os.path.join(upload_dir, media_files[0]), - manifest_path=db_data.get_manifest_path()) - manifest.init_index() - manifest.validate_seek_key_frames() - assert len(manifest) > 0, 'No key frames.' - - all_frames = manifest.video_length - video_size = manifest.video_resolution - manifest_is_prepared = True - except Exception as ex: - manifest.remove() - if isinstance(ex, AssertionError): - base_msg = str(ex) - else: - base_msg = 'Invalid manifest file was upload.' - slogger.glob.warning(str(ex)) - _update_status('{} Start prepare a valid manifest file.'.format(base_msg)) - - if not manifest_is_prepared: - _update_status('Start prepare a manifest file') - manifest = VideoManifestManager(db_data.get_manifest_path()) - manifest.link( - media_file=media_files[0], - upload_dir=upload_dir, - chunk_size=db_data.chunk_size - ) - manifest.create() - _update_status('A manifest had been created') + _update_status('Validating the input manifest file') - all_frames = len(manifest.reader) - video_size = manifest.reader.resolution - manifest_is_prepared = True + manifest = VideoManifestValidator( + source_path=os.path.join(upload_dir, media_files[0]), + manifest_path=db_data.get_manifest_path() + ) + manifest.init_index() + manifest.validate_seek_key_frames() + + if not len(manifest): + raise ValidationError("No key frames found in the manifest") - db_data.size = len(range(db_data.start_frame, min(data['stop_frame'] + 1 \ - if data['stop_frame'] else all_frames, all_frames), db_data.get_frame_step())) - video_path = os.path.join(upload_dir, media_files[0]) except Exception as ex: - db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM manifest.remove() - del manifest - base_msg = str(ex) if isinstance(ex, AssertionError) \ - else "Uploaded video does not support a quick way of task creating." - _update_status("{} The task will be created using the old method".format(base_msg)) - else: # images, archive, pdf - db_data.size = len(extractor) - manifest = ImageManifestManager(db_data.get_manifest_path()) - - if not manifest.exists: + manifest = None + + if isinstance(ex, (ValidationError, AssertionError)): + base_msg = f"Invalid manifest file was uploaded: {ex}" + else: + base_msg = "Failed to parse the uploaded manifest file" + slogger.glob.warning(ex, exc_info=True) + + _update_status(base_msg) + else: + manifest = None + + if not manifest: + try: + _update_status('Preparing a manifest file') + + # TODO: maybe generate manifest in a temp directory + manifest = VideoManifestManager(db_data.get_manifest_path()) manifest.link( - sources=extractor.absolute_source_paths, - meta={ k: {'related_images': related_images[k] } for k in related_images }, - data_dir=upload_dir, - DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D), + media_file=media_files[0], + upload_dir=upload_dir, + chunk_size=db_data.chunk_size, # TODO: why it's needed here? + force=True ) manifest.create() - else: - manifest.init_index() - counter = itertools.count() - for _, chunk_frames in itertools.groupby(extractor.frame_range, lambda x: next(counter) // db_data.chunk_size): - chunk_paths = [(extractor.get_path(i), i) for i in chunk_frames] - img_sizes = [] - - for chunk_path, frame_id in chunk_paths: - properties = manifest[manifest_index(frame_id)] - - # check mapping - if not chunk_path.endswith(f"{properties['name']}{properties['extension']}"): - raise Exception('Incorrect file mapping to manifest content') - - if db_task.dimension == models.DimensionType.DIM_2D and ( - properties.get('width') is not None and - properties.get('height') is not None - ): - resolution = (properties['width'], properties['height']) - elif is_data_in_cloud: - raise Exception( - "Can't find image '{}' width or height info in the manifest" - .format(f"{properties['name']}{properties['extension']}") - ) - else: - resolution = extractor.get_image_size(frame_id) - img_sizes.append(resolution) - - db_images.extend([ - models.Image(data=db_data, - path=os.path.relpath(path, upload_dir), - frame=frame, width=w, height=h) - for (path, frame), (w, h) in zip(chunk_paths, img_sizes) - ]) - if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE: - counter = itertools.count() - generator = itertools.groupby(extractor, lambda _: next(counter) // db_data.chunk_size) - generator = ((idx, list(chunk_data)) for idx, chunk_data in generator) - - def save_chunks( - executor: concurrent.futures.ThreadPoolExecutor, - chunk_idx: int, - chunk_data: Iterable[tuple[str, str, str]]) -> list[tuple[str, int, tuple[int, int]]]: - nonlocal db_data, db_task, extractor, original_chunk_writer, compressed_chunk_writer - if (db_task.dimension == models.DimensionType.DIM_2D and - isinstance(extractor, ( - MEDIA_TYPES['image']['extractor'], - MEDIA_TYPES['zip']['extractor'], - MEDIA_TYPES['pdf']['extractor'], - MEDIA_TYPES['archive']['extractor'], - ))): - chunk_data = preload_images(chunk_data) - - fs_original = executor.submit( - original_chunk_writer.save_as_chunk, - images=chunk_data, - chunk_path=db_data.get_original_chunk_path(chunk_idx) - ) - fs_compressed = executor.submit( - compressed_chunk_writer.save_as_chunk, - images=chunk_data, - chunk_path=db_data.get_compressed_chunk_path(chunk_idx), - ) - fs_original.result() - image_sizes = fs_compressed.result() - # (path, frame, size) - return list((i[0][1], i[0][2], i[1]) for i in zip(chunk_data, image_sizes)) + _update_status('A manifest has been created') - def process_results(img_meta: list[tuple[str, int, tuple[int, int]]]): - nonlocal db_images, db_data, video_path, video_size + except Exception as ex: + manifest.remove() + manifest = None - if db_task.mode == 'annotation': - db_images.extend( - models.Image( - data=db_data, - path=os.path.relpath(frame_path, upload_dir), - frame=frame_number, - width=frame_size[0], - height=frame_size[1]) - for frame_path, frame_number, frame_size in img_meta) + if isinstance(ex, AssertionError): + base_msg = f": {ex}" + else: + base_msg = "" + slogger.glob.warning(ex, exc_info=True) + + _update_status( + f"Failed to create manifest for the uploaded video{base_msg}. " + "A manifest will not be used in this task" + ) + + if manifest: + video_frame_count = manifest.video_length + video_frame_size = manifest.video_resolution else: - video_size = img_meta[0][2] - video_path = img_meta[0][0] + video_frame_count = extractor.get_frame_count() + video_frame_size = extractor.get_image_size(0) + + db_data.size = len(range( + db_data.start_frame, + min( + data['stop_frame'] + 1 if data['stop_frame'] else video_frame_count, + video_frame_count, + ), + db_data.get_frame_step() + )) + video_path = os.path.join(upload_dir, media_files[0]) + else: # images, archive, pdf + db_data.size = len(extractor) - progress = extractor.get_progress(img_meta[-1][1]) - update_progress(progress) + manifest = ImageManifestManager(db_data.get_manifest_path()) + if not manifest.exists: + manifest.link( + sources=extractor.absolute_source_paths, + meta={ + k: {'related_images': related_images[k] } + for k in related_images + }, + data_dir=upload_dir, + DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D), + ) + manifest.create() + else: + manifest.init_index() + + for frame_id in extractor.frame_range: + image_path = extractor.get_path(frame_id) + image_size = None + + if manifest: + image_info = manifest[manifest_index(frame_id)] + + # check mapping + if not image_path.endswith(f"{image_info['name']}{image_info['extension']}"): + raise ValidationError('Incorrect file mapping to manifest content') + + if db_task.dimension == models.DimensionType.DIM_2D and ( + image_info.get('width') is not None and + image_info.get('height') is not None + ): + image_size = (image_info['width'], image_info['height']) + elif is_data_in_cloud: + raise ValidationError( + "Can't find image '{}' width or height info in the manifest" + .format(f"{image_info['name']}{image_info['extension']}") + ) - futures = queue.Queue(maxsize=settings.CVAT_CONCURRENT_CHUNK_PROCESSING) - with concurrent.futures.ThreadPoolExecutor(max_workers=2*settings.CVAT_CONCURRENT_CHUNK_PROCESSING) as executor: - for chunk_idx, chunk_data in generator: - db_data.size += len(chunk_data) - if futures.full(): - process_results(futures.get().result()) - futures.put(executor.submit(save_chunks, executor, chunk_idx, chunk_data)) + if not image_size: + image_size = extractor.get_image_size(frame_id) - while not futures.empty(): - process_results(futures.get().result()) + images.append( + models.Image( + data=db_data, + path=os.path.relpath(image_path, upload_dir), + frame=frame_id, + width=image_size[0], + height=image_size[1], + ) + ) if db_task.mode == 'annotation': - models.Image.objects.bulk_create(db_images) - created_images = models.Image.objects.filter(data_id=db_data.id) + models.Image.objects.bulk_create(images) + images = models.Image.objects.filter(data_id=db_data.id) db_related_files = [ models.RelatedFile(data=image.data, primary_image=image, path=os.path.join(upload_dir, related_file_path)) - for image in created_images + for image in images for related_file_path in related_images.get(image.path, []) ] models.RelatedFile.objects.bulk_create(db_related_files) - db_images = [] else: models.Video.objects.create( data=db_data, path=os.path.relpath(video_path, upload_dir), - width=video_size[0], height=video_size[1]) + width=video_frame_size[0], height=video_frame_size[1] + ) + # validate stop_frame if db_data.stop_frame == 0: db_data.stop_frame = db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step() else: - # validate stop_frame db_data.stop_frame = min(db_data.stop_frame, \ db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step()) slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id)) - _save_task_to_db(db_task, job_file_mapping=job_file_mapping) + _create_segments_and_jobs(db_task, job_file_mapping=job_file_mapping) + + if ( + settings.MEDIA_CACHE_ALLOW_STATIC_CACHE and + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM + ): + _create_static_chunks(db_task, media_extractor=extractor) + +def _create_static_chunks(db_task: models.Task, *, media_extractor: IMediaReader): + @attrs.define + class _ChunkProgressUpdater: + _call_counter: int = attrs.field(default=0, init=False) + _rq_job: rq.job.Job = attrs.field(factory=rq.get_current_job) + + def update_progress(self, progress: float): + progress_animation = '|/-\\' + + status_message = 'CVAT is preparing data chunks' + if not progress: + status_message = '{} {}'.format( + status_message, progress_animation[self._call_counter] + ) + + self._rq_job.meta['status'] = status_message + self._rq_job.meta['task_progress'] = progress or 0. + self._rq_job.save_meta() + + self._call_counter = (self._call_counter + 1) % len(progress_animation) + + def save_chunks( + executor: concurrent.futures.ThreadPoolExecutor, + db_segment: models.Segment, + chunk_idx: int, + chunk_frame_ids: Sequence[int] + ): + chunk_data = [media_iterator[frame_idx] for frame_idx in chunk_frame_ids] + + if ( + db_task.dimension == models.DimensionType.DIM_2D and + isinstance(media_extractor, ( + MEDIA_TYPES['image']['extractor'], + MEDIA_TYPES['zip']['extractor'], + MEDIA_TYPES['pdf']['extractor'], + MEDIA_TYPES['archive']['extractor'], + )) + ): + chunk_data = preload_images(chunk_data) + + # TODO: extract into a class + + fs_original = executor.submit( + original_chunk_writer.save_as_chunk, + images=chunk_data, + chunk_path=db_data.get_original_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + compressed_chunk_writer.save_as_chunk( + images=chunk_data, + chunk_path=db_data.get_compressed_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + fs_original.result() + + db_data = db_task.data + + if db_data.compressed_chunk_type == models.DataChoice.VIDEO: + compressed_chunk_writer_class = Mpeg4CompressedChunkWriter + else: + compressed_chunk_writer_class = ZipCompressedChunkWriter + + if db_data.original_chunk_type == models.DataChoice.VIDEO: + original_chunk_writer_class = Mpeg4ChunkWriter + + # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks, + # which should be visually lossless or nearly so. + # A lower value will significantly increase the chunk size with a slight increase of quality. + original_quality = 67 # TODO: fix discrepancy in values in different parts of code + else: + original_chunk_writer_class = ZipChunkWriter + original_quality = 100 + + chunk_writer_kwargs = {} + if db_task.dimension == models.DimensionType.DIM_3D: + chunk_writer_kwargs["dimension"] = db_task.dimension + compressed_chunk_writer = compressed_chunk_writer_class( + db_data.image_quality, **chunk_writer_kwargs + ) + original_chunk_writer = original_chunk_writer_class(original_quality, **chunk_writer_kwargs) + + db_segments = db_task.segment_set.all() + + if isinstance(media_extractor, MEDIA_TYPES['video']['extractor']): + def _get_frame_size(frame_tuple: Tuple[av.VideoFrame, Any, Any]) -> int: + # There is no need to be absolutely precise here, + # just need to provide the reasonable upper boundary. + # Return bytes needed for 1 frame + frame = frame_tuple[0] + return frame.width * frame.height * (frame.format.padded_bits_per_pixel // 8) + + # Currently, we only optimize video creation for sequential + # chunks with potential overlap, so parallel processing is likely to + # help only for image datasets + media_iterator = CachingMediaIterator( + media_extractor, + max_cache_memory=2 ** 30, max_cache_entries=db_task.overlap, + object_size_callback=_get_frame_size + ) + else: + media_iterator = RandomAccessIterator(media_extractor) + + with closing(media_iterator): + progress_updater = _ChunkProgressUpdater() + + # TODO: remove 2 * or the configuration option + # TODO: maybe make real multithreading support, currently the code is limited by 1 + # video segment chunk, even if more threads are available + max_concurrency = 2 * settings.CVAT_CONCURRENT_CHUNK_PROCESSING if not isinstance( + media_extractor, MEDIA_TYPES['video']['extractor'] + ) else 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor: + frame_step = db_data.get_frame_step() + for segment_idx, db_segment in enumerate(db_segments): + frame_counter = itertools.count() + for chunk_idx, chunk_frame_ids in ( + (chunk_idx, list(chunk_frame_ids)) + for chunk_idx, chunk_frame_ids in itertools.groupby( + ( + # Convert absolute to relative ids (extractor output positions) + # Extractor will skip frames outside requested + (abs_frame_id - db_data.start_frame) // frame_step + for abs_frame_id in db_segment.frame_set + ), + lambda _: next(frame_counter) // db_data.chunk_size + ) + ): + save_chunks(executor, db_segment, chunk_idx, chunk_frame_ids) + + progress_updater.update_progress(segment_idx / len(db_segments)) diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index ae0200b6a2aa..e7ae8ae9ba7b 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -1422,7 +1422,13 @@ def _create_task(task_data, media_data): if isinstance(media, io.BytesIO): media.seek(0) response = cls.client.post("/api/tasks/{}/data".format(tid), data=media_data) - assert response.status_code == status.HTTP_202_ACCEPTED + assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = cls.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = cls.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] cls.tasks.append({ @@ -1766,6 +1772,12 @@ def _create_task(task_data, media_data): media.seek(0) response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data) assert response.status_code == status.HTTP_202_ACCEPTED + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = self.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] self.tasks.append({ @@ -2882,6 +2894,12 @@ def _create_task(task_data, media_data): media.seek(0) response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data) assert response.status_code == status.HTTP_202_ACCEPTED + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = self.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] self.tasks.append({ @@ -3433,7 +3451,7 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data, expected_compressed_type, expected_original_type, expected_image_sizes, - expected_storage_method=StorageMethodChoice.FILE_SYSTEM, + expected_storage_method=None, expected_uploaded_data_location=StorageChoice.LOCAL, dimension=DimensionType.DIM_2D, expected_task_creation_status_state='Finished', @@ -3448,6 +3466,12 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data, if get_status_callback is None: get_status_callback = self._get_task_creation_status + if expected_storage_method is None: + if settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: + expected_storage_method = StorageMethodChoice.FILE_SYSTEM + else: + expected_storage_method = StorageMethodChoice.CACHE + # create task response = self._create_task(user, spec) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -4007,7 +4031,7 @@ def _test_api_v2_tasks_id_data_create_can_use_chunked_local_video(self, user): image_sizes = self._share_image_sizes['test_rotated_90_video.mp4'] self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, - self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.FILE_SYSTEM) + self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.CACHE) def _test_api_v2_tasks_id_data_create_can_use_chunked_cached_local_video(self, user): task_spec = { @@ -4104,7 +4128,6 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u task_data = { "image_quality": 70, - "use_cache": True } manifest_name = "images_manifest_sorted.jsonl" @@ -4115,79 +4138,34 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u for i, fn in enumerate(images + [manifest_name]) }) - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): + for use_cache in [True, False]: + task_data['use_cache'] = use_cache + + for copy_data in [True, False]: + with self.subTest(current_function_name(), copy=copy_data, use_cache=use_cache): + task_spec = task_spec_common.copy() + task_spec['name'] = task_spec['name'] + f' copy={copy_data}' + task_data_copy = task_data.copy() + task_data_copy['copy_data'] = copy_data + self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, + self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, + image_sizes, + expected_uploaded_data_location=( + StorageChoice.LOCAL if copy_data else StorageChoice.SHARE + ) + ) + + with self.subTest(current_function_name() + ' file order mismatch', use_cache=use_cache): task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, + task_spec['name'] = task_spec['name'] + f' mismatching file order' + task_data_copy = task_data.copy() + task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" + self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - with self.subTest(current_function_name() + ' without use cache'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' manifest without cache' - task_data_copy = task_data.copy() - task_data_copy['use_cache'] = False - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason="A manifest file can only be used with the 'use cache' option") + image_sizes, + expected_uploaded_data_location=StorageChoice.SHARE, + expected_task_creation_status_state='Failed', + expected_task_creation_status_reason='Incorrect file mapping to manifest content') def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sorting(self, user): task_spec = { @@ -4219,7 +4197,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sort task_data = task_data_common.copy() task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4278,7 +4256,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_images_with_predefined_sorti sorting_method=SortingMethod.PREDEFINED) task_data_common["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4339,7 +4317,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_archive_with_predefined_sor task_data = task_data_common.copy() task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4412,7 +4390,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_archive_with_predefined_sort sorting_method=SortingMethod.PREDEFINED) task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4590,7 +4568,7 @@ def _send_data_and_fail(*args, **kwargs): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data) with self.subTest(current_function_name() + ' mismatching file sets - extra files'): @@ -4604,7 +4582,7 @@ def _send_data_and_fail(*args, **kwargs): with self.assertRaisesMessage(Exception, "(extra)"): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data_and_fail) with self.subTest(current_function_name() + ' mismatching file sets - missing files'): @@ -4618,7 +4596,7 @@ def _send_data_and_fail(*args, **kwargs): with self.assertRaisesMessage(Exception, "(missing)"): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data_and_fail) def _test_api_v2_tasks_id_data_create_can_use_server_rar(self, user): diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py index a67a79109f33..9f000be5d218 100644 --- a/cvat/apps/engine/tests/test_rest_api_3D.py +++ b/cvat/apps/engine/tests/test_rest_api_3D.py @@ -86,9 +86,13 @@ def _create_task(self, data, image_data): assert response.status_code == status.HTTP_201_CREATED, response.status_code tid = response.data["id"] - response = self.client.post("/api/tasks/%s/data" % tid, - data=image_data) + response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) @@ -527,7 +531,7 @@ def test_api_v2_dump_and_upload_annotation(self): for user, edata in list(self.expected_dump_upload.items()): with self.subTest(format=f"{format_name}_{edata['name']}_dump"): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations url = self._generate_url_dump_tasks_annotations(task_id) file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip") @@ -718,7 +722,7 @@ def test_api_v2_export_dataset(self): for user, edata in list(self.expected_dump_upload.items()): with self.subTest(format=f"{format_name}_{edata['name']}_export"): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations url = self._generate_url_dump_dataset(task_id) file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip") @@ -740,6 +744,8 @@ def test_api_v2_export_dataset(self): content = io.BytesIO(b"".join(response.streaming_content)) with open(file_name, "wb") as f: f.write(content.getvalue()) - self.assertEqual(osp.exists(file_name), edata['file_exists']) - self._check_dump_content(content, task_ann_prev.data, format_name,related_files=False) + self.assertEqual(osp.exists(file_name), edata['file_exists']) + self._check_dump_content( + content, task_ann_prev.data, format_name, related_files=False + ) diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py index b884b3e9b4c4..3d2a533d1e97 100644 --- a/cvat/apps/engine/tests/utils.py +++ b/cvat/apps/engine/tests/utils.py @@ -13,7 +13,7 @@ from django.core.cache import caches from django.http.response import HttpResponse from PIL import Image -from rest_framework.test import APIClient, APITestCase +from rest_framework.test import APITestCase import av import django_rq import numpy as np @@ -92,14 +92,7 @@ def clear_rq_jobs(): class ApiTestBase(APITestCase): - def _clear_rq_jobs(self): - clear_rq_jobs() - - def setUp(self): - super().setUp() - self.client = APIClient() - - def tearDown(self): + def _clear_temp_data(self): # Clear server frame/chunk cache. # The parent class clears DB changes, and it can lead to under-cleaned task data, # which can affect other tests. @@ -112,7 +105,14 @@ def tearDown(self): # Clear any remaining RQ jobs produced by the tests executed self._clear_rq_jobs() - return super().tearDown() + def _clear_rq_jobs(self): + clear_rq_jobs() + + def setUp(self): + self._clear_temp_data() + + super().setUp() + self.client = self.client_class() def generate_image_file(filename, size=(100, 100)): diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 528c8314b677..3cb7e34c5c40 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: MIT +from abc import ABCMeta, abstractmethod import os import os.path as osp import re @@ -12,7 +13,7 @@ from contextlib import suppress from PIL import Image from types import SimpleNamespace -from typing import Optional, Any, Dict, List, cast, Callable, Mapping, Iterable +from typing import Optional, Any, Dict, List, Union, cast, Callable, Mapping, Iterable import traceback import textwrap from collections import namedtuple @@ -58,12 +59,14 @@ from cvat.apps.events.handlers import handle_dataset_import from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import ( + IFrameProvider, TaskFrameProvider, JobFrameProvider, FrameQuality +) from cvat.apps.engine.filters import NonModelSimpleFilter, NonModelOrderingFilter, NonModelJsonLogicFilter from cvat.apps.engine.media_extractors import get_mime from cvat.apps.engine.permissions import AnnotationGuidePermission, get_iam_context from cvat.apps.engine.models import ( - ClientFile, Job, JobType, Label, SegmentType, Task, Project, Issue, Data, + ClientFile, Job, JobType, Label, Task, Project, Issue, Data, Comment, StorageMethodChoice, StorageChoice, CloudProviderChoice, Location, CloudStorage as CloudStorageModel, Asset, AnnotationGuide, RequestStatus, RequestAction, RequestTarget, RequestSubresource @@ -631,19 +634,17 @@ def append_backup_chunk(self, request, file_id): def preview(self, request, pk): self._object = self.get_object() # call check_object_permissions as well - first_task = self._object.tasks.select_related('data__video').order_by('-id').first() + first_task: Optional[models.Task] = self._object.tasks.order_by('-id').first() if not first_task: return HttpResponseNotFound('Project image preview not found') - data_getter = DataChunkGetter( + data_getter = _TaskDataGetter( + db_task=first_task, data_type='preview', data_quality='compressed', - data_num=first_task.data.start_frame, - task_dim=first_task.dimension ) - return data_getter(request, first_task.data.start_frame, - first_task.data.stop_frame, first_task.data) + return data_getter() @staticmethod def _get_rq_response(queue, job_id): @@ -663,80 +664,50 @@ def _get_rq_response(queue, job_id): return response -class DataChunkGetter: - def __init__(self, data_type, data_num, data_quality, task_dim): +class _DataGetter(metaclass=ABCMeta): + def __init__( + self, data_type: str, data_num: Optional[Union[str, int]], data_quality: str + ) -> None: possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') possible_quality_values = ('compressed', 'original') if not data_type or data_type not in possible_data_type_values: raise ValidationError('Data type not specified or has wrong value') elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview': - if data_num is None: + if data_num is None and data_type != 'preview': raise ValidationError('Number is not specified') elif data_quality not in possible_quality_values: raise ValidationError('Wrong quality value') self.type = data_type self.number = int(data_num) if data_num is not None else None - self.quality = FrameProvider.Quality.COMPRESSED \ - if data_quality == 'compressed' else FrameProvider.Quality.ORIGINAL - - self.dimension = task_dim - - def _check_frame_range(self, frame: int): - frame_range = range(self._start, self._stop + 1, self._db_data.get_frame_step()) - if frame not in frame_range: - raise ValidationError( - f'The frame number should be in the [{self._start}, {self._stop}] range' - ) - - def __call__(self, request, start: int, stop: int, db_data: Optional[Data]): - if not db_data: - raise NotFound(detail='Cannot find requested data') + self.quality = FrameQuality.COMPRESSED \ + if data_quality == 'compressed' else FrameQuality.ORIGINAL - self._start = start - self._stop = stop - self._db_data = db_data + @abstractmethod + def _get_frame_provider(self) -> IFrameProvider: ... - frame_provider = FrameProvider(db_data, self.dimension) + def __call__(self): + frame_provider = self._get_frame_provider() try: if self.type == 'chunk': - start_chunk = frame_provider.get_chunk_number(start) - stop_chunk = frame_provider.get_chunk_number(stop) - # pylint: disable=superfluous-parens - if not (start_chunk <= self.number <= stop_chunk): - raise ValidationError('The chunk number should be in the ' + - f'[{start_chunk}, {stop_chunk}] range') - - # TODO: av.FFmpegError processing - if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE: - buff, mime_type = frame_provider.get_chunk(self.number, self.quality) - return HttpResponse(buff.getvalue(), content_type=mime_type) - - # Follow symbol links if the chunk is a link on a real image otherwise - # mimetype detection inside sendfile will work incorrectly. - path = os.path.realpath(frame_provider.get_chunk(self.number, self.quality)) - return sendfile(request, path) + data = frame_provider.get_chunk(self.number, quality=self.quality) + return HttpResponse(data.data.getvalue(), content_type=data.mime) elif self.type == 'frame' or self.type == 'preview': - self._check_frame_range(self.number) - if self.type == 'preview': - cache = MediaCache(self.dimension) - buf, mime = cache.get_local_preview_with_mime(self.number, db_data) + data = frame_provider.get_preview() else: - buf, mime = frame_provider.get_frame(self.number, self.quality) + data = frame_provider.get_frame(self.number, quality=self.quality) - return HttpResponse(buf.getvalue(), content_type=mime) + return HttpResponse(data.data.getvalue(), content_type=data.mime) elif self.type == 'context_image': - self._check_frame_range(self.number) - - cache = MediaCache(self.dimension) - buff, mime = cache.get_frame_context_images(db_data, self.number) - if not buff: + data = frame_provider.get_frame_context_images_chunk(self.number) + if not data: return HttpResponseNotFound() - return HttpResponse(buff, content_type=mime) + + return HttpResponse(data.data, content_type=data.mime) else: return Response(data='unknown data type {}.'.format(self.type), status=status.HTTP_400_BAD_REQUEST) @@ -745,44 +716,78 @@ def __call__(self, request, start: int, stop: int, db_data: Optional[Data]): '\n'.join([str(d) for d in ex.detail]) return Response(data=msg, status=ex.status_code) +class _TaskDataGetter(_DataGetter): + def __init__( + self, + db_task: models.Task, + *, + data_type: str, + data_quality: str, + data_num: Optional[Union[str, int]] = None, + ) -> None: + super().__init__(data_type=data_type, data_num=data_num, data_quality=data_quality) + self._db_task = db_task + + def _get_frame_provider(self) -> TaskFrameProvider: + return TaskFrameProvider(self._db_task) + + +class _JobDataGetter(_DataGetter): + def __init__( + self, + db_job: models.Job, + *, + data_type: str, + data_quality: str, + data_num: Optional[Union[str, int]] = None, + data_index: Optional[Union[str, int]] = None, + ) -> None: + possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') + possible_quality_values = ('compressed', 'original') + + if not data_type or data_type not in possible_data_type_values: + raise ValidationError('Data type not specified or has wrong value') + elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview': + if data_type == 'chunk': + if data_num is None and data_index is None: + raise ValidationError('Number or Index is not specified') + if data_num is not None and data_index is not None: + raise ValidationError('Number and Index cannot be used together') + elif data_num is None and data_type != 'preview': + raise ValidationError('Number is not specified') + elif data_quality not in possible_quality_values: + raise ValidationError('Wrong quality value') + + self.type = data_type -class JobDataGetter(DataChunkGetter): - def __init__(self, job: Job, data_type, data_num, data_quality): - super().__init__(data_type, data_num, data_quality, task_dim=job.segment.task.dimension) - self.job = job + self.index = int(data_index) if data_index is not None else None + self.number = int(data_num) if data_num is not None else None - def _check_frame_range(self, frame: int): - frame_range = self.job.segment.frame_set - if frame not in frame_range: - raise ValidationError("The frame number doesn't belong to the job") + self.quality = FrameQuality.COMPRESSED \ + if data_quality == 'compressed' else FrameQuality.ORIGINAL - def __call__(self, request, start, stop, db_data): - if self.type == 'chunk' and self.job.segment.type == SegmentType.SPECIFIC_FRAMES: - frame_provider = FrameProvider(db_data, self.dimension) + self._db_job = db_job - start_chunk = frame_provider.get_chunk_number(start) - stop_chunk = frame_provider.get_chunk_number(stop) - # pylint: disable=superfluous-parens - if not (start_chunk <= self.number <= stop_chunk): - raise ValidationError('The chunk number should be in the ' + - f'[{start_chunk}, {stop_chunk}] range') + def _get_frame_provider(self) -> JobFrameProvider: + return JobFrameProvider(self._db_job) - cache = MediaCache() + def __call__(self): + if self.type == 'chunk': + # Reproduce the task chunk indexing + frame_provider = self._get_frame_provider() - if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE: - buf, mime = cache.get_selective_job_chunk_data_with_mime( - chunk_number=self.number, quality=self.quality, job=self.job + if self.index is not None: + data = frame_provider.get_chunk( + self.index, quality=self.quality, is_task_chunk=False ) else: - buf, mime = cache.prepare_selective_job_chunk( - chunk_number=self.number, quality=self.quality, db_job=self.job + data = frame_provider.get_chunk( + self.number, quality=self.quality, is_task_chunk=True ) - return HttpResponse(buf.getvalue(), content_type=mime) - + return HttpResponse(data.data.getvalue(), content_type=data.mime) else: - return super().__call__(request, start, stop, db_data) - + return super().__call__() @extend_schema(tags=['tasks']) @extend_schema_view( @@ -1306,11 +1311,10 @@ def data(self, request, pk): data_num = request.query_params.get('number', None) data_quality = request.query_params.get('quality', 'compressed') - data_getter = DataChunkGetter(data_type, data_num, data_quality, - self._object.dimension) - - return data_getter(request, self._object.data.start_frame, - self._object.data.stop_frame, self._object.data) + data_getter = _TaskDataGetter( + self._object, data_type=data_type, data_num=data_num, data_quality=data_quality + ) + return data_getter() @tus_chunk_action(detail=True, suffix_base="data") def append_data_chunk(self, request, pk, file_id): @@ -1651,15 +1655,12 @@ def preview(self, request, pk): if not self._object.data: return HttpResponseNotFound('Task image preview not found') - data_getter = DataChunkGetter( + data_getter = _TaskDataGetter( + db_task=self._object, data_type='preview', data_quality='compressed', - data_num=self._object.data.start_frame, - task_dim=self._object.dimension ) - - return data_getter(request, self._object.data.start_frame, - self._object.data.stop_frame, self._object.data) + return data_getter() @extend_schema(tags=['jobs']) @@ -2026,8 +2027,14 @@ def get_export_callback(self, save_images: bool) -> Callable: OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.STR, enum=['compressed', 'original'], description="Specifies the quality level of the requested data"), - OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, - description="A unique number value identifying chunk or frame"), + OpenApiParameter('number', + location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, + description="A unique number value identifying chunk or frame. " + "The numbers are the same as for the task. " + "Deprecated for chunks in favor of 'index'"), + OpenApiParameter('index', + location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, + description="A unique number value identifying chunk, starts from 0 for each job"), ], responses={ '200': OpenApiResponse(OpenApiTypes.BINARY, description='Data of a specific type'), @@ -2039,12 +2046,15 @@ def data(self, request, pk): db_job = self.get_object() # call check_object_permissions as well data_type = request.query_params.get('type', None) data_num = request.query_params.get('number', None) + data_index = request.query_params.get('index', None) data_quality = request.query_params.get('quality', 'compressed') - data_getter = JobDataGetter(db_job, data_type, data_num, data_quality) - - return data_getter(request, db_job.segment.start_frame, - db_job.segment.stop_frame, db_job.segment.task.data) + data_getter = _JobDataGetter( + db_job, + data_type=data_type, data_quality=data_quality, + data_index=data_index, data_num=data_num + ) + return data_getter() @extend_schema(methods=['GET'], summary='Get metainformation for media files in a job', @@ -2137,15 +2147,12 @@ def metadata(self, request, pk): def preview(self, request, pk): self._object = self.get_object() # call check_object_permissions as well - data_getter = DataChunkGetter( + data_getter = _JobDataGetter( + db_job=self._object, data_type='preview', data_quality='compressed', - data_num=self._object.segment.start_frame, - task_dim=self._object.segment.task.dimension ) - - return data_getter(request, self._object.segment.start_frame, - self._object.segment.stop_frame, self._object.segment.task.data) + return data_getter() @extend_schema(tags=['issues']) @@ -2716,13 +2723,13 @@ def preview(self, request, pk): # The idea is try to define real manifest preview only for the storages that have related manifests # because otherwise it can lead to extra calls to a bucket, that are usually not free. if not db_storage.has_at_least_one_manifest: - result = cache.get_cloud_preview_with_mime(db_storage) + result = cache.get_cloud_preview(db_storage) if not result: return HttpResponseNotFound('Cloud storage preview not found') - return HttpResponse(result[0], result[1]) + return HttpResponse(result[0].getvalue(), result[1]) - preview, mime = cache.get_or_set_cloud_preview_with_mime(db_storage) - return HttpResponse(preview, mime) + preview, mime = cache.get_or_set_cloud_preview(db_storage) + return HttpResponse(preview.getvalue(), mime) except CloudStorageModel.DoesNotExist: message = f"Storage {pk} does not exist" slogger.glob.error(message) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index c86b4eaa61af..e49b93e24f12 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -1,11 +1,10 @@ # Copyright (C) 2021-2022 Intel Corporation -# Copyright (C) 2023 CVAT.ai Corporation +# Copyright (C) 2023-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from collections import OrderedDict from itertools import groupby -from io import BytesIO from typing import Dict, Optional from unittest import mock, skip import json @@ -14,11 +13,11 @@ import requests from django.contrib.auth.models import Group, User from django.http import HttpResponseNotFound, HttpResponseServerError -from PIL import Image from rest_framework import status -from rest_framework.test import APIClient, APITestCase -from cvat.apps.engine.tests.utils import filter_dict, get_paginated_collection +from cvat.apps.engine.tests.utils import ( + ApiTestBase, filter_dict, ForceLogin, generate_image_file, get_paginated_collection +) LAMBDA_ROOT_PATH = '/api/lambda' LAMBDA_FUNCTIONS_PATH = f'{LAMBDA_ROOT_PATH}/functions' @@ -49,34 +48,11 @@ with open(path) as f: functions = json.load(f) - -def generate_image_file(filename, size=(100, 100)): - f = BytesIO() - image = Image.new('RGB', size=size) - image.save(f, 'jpeg') - f.name = filename - f.seek(0) - return f - - -class ForceLogin: - def __init__(self, user, client): - self.user = user - self.client = client - - def __enter__(self): - if self.user: - self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend') - - return self - - def __exit__(self, exception_type, exception_value, traceback): - if self.user: - self.client.logout() - -class _LambdaTestCaseBase(APITestCase): +class _LambdaTestCaseBase(ApiTestBase): def setUp(self): - self.client = APIClient(raise_request_exception=False) + super().setUp() + + self.client = self.client_class(raise_request_exception=False) http_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', side_effect = self._get_data_from_lambda_manager_http) self.addCleanup(http_patcher.stop) @@ -181,6 +157,11 @@ def _create_task(self, task_spec, data, *, owner=None, org_id=None): data=data, QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid, QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 286b8b4cc985..143537985fd7 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -1,5 +1,5 @@ # Copyright (C) 2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -32,9 +32,9 @@ from rest_framework.request import Request import cvat.apps.dataset_manager as dm -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider from cvat.apps.engine.models import ( - Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget, + Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget ) from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.serializers import LabeledDataSerializer @@ -489,19 +489,19 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes): def _get_image(self, db_task, frame, quality): if quality is None or quality == "original": - quality = FrameProvider.Quality.ORIGINAL + quality = FrameQuality.ORIGINAL elif quality == "compressed": - quality = FrameProvider.Quality.COMPRESSED + quality = FrameQuality.COMPRESSED else: raise ValidationError( '`{}` lambda function was run '.format(self.id) + 'with wrong arguments (quality={})'.format(quality), code=status.HTTP_400_BAD_REQUEST) - frame_provider = FrameProvider(db_task.data) + frame_provider = TaskFrameProvider(db_task) image = frame_provider.get_frame(frame, quality=quality) - return base64.b64encode(image[0].getvalue()).decode('utf-8') + return base64.b64encode(image.data.getvalue()).decode('utf-8') class LambdaQueue: RESULT_TTL = timedelta(minutes=30) diff --git a/cvat/requirements/base.in b/cvat/requirements/base.in index 50723357d27a..2bc36c18d8e7 100644 --- a/cvat/requirements/base.in +++ b/cvat/requirements/base.in @@ -1,7 +1,13 @@ -r ../../utils/dataset_manifest/requirements.in attrs==21.4.0 + +# This is the last version of av that supports ffmpeg we depend on. +# Changing ffmpeg is undesirable, as there might be video decoding differences +# between versions. +# TODO: try to move to the newer version av==9.2.0 + azure-storage-blob==12.13.0 boto3==1.17.61 clickhouse-connect==0.6.8 diff --git a/cvat/schema.yml b/cvat/schema.yml index ff97755b26c1..779b08fe376f 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -2322,11 +2322,18 @@ paths: type: integer description: A unique integer value identifying this job. required: true + - in: query + name: index + schema: + type: integer + description: A unique number value identifying chunk, starts from 0 for each + job - in: query name: number schema: type: integer - description: A unique number value identifying chunk or frame + description: A unique number value identifying chunk or frame. The numbers + are the same as for the task. Deprecated for chunks in favor of 'index' - in: query name: quality schema: @@ -8074,6 +8081,10 @@ components: allOf: - $ref: '#/components/schemas/ChunkType' readOnly: true + data_original_chunk_type: + allOf: + - $ref: '#/components/schemas/ChunkType' + readOnly: true created_date: type: string format: date-time diff --git a/dev/format_python_code.sh b/dev/format_python_code.sh index 5b455a296f4d..7eff923abb8a 100755 --- a/dev/format_python_code.sh +++ b/dev/format_python_code.sh @@ -25,6 +25,9 @@ for paths in \ "cvat/apps/analytics_report" \ "cvat/apps/engine/lazy_list.py" \ "cvat/apps/engine/background.py" \ + "cvat/apps/engine/frame_provider.py" \ + "cvat/apps/engine/cache.py" \ + "cvat/apps/engine/default_settings.py" \ ; do ${BLACK} -- ${paths} ${ISORT} -- ${paths} diff --git a/docker-compose.yml b/docker-compose.yml index 051bd0bfd8cf..569e163e9fe5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,7 @@ x-backend-env: &backend-env CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk CVAT_REDIS_ONDISK_PORT: 6666 CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_ALLOW_STATIC_CACHE: '${CVAT_ALLOW_STATIC_CACHE:-no}' DJANGO_LOG_SERVER_HOST: vector DJANGO_LOG_SERVER_PORT: 80 no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} diff --git a/helm-chart/test.values.yaml b/helm-chart/test.values.yaml index 5a5fa8fe6bab..73edaa815d70 100644 --- a/helm-chart/test.values.yaml +++ b/helm-chart/test.values.yaml @@ -27,6 +27,12 @@ cvat: frontend: imagePullPolicy: Never +redis: + master: + # The "flushall" command, which we use in tests, is disabled in helm by default + # https://artifacthub.io/packages/helm/bitnami/redis#redis-master-configuration-parameters + disableCommands: [] + keydb: resources: requests: diff --git a/tests/python/rest_api/test_jobs.py b/tests/python/rest_api/test_jobs.py index 4fbea276e0a7..a6cd225a5d52 100644 --- a/tests/python/rest_api/test_jobs.py +++ b/tests/python/rest_api/test_jobs.py @@ -11,7 +11,7 @@ from copy import deepcopy from http import HTTPStatus from io import BytesIO -from itertools import product +from itertools import groupby, product from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -361,7 +361,7 @@ def _test_destroy_job_fails(self, user, job_id, *, expected_status: int, **kwarg assert response.status == expected_status return response - @pytest.mark.usefixtures("restore_cvat_data") + @pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.parametrize("job_type, allow", (("ground_truth", True), ("annotation", False))) def test_destroy_job(self, admin_user, jobs, job_type, allow): job = next(j for j in jobs if j["type"] == job_type) @@ -603,12 +603,8 @@ def test_get_gt_job_in_org_task( self._test_get_job_403(user["username"], job["id"]) -@pytest.mark.usefixtures( - # if the db is restored per test, there are conflicts with the server data cache - # if we don't clean the db, the gt jobs created will be reused, and their - # ids won't conflict - "restore_db_per_class" -) +@pytest.mark.usefixtures("restore_db_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_class") class TestGetGtJobData: def _delete_gt_job(self, user, gt_job_id): with make_api_client(user) as api_client: @@ -636,12 +632,11 @@ def test_can_get_gt_job_meta(self, admin_user, tasks, jobs, task_mode, request): :job_frame_count ] gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id)) with make_api_client(user) as api_client: (gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id) - request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id)) - # These values are relative to the resulting task frames, unlike meta values assert 0 == gt_job.start_frame assert task_meta.size - 1 == gt_job.stop_frame @@ -691,12 +686,11 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request) task_frame_ids = range(start_frame, stop_frame, frame_step) job_frame_ids = list(task_frame_ids[::3]) gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) with make_api_client(admin_user) as api_client: (gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id) - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - # These values are relative to the resulting task frames, unlike meta values assert 0 == gt_job.start_frame assert len(task_frame_ids) - 1 == gt_job.stop_frame @@ -717,7 +711,10 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request) @pytest.mark.parametrize("task_mode", ["annotation", "interpolation"]) @pytest.mark.parametrize("quality", ["compressed", "original"]) - def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality, request): + @pytest.mark.parametrize("indexing", ["absolute", "relative"]) + def test_can_get_gt_job_chunk( + self, admin_user, tasks, jobs, task_mode, quality, request, indexing + ): user = admin_user job_frame_count = 4 task = next( @@ -734,41 +731,49 @@ def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality, (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) frame_step = int(task_meta.frame_filter.split("=")[-1]) if task_meta.frame_filter else 1 - job_frame_ids = list(range(task_meta.start_frame, task_meta.stop_frame, frame_step))[ - :job_frame_count - ] + task_frame_ids = range(task_meta.start_frame, task_meta.stop_frame + 1, frame_step) + rng = np.random.Generator(np.random.MT19937(42)) + job_frame_ids = sorted(rng.choice(task_frame_ids, job_frame_count, replace=False).tolist()) + gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - with make_api_client(admin_user) as api_client: - (chunk_file, response) = api_client.jobs_api.retrieve_data( - gt_job.id, number=0, quality=quality, type="chunk" - ) - assert response.status == HTTPStatus.OK + if indexing == "absolute": + chunk_iter = groupby(task_frame_ids, key=lambda f: f // task_meta.chunk_size) + else: + chunk_iter = groupby(job_frame_ids, key=lambda f: f // task_meta.chunk_size) - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) + for chunk_id, chunk_frames in chunk_iter: + chunk_frames = list(chunk_frames) - frame_range = range( - task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step - ) - included_frames = job_frame_ids + if indexing == "absolute": + kwargs = {"number": chunk_id} + else: + kwargs = {"index": chunk_id} - # The frame count is the same as in the whole range - # with placeholders in the frames outside the job. - # This is required by the UI implementation - with zipfile.ZipFile(chunk_file) as chunk: - assert set(chunk.namelist()) == set("{:06d}.jpeg".format(i) for i in frame_range) + with make_api_client(admin_user) as api_client: + (chunk_file, response) = api_client.jobs_api.retrieve_data( + gt_job.id, **kwargs, quality=quality, type="chunk" + ) + assert response.status == HTTPStatus.OK + + # The frame count is the same as in the whole range + # with placeholders in the frames outside the job. + # This is required by the UI implementation + with zipfile.ZipFile(chunk_file) as chunk: + assert set(chunk.namelist()) == set( + f"{i:06d}.jpeg" for i in range(len(chunk_frames)) + ) - for file_info in chunk.filelist: - with chunk.open(file_info) as image_file: - image = Image.open(image_file) - image_data = np.array(image) + for file_info in chunk.filelist: + with chunk.open(file_info) as image_file: + image = Image.open(image_file) - if int(os.path.splitext(file_info.filename)[0]) not in included_frames: - assert image.size == (1, 1) - assert np.all(image_data == 0), image_data - else: - assert image.size > (1, 1) - assert np.any(image_data != 0) + chunk_frame_id = int(os.path.splitext(file_info.filename)[0]) + if chunk_frames[chunk_frame_id] not in job_frame_ids: + assert image.size == (1, 1) + else: + assert image.size > (1, 1) def _create_gt_job(self, user, task_id, frames): with make_api_client(user) as api_client: @@ -813,6 +818,7 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality, :job_frame_count ] gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) frame_range = range( task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step @@ -830,15 +836,13 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality, _check_status=False, ) assert response.status == HTTPStatus.BAD_REQUEST - assert b"The frame number doesn't belong to the job" in response.data + assert b"Incorrect requested frame number" in response.data (_, response) = api_client.jobs_api.retrieve_data( gt_job.id, number=included_frames[0], quality=quality, type="frame" ) assert response.status == HTTPStatus.OK - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - @pytest.mark.usefixtures("restore_db_per_class") class TestListJobs: diff --git a/tests/python/rest_api/test_queues.py b/tests/python/rest_api/test_queues.py index f801e661e426..4ce314b865b2 100644 --- a/tests/python/rest_api/test_queues.py +++ b/tests/python/rest_api/test_queues.py @@ -18,7 +18,7 @@ @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestRQQueueWorking: _USER_1 = "admin1" diff --git a/tests/python/rest_api/test_resource_import_export.py b/tests/python/rest_api/test_resource_import_export.py index 833661fcfab8..39f4be22a011 100644 --- a/tests/python/rest_api/test_resource_import_export.py +++ b/tests/python/rest_api/test_resource_import_export.py @@ -177,7 +177,7 @@ def test_user_cannot_export_to_cloud_storage_with_specific_location_without_acce @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") class TestImportResourceFromS3(_S3ResourceTest): @pytest.mark.usefixtures("restore_redis_inmem_per_function") @pytest.mark.parametrize("cloud_storage_id", [3]) diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index e849244361fc..eda54b8ddd0c 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -6,10 +6,14 @@ import io import itertools import json +import math import os import os.path as osp import zipfile +from abc import ABCMeta, abstractmethod +from contextlib import closing from copy import deepcopy +from enum import Enum from functools import partial from http import HTTPStatus from itertools import chain, product @@ -18,8 +22,10 @@ from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from time import sleep, time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union +import attrs +import numpy as np import pytest from cvat_sdk import Client, Config, exceptions from cvat_sdk.api_client import models @@ -30,6 +36,7 @@ from cvat_sdk.core.uploading import Uploader from deepdiff import DeepDiff from PIL import Image +from pytest_cases import fixture_ref, parametrize import shared.utils.s3 as s3 from shared.fixtures.init import docker_exec_cvat, kube_exec_cvat @@ -48,6 +55,7 @@ generate_image_files, generate_manifest, generate_video_file, + read_video_file, ) from .utils import ( @@ -903,7 +911,7 @@ def test_uses_subset_name( @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") class TestPostTaskData: _USERNAME = "admin1" @@ -2028,6 +2036,525 @@ def test_create_task_with_cloud_storage_directories_and_default_bucket_prefix( assert task.size == expected_task_size +class _SourceDataType(str, Enum): + images = "images" + video = "video" + + +class _TaskSpec(models.ITaskWriteRequest, models.IDataRequest, metaclass=ABCMeta): + size: int + frame_step: int + source_data_type: _SourceDataType + + @abstractmethod + def read_frame(self, i: int) -> Image.Image: ... + + +@attrs.define +class _TaskSpecBase(_TaskSpec): + _params: Union[Dict, models.TaskWriteRequest] + _data_params: Union[Dict, models.DataRequest] + size: int = attrs.field(kw_only=True) + + @property + def frame_step(self) -> int: + v = getattr(self, "frame_filter", "step=1") + return int(v.split("=")[-1]) + + def __getattr__(self, k: str) -> Any: + notfound = object() + + for params in [self._params, self._data_params]: + if isinstance(params, dict): + v = params.get(k, notfound) + else: + v = getattr(params, k, notfound) + + if v is not notfound: + return v + + raise AttributeError(k) + + +@attrs.define +class _ImagesTaskSpec(_TaskSpecBase): + source_data_type: ClassVar[_SourceDataType] = _SourceDataType.images + + _get_frame: Callable[[int], bytes] = attrs.field(kw_only=True) + + def read_frame(self, i: int) -> Image.Image: + return Image.open(io.BytesIO(self._get_frame(i))) + + +@attrs.define +class _VideoTaskSpec(_TaskSpecBase): + source_data_type: ClassVar[_SourceDataType] = _SourceDataType.video + + _get_video_file: Callable[[], io.IOBase] = attrs.field(kw_only=True) + + def read_frame(self, i: int) -> Image.Image: + with closing(read_video_file(self._get_video_file())) as reader: + for _ in range(i + 1): + frame = next(reader) + + return frame + + +@pytest.mark.usefixtures("restore_db_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_class") +@pytest.mark.usefixtures("restore_cvat_data_per_function") +class TestTaskData: + _USERNAME = "admin1" + + def _uploaded_images_task_fxt_base( + self, + request: pytest.FixtureRequest, + *, + frame_count: int = 10, + segment_size: Optional[int] = None, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + task_params = { + "name": request.node.name, + "labels": [{"name": "a"}], + } + if segment_size: + task_params["segment_size"] = segment_size + + image_files = generate_image_files(frame_count) + images_data = [f.getvalue() for f in image_files] + data_params = { + "image_quality": 70, + "client_files": image_files, + } + + def get_frame(i: int) -> bytes: + return images_data[i] + + task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params) + yield _ImagesTaskSpec( + models.TaskWriteRequest._from_openapi_data(**task_params), + models.DataRequest._from_openapi_data(**data_params), + get_frame=get_frame, + size=len(images_data), + ), task_id + + @pytest.fixture(scope="class") + def fxt_uploaded_images_task( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_images_task_fxt_base(request=request) + + @pytest.fixture(scope="class") + def fxt_uploaded_images_task_with_segments( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_images_task_fxt_base(request=request, segment_size=4) + + def _uploaded_video_task_fxt_base( + self, + request: pytest.FixtureRequest, + *, + frame_count: int = 10, + segment_size: Optional[int] = None, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + task_params = { + "name": request.node.name, + "labels": [{"name": "a"}], + } + if segment_size: + task_params["segment_size"] = segment_size + + video_file = generate_video_file(frame_count) + video_data = video_file.getvalue() + data_params = { + "image_quality": 70, + "client_files": [video_file], + } + + def get_video_file() -> io.BytesIO: + return io.BytesIO(video_data) + + task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params) + yield _VideoTaskSpec( + models.TaskWriteRequest._from_openapi_data(**task_params), + models.DataRequest._from_openapi_data(**data_params), + get_video_file=get_video_file, + size=frame_count, + ), task_id + + @pytest.fixture(scope="class") + def fxt_uploaded_video_task( + self, + request: pytest.FixtureRequest, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_video_task_fxt_base(request=request) + + @pytest.fixture(scope="class") + def fxt_uploaded_video_task_with_segments( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_video_task_fxt_base(request=request, segment_size=4) + + def _compute_segment_params(self, task_spec: _TaskSpec) -> List[Tuple[int, int]]: + segment_params = [] + segment_size = getattr(task_spec, "segment_size", 0) or task_spec.size + start_frame = getattr(task_spec, "start_frame", 0) + end_frame = (getattr(task_spec, "stop_frame", None) or (task_spec.size - 1)) + 1 + overlap = min( + ( + getattr(task_spec, "overlap", None) or 0 + if task_spec.source_data_type == _SourceDataType.images + else 5 + ), + segment_size // 2, + ) + segment_start = start_frame + while segment_start < end_frame: + if start_frame < segment_start: + segment_start -= overlap * task_spec.frame_step + + segment_end = segment_start + task_spec.frame_step * segment_size + + segment_params.append((segment_start, min(segment_end, end_frame) - 1)) + segment_start = segment_end + + return segment_params + + @staticmethod + def _compare_images( + expected: Image.Image, actual: Image.Image, *, must_be_identical: bool = True + ): + expected_pixels = np.array(expected) + chunk_frame_pixels = np.array(actual) + assert expected_pixels.shape == chunk_frame_pixels.shape + + if not must_be_identical: + # video chunks can have slightly changed colors, due to codec specifics + # compressed images can also be distorted + assert np.allclose(chunk_frame_pixels, expected_pixels, atol=2) + else: + assert np.array_equal(chunk_frame_pixels, expected_pixels) + + _default_task_cases = [ + fixture_ref("fxt_uploaded_images_task"), + fixture_ref("fxt_uploaded_images_task_with_segments"), + fixture_ref("fxt_uploaded_video_task"), + fixture_ref("fxt_uploaded_video_task_with_segments"), + ] + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_meta(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + assert task_meta.size == task_spec.size + assert task_meta.start_frame == getattr(task_spec, "start_frame", 0) + assert task_meta.stop_frame == getattr(task_spec, "stop_frame", None) or task_spec.size + assert task_meta.frame_filter == getattr(task_spec, "frame_filter", "") + + task_frame_set = set( + range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step) + ) + assert len(task_frame_set) == task_meta.size + + if getattr(task_spec, "chunk_size", None): + assert task_meta.chunk_size == task_spec.chunk_size + + if task_spec.source_data_type == _SourceDataType.video: + assert len(task_meta.frames) == 1 + else: + assert len(task_meta.frames) == task_meta.size + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_frames(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + for quality, abs_frame_id in product( + ["original", "compressed"], + range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step), + ): + rel_frame_id = ( + abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step + ) + (_, response) = api_client.tasks_api.retrieve_data( + task_id, + type="frame", + quality=quality, + number=rel_frame_id, + _parse_response=False, + ) + + if task_spec.source_data_type == _SourceDataType.video: + frame_size = (task_meta.frames[0].width, task_meta.frames[0].height) + else: + frame_size = ( + task_meta.frames[rel_frame_id].width, + task_meta.frames[rel_frame_id].height, + ) + + frame = Image.open(io.BytesIO(response.data)) + assert frame_size == frame.size + + self._compare_images( + task_spec.read_frame(abs_frame_id), + frame, + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_chunks(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task, _) = api_client.tasks_api.retrieve(task_id) + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + if task_spec.source_data_type == _SourceDataType.images: + assert task.data_original_chunk_type == "imageset" + assert task.data_compressed_chunk_type == "imageset" + elif task_spec.source_data_type == _SourceDataType.video: + assert task.data_original_chunk_type == "video" + + if getattr(task_spec, "use_zip_chunks", False): + assert task.data_compressed_chunk_type == "imageset" + else: + assert task.data_compressed_chunk_type == "video" + else: + assert False + + chunk_count = math.ceil(task_meta.size / task_meta.chunk_size) + for quality, chunk_id in product(["original", "compressed"], range(chunk_count)): + expected_chunk_frame_ids = range( + chunk_id * task_meta.chunk_size, + min((chunk_id + 1) * task_meta.chunk_size, task_meta.size), + ) + + (_, response) = api_client.tasks_api.retrieve_data( + task_id, type="chunk", quality=quality, number=chunk_id, _parse_response=False + ) + + chunk_file = io.BytesIO(response.data) + if zipfile.is_zipfile(chunk_file): + with zipfile.ZipFile(chunk_file, "r") as chunk_archive: + chunk_images = { + int(os.path.splitext(name)[0]): np.array( + Image.open(io.BytesIO(chunk_archive.read(name))) + ) + for name in chunk_archive.namelist() + } + chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0])) + else: + chunk_images = dict(enumerate(read_video_file(chunk_file))) + + assert sorted(chunk_images.keys()) == list(range(len(expected_chunk_frame_ids))) + + for chunk_frame, abs_frame_id in zip(chunk_images, expected_chunk_frame_ids): + self._compare_images( + task_spec.read_frame(abs_frame_id), + chunk_images[chunk_frame], + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_job_meta(self, task_spec: _TaskSpec, task_id: int): + segment_params = self._compute_segment_params(task_spec) + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + assert len(jobs) == len(segment_params) + + for (segment_start, segment_end), job in zip(segment_params, jobs): + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + assert (job_meta.start_frame, job_meta.stop_frame) == (segment_start, segment_end) + assert job_meta.frame_filter == getattr(task_spec, "frame_filter", "") + + segment_size = segment_end - segment_start + 1 + assert job_meta.size == segment_size + + task_frame_set = set( + range(job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step) + ) + assert len(task_frame_set) == job_meta.size + + if getattr(task_spec, "chunk_size", None): + assert job_meta.chunk_size == task_spec.chunk_size + + if task_spec.source_data_type == _SourceDataType.video: + assert len(job_meta.frames) == 1 + else: + assert len(job_meta.frames) == job_meta.size + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_job_frames(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + for job in jobs: + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + for quality, (frame_pos, abs_frame_id) in product( + ["original", "compressed"], + enumerate(range(job_meta.start_frame, job_meta.stop_frame)), + ): + rel_frame_id = ( + abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step + ) + (_, response) = api_client.jobs_api.retrieve_data( + job.id, + type="frame", + quality=quality, + number=rel_frame_id, + _parse_response=False, + ) + + if task_spec.source_data_type == _SourceDataType.video: + frame_size = (job_meta.frames[0].width, job_meta.frames[0].height) + else: + frame_size = ( + job_meta.frames[frame_pos].width, + job_meta.frames[frame_pos].height, + ) + + frame = Image.open(io.BytesIO(response.data)) + assert frame_size == frame.size + + self._compare_images( + task_spec.read_frame(abs_frame_id), + frame, + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + @parametrize("indexing", ["absolute", "relative"]) + def test_can_get_job_chunks(self, task_spec: _TaskSpec, task_id: int, indexing: str): + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + for job in jobs: + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + if task_spec.source_data_type == _SourceDataType.images: + assert job.data_original_chunk_type == "imageset" + assert job.data_compressed_chunk_type == "imageset" + elif task_spec.source_data_type == _SourceDataType.video: + assert job.data_original_chunk_type == "video" + + if getattr(task_spec, "use_zip_chunks", False): + assert job.data_compressed_chunk_type == "imageset" + else: + assert job.data_compressed_chunk_type == "video" + else: + assert False + + if indexing == "absolute": + chunk_count = math.ceil(task_meta.size / job_meta.chunk_size) + + def get_task_chunk_abs_frame_ids(chunk_id: int) -> Sequence[int]: + return range( + task_meta.start_frame + + chunk_id * task_meta.chunk_size * task_spec.frame_step, + task_meta.start_frame + + min((chunk_id + 1) * task_meta.chunk_size, task_meta.size) + * task_spec.frame_step, + ) + + def get_job_frame_ids() -> Sequence[int]: + return range( + job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step + ) + + def get_expected_chunk_abs_frame_ids(chunk_id: int): + return sorted( + set(get_task_chunk_abs_frame_ids(chunk_id)) & set(get_job_frame_ids()) + ) + + job_chunk_ids = ( + task_chunk_id + for task_chunk_id in range(chunk_count) + if get_expected_chunk_abs_frame_ids(task_chunk_id) + ) + else: + chunk_count = math.ceil(job_meta.size / job_meta.chunk_size) + job_chunk_ids = range(chunk_count) + + def get_expected_chunk_abs_frame_ids(chunk_id: int): + return sorted( + frame + for frame in range( + job_meta.start_frame + + chunk_id * job_meta.chunk_size * task_spec.frame_step, + job_meta.start_frame + + min((chunk_id + 1) * job_meta.chunk_size, job_meta.size) + * task_spec.frame_step, + ) + if not job_meta.included_frames or frame in job_meta.included_frames + ) + + for quality, chunk_id in product(["original", "compressed"], job_chunk_ids): + expected_chunk_abs_frame_ids = get_expected_chunk_abs_frame_ids(chunk_id) + + kwargs = {} + if indexing == "absolute": + kwargs["number"] = chunk_id + elif indexing == "relative": + kwargs["index"] = chunk_id + else: + assert False + + (_, response) = api_client.jobs_api.retrieve_data( + job.id, + type="chunk", + quality=quality, + **kwargs, + _parse_response=False, + ) + + chunk_file = io.BytesIO(response.data) + if zipfile.is_zipfile(chunk_file): + with zipfile.ZipFile(chunk_file, "r") as chunk_archive: + chunk_images = { + int(os.path.splitext(name)[0]): np.array( + Image.open(io.BytesIO(chunk_archive.read(name))) + ) + for name in chunk_archive.namelist() + } + chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0])) + else: + chunk_images = dict(enumerate(read_video_file(chunk_file))) + + assert sorted(chunk_images.keys()) == list(range(job_meta.size)) + + for chunk_frame, abs_frame_id in zip( + chunk_images, expected_chunk_abs_frame_ids + ): + self._compare_images( + task_spec.read_frame(abs_frame_id), + chunk_images[chunk_frame], + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @pytest.mark.usefixtures("restore_db_per_function") class TestPatchTaskLabel: def _get_task_labels(self, pid, user, **kwargs) -> List[models.Label]: @@ -2229,7 +2756,7 @@ def test_admin_can_add_skeleton(self, tasks, admin_user): @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") class TestWorkWithTask: _USERNAME = "admin1" @@ -2286,7 +2813,13 @@ def _make_client(self) -> Client: return Client(BASE_URL, config=Config(status_check_period=0.01)) @pytest.fixture(autouse=True) - def setup(self, restore_db_per_function, restore_cvat_data, tmp_path: Path, admin_user: str): + def setup( + self, + restore_db_per_function, + restore_cvat_data_per_function, + tmp_path: Path, + admin_user: str, + ): self.tmp_dir = tmp_path self.client = self._make_client() diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index 142c4354c4d1..e7ac8418b69a 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -29,6 +29,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py index d5fbc0957eb7..542ad9a1e80c 100644 --- a/tests/python/sdk/test_datasets.py +++ b/tests/python/sdk/test_datasets.py @@ -23,6 +23,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_jobs.py b/tests/python/sdk/test_jobs.py index ef46fcb8cf0e..3202e2957ff0 100644 --- a/tests/python/sdk/test_jobs.py +++ b/tests/python/sdk/test_jobs.py @@ -29,6 +29,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/sdk/test_projects.py b/tests/python/sdk/test_projects.py index 43d6257c03c6..b03df660d87a 100644 --- a/tests/python/sdk/test_projects.py +++ b/tests/python/sdk/test_projects.py @@ -32,6 +32,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 722cb37ab003..2bcbd122abff 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -36,6 +36,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_tasks.py b/tests/python/sdk/test_tasks.py index 0dc5c0694e9c..54e0823d3311 100644 --- a/tests/python/sdk/test_tasks.py +++ b/tests/python/sdk/test_tasks.py @@ -33,6 +33,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/shared/assets/jobs.json b/tests/python/shared/assets/jobs.json index d4add795c783..415fb67d44cd 100644 --- a/tests/python/shared/assets/jobs.json +++ b/tests/python/shared/assets/jobs.json @@ -10,6 +10,7 @@ "created_date": "2024-07-15T15:34:53.594000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -51,6 +52,7 @@ "created_date": "2024-07-15T15:33:10.549000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -92,6 +94,7 @@ "created_date": "2024-03-21T20:50:05.838000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 3, "guide_id": null, @@ -125,6 +128,7 @@ "created_date": "2024-03-21T20:50:05.815000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -158,6 +162,7 @@ "created_date": "2024-03-21T20:50:05.811000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -191,6 +196,7 @@ "created_date": "2024-03-21T20:50:05.805000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -224,6 +230,7 @@ "created_date": "2023-05-26T16:11:23.946000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 3, "guide_id": null, @@ -257,6 +264,7 @@ "created_date": "2023-05-26T16:11:23.880000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -290,6 +298,7 @@ "created_date": "2023-03-27T19:08:07.649000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 4, "guide_id": null, @@ -331,6 +340,7 @@ "created_date": "2023-03-27T19:08:07.649000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 6, "guide_id": null, @@ -372,6 +382,7 @@ "created_date": "2023-03-10T11:57:31.614000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -413,6 +424,7 @@ "created_date": "2023-03-10T11:56:33.757000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -454,6 +466,7 @@ "created_date": "2023-03-01T15:36:26.668000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -495,6 +508,7 @@ "created_date": "2023-02-10T14:05:25.947000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -528,6 +542,7 @@ "created_date": "2022-12-01T12:53:10.425000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "video", "dimension": "2d", "frame_count": 25, "guide_id": null, @@ -569,6 +584,7 @@ "created_date": "2022-09-22T14:22:25.820000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 8, "guide_id": null, @@ -610,6 +626,7 @@ "created_date": "2022-06-08T08:33:06.505000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -649,6 +666,7 @@ "created_date": "2022-03-05T10:32:19.149000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -690,6 +708,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -723,6 +742,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -756,6 +776,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -795,6 +816,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -834,6 +856,7 @@ "created_date": "2022-03-05T08:30:48.612000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 14, "guide_id": null, @@ -867,6 +890,7 @@ "created_date": "2022-02-21T10:31:52.429000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -900,6 +924,7 @@ "created_date": "2022-02-16T06:26:54.631000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "3d", "frame_count": 1, "guide_id": null, @@ -939,6 +964,7 @@ "created_date": "2022-02-16T06:25:48.168000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "video", "dimension": "2d", "frame_count": 25, "guide_id": null, @@ -978,6 +1004,7 @@ "created_date": "2021-12-14T18:50:29.458000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 23, "guide_id": null, diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py index 8e9d334f7a47..4a17454617d0 100644 --- a/tests/python/shared/fixtures/init.py +++ b/tests/python/shared/fixtures/init.py @@ -96,12 +96,20 @@ def pytest_addoption(parser): def _run(command, capture_output=True): _command = command.split() if isinstance(command, str) else command try: + logger.debug(f"Executing a command: {_command}") + stdout, stderr = "", "" if capture_output: proc = run(_command, check=True, stdout=PIPE, stderr=PIPE) # nosec stdout, stderr = proc.stdout.decode(), proc.stderr.decode() else: proc = run(_command) # nosec + + if stdout: + logger.debug(f"Output (stdout): {stdout}") + if stderr: + logger.debug(f"Output (stderr): {stderr}") + return stdout, stderr except CalledProcessError as exc: message = f"Command failed: {' '.join(map(shlex.quote, _command))}." @@ -232,20 +240,20 @@ def kube_restore_clickhouse_db(): def docker_restore_redis_inmem(): - docker_exec_redis_inmem(["redis-cli", "flushall"]) + docker_exec_redis_inmem(["redis-cli", "-e", "flushall"]) def kube_restore_redis_inmem(): - kube_exec_redis_inmem(["redis-cli", "flushall"]) + kube_exec_redis_inmem(["sh", "-c", 'redis-cli -e -a "${REDIS_PASSWORD}" flushall']) def docker_restore_redis_ondisk(): - docker_exec_redis_ondisk(["redis-cli", "-p", "6666", "flushall"]) + docker_exec_redis_ondisk(["redis-cli", "-e", "-p", "6666", "flushall"]) def kube_restore_redis_ondisk(): kube_exec_redis_ondisk( - ["redis-cli", "-p", "6666", "-a", "${CVAT_REDIS_ONDISK_PASSWORD}", "flushall"] + ["sh", "-c", 'redis-cli -e -p 6666 -a "${CVAT_REDIS_ONDISK_PASSWORD}" flushall'] ) @@ -551,7 +559,7 @@ def restore_db_per_class(request): @pytest.fixture(scope="function") -def restore_cvat_data(request): +def restore_cvat_data_per_function(request): platform = request.config.getoption("--platform") if platform == "local": docker_restore_data_volumes() @@ -592,6 +600,15 @@ def restore_redis_inmem_per_function(request): kube_restore_redis_inmem() +@pytest.fixture(scope="class") +def restore_redis_inmem_per_class(request): + platform = request.config.getoption("--platform") + if platform == "local": + docker_restore_redis_inmem() + else: + kube_restore_redis_inmem() + + @pytest.fixture(scope="function") def restore_redis_ondisk_per_function(request): platform = request.config.getoption("--platform") @@ -599,3 +616,12 @@ def restore_redis_ondisk_per_function(request): docker_restore_redis_ondisk() else: kube_restore_redis_ondisk() + + +@pytest.fixture(scope="class") +def restore_redis_ondisk_per_class(request): + platform = request.config.getoption("--platform") + if platform == "local": + docker_restore_redis_ondisk() + else: + kube_restore_redis_ondisk() diff --git a/tests/python/shared/utils/helpers.py b/tests/python/shared/utils/helpers.py index f336cb3f9111..ac5948182d78 100644 --- a/tests/python/shared/utils/helpers.py +++ b/tests/python/shared/utils/helpers.py @@ -1,10 +1,11 @@ -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT import subprocess +from contextlib import closing from io import BytesIO -from typing import List, Optional +from typing import Generator, List, Optional import av import av.video.reformatter @@ -13,7 +14,7 @@ from shared.fixtures.init import get_server_image_tag -def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)): +def generate_image_file(filename="image.png", size=(100, 50), color=(0, 0, 0)): f = BytesIO() f.name = filename image = Image.new("RGB", size=size, color=color) @@ -40,7 +41,7 @@ def generate_image_files( return images -def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO: +def generate_video_file(num_frames: int, size=(100, 50)) -> BytesIO: f = BytesIO() f.name = "video.avi" @@ -60,6 +61,19 @@ def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO: return f +def read_video_file(file: BytesIO) -> Generator[Image.Image, None, None]: + file.seek(0) + + with av.open(file) as container: + video_stream = container.streams.video[0] + + with closing(video_stream.codec_context): # pyav has a memory leak in stream.close() + with closing(container.demux(video_stream)) as demux_iter: + for packet in demux_iter: + for frame in packet.decode(): + yield frame.to_image() + + def generate_manifest(path: str) -> None: command = [ "docker",