diff --git a/Makefile b/Makefile index 3dcdf0806f..460f01eb35 100644 --- a/Makefile +++ b/Makefile @@ -26,4 +26,4 @@ prod-scale: docker-compose -f ./docker/docker-compose.yml up --build -V --scale immich-server=3 --scale immich-microservices=3 --remove-orphans api: - cd ./server && npm run api:generate \ No newline at end of file + cd ./server && npm run api:generate diff --git a/server/apps/immich/src/api-v1/album/album-repository.ts b/server/apps/immich/src/api-v1/album/album-repository.ts index b19729c35c..569b7b4295 100644 --- a/server/apps/immich/src/api-v1/album/album-repository.ts +++ b/server/apps/immich/src/api-v1/album/album-repository.ts @@ -25,6 +25,7 @@ export interface IAlbumRepository { updateAlbum(album: AlbumEntity, updateAlbumDto: UpdateAlbumDto): Promise; getListByAssetId(userId: string, assetId: string): Promise; getCountByUserId(userId: string): Promise; + getSharedWithUserAlbumCount(userId: string, assetId: string): Promise; } export const ALBUM_REPOSITORY = 'ALBUM_REPOSITORY'; @@ -283,4 +284,17 @@ export class AlbumRepository implements IAlbumRepository { return this.albumRepository.save(album); } + + async getSharedWithUserAlbumCount(userId: string, assetId: string): Promise { + const result = await this + .userAlbumRepository + .createQueryBuilder('usa') + .select('count(aa)', 'count') + .innerJoin('asset_album', 'aa', 'aa.albumId = usa.albumId') + .where('aa.assetId = :assetId', { assetId }) + .andWhere('usa.sharedUserId = :userId', { userId }) + .getRawOne(); + + return result.count; + } } diff --git a/server/apps/immich/src/api-v1/album/album.service.spec.ts b/server/apps/immich/src/api-v1/album/album.service.spec.ts index 30153795f2..ca1db8f62e 100644 --- a/server/apps/immich/src/api-v1/album/album.service.spec.ts +++ b/server/apps/immich/src/api-v1/album/album.service.spec.ts @@ -123,6 +123,7 @@ describe('Album service', () => { updateAlbum: jest.fn(), getListByAssetId: jest.fn(), getCountByUserId: jest.fn(), + getSharedWithUserAlbumCount: jest.fn(), }; assetRepositoryMock = { @@ -142,6 +143,7 @@ describe('Album service', () => { getAssetWithNoThumbnail: jest.fn(), getAssetWithNoSmartInfo: jest.fn(), getExistingAssets: jest.fn(), + countByIdAndUser: jest.fn(), }; downloadServiceMock = { diff --git a/server/apps/immich/src/api-v1/asset/asset-repository.ts b/server/apps/immich/src/api-v1/asset/asset-repository.ts index 6d0aa2e41d..2aa5fc6418 100644 --- a/server/apps/immich/src/api-v1/asset/asset-repository.ts +++ b/server/apps/immich/src/api-v1/asset/asset-repository.ts @@ -43,6 +43,7 @@ export interface IAssetRepository { userId: string, checkDuplicateAssetDto: CheckExistingAssetsDto, ): Promise; + countByIdAndUser(assetId: string, userId: string): Promise; } export const ASSET_REPOSITORY = 'ASSET_REPOSITORY'; @@ -343,4 +344,13 @@ export class AssetRepository implements IAssetRepository { }); return new CheckExistingAssetsResponseDto(existingAssets.map((a) => a.deviceAssetId)); } + + async countByIdAndUser(assetId: string, userId: string): Promise { + return await this.assetRepository.count({ + where: { + id: assetId, + userId + } + }); + } } diff --git a/server/apps/immich/src/api-v1/asset/asset.controller.ts b/server/apps/immich/src/api-v1/asset/asset.controller.ts index 94624e0e48..83befba8b3 100644 --- a/server/apps/immich/src/api-v1/asset/asset.controller.ts +++ b/server/apps/immich/src/api-v1/asset/asset.controller.ts @@ -21,7 +21,7 @@ import { FileFieldsInterceptor } from '@nestjs/platform-express'; import { assetUploadOption } from '../../config/asset-upload.config'; import { AuthUserDto, GetAuthUser } from '../../decorators/auth-user.decorator'; import { ServeFileDto } from './dto/serve-file.dto'; -import { Response as Res} from 'express'; +import { Response as Res } from 'express'; import { BackgroundTaskService } from '../../modules/background-task/background-task.service'; import { DeleteAssetDto } from './dto/delete-asset.dto'; import { SearchAssetDto } from './dto/search-asset.dto'; @@ -86,10 +86,12 @@ export class AssetController { @Get('/download/:assetId') async downloadFile( + @GetAuthUser() authUser: AuthUserDto, @Response({ passthrough: true }) res: Res, @Query(new ValidationPipe({ transform: true })) query: ServeFileDto, @Param('assetId') assetId: string, ): Promise { + await this.assetService.checkAssetsAccess(authUser, [assetId]); return this.assetService.downloadFile(query, assetId, res); } @@ -110,22 +112,26 @@ export class AssetController { @Get('/file/:assetId') @Header('Cache-Control', 'max-age=3600') async serveFile( + @GetAuthUser() authUser: AuthUserDto, @Headers() headers: Record, @Response({ passthrough: true }) res: Res, @Query(new ValidationPipe({ transform: true })) query: ServeFileDto, @Param('assetId') assetId: string, ): Promise { + await this.assetService.checkAssetsAccess(authUser, [assetId]); return this.assetService.serveFile(assetId, query, res, headers); } @Get('/thumbnail/:assetId') @Header('Cache-Control', 'max-age=3600') async getAssetThumbnail( + @GetAuthUser() authUser: AuthUserDto, @Headers() headers: Record, @Response({ passthrough: true }) res: Res, @Param('assetId') assetId: string, @Query(new ValidationPipe({ transform: true })) query: GetAssetThumbnailDto, ): Promise { + await this.assetService.checkAssetsAccess(authUser, [assetId]); return this.assetService.getAssetThumbnail(assetId, query, res, headers); } @@ -203,7 +209,8 @@ export class AssetController { @GetAuthUser() authUser: AuthUserDto, @Param('assetId') assetId: string, ): Promise { - return await this.assetService.getAssetById(authUser, assetId); + await this.assetService.checkAssetsAccess(authUser, [assetId]); + return await this.assetService.getAssetById(assetId); } /** @@ -215,7 +222,8 @@ export class AssetController { @Param('assetId') assetId: string, @Body() dto: UpdateAssetDto, ): Promise { - return await this.assetService.updateAssetById(authUser, assetId, dto); + await this.assetService.checkAssetsAccess(authUser, [assetId], true); + return await this.assetService.updateAssetById(assetId, dto); } @Delete('/') @@ -223,17 +231,19 @@ export class AssetController { @GetAuthUser() authUser: AuthUserDto, @Body(ValidationPipe) assetIds: DeleteAssetDto, ): Promise { + await this.assetService.checkAssetsAccess(authUser, assetIds.ids, true); + const deleteAssetList: AssetResponseDto[] = []; for (const id of assetIds.ids) { - const assets = await this.assetService.getAssetById(authUser, id); + const assets = await this.assetService.getAssetById(id); if (!assets) { continue; } deleteAssetList.push(assets); if (assets.livePhotoVideoId) { - const livePhotoVideo = await this.assetService.getAssetById(authUser, assets.livePhotoVideoId); + const livePhotoVideo = await this.assetService.getAssetById(assets.livePhotoVideoId); if (livePhotoVideo) { deleteAssetList.push(livePhotoVideo); assetIds.ids = [...assetIds.ids, livePhotoVideo.id]; @@ -241,7 +251,7 @@ export class AssetController { } } - const result = await this.assetService.deleteAssetById(authUser, assetIds); + const result = await this.assetService.deleteAssetById(assetIds); result.forEach((res) => { deleteAssetList.filter((a) => a.id == res.id && res.status == DeleteAssetStatusEnum.SUCCESS); diff --git a/server/apps/immich/src/api-v1/asset/asset.module.ts b/server/apps/immich/src/api-v1/asset/asset.module.ts index f33d7a415b..d6d3b98196 100644 --- a/server/apps/immich/src/api-v1/asset/asset.module.ts +++ b/server/apps/immich/src/api-v1/asset/asset.module.ts @@ -10,13 +10,18 @@ import { CommunicationModule } from '../communication/communication.module'; import { QueueNameEnum } from '@app/job/constants/queue-name.constant'; import { AssetRepository, ASSET_REPOSITORY } from './asset-repository'; import { DownloadModule } from '../../modules/download/download.module'; +import { ALBUM_REPOSITORY, AlbumRepository } from '../album/album-repository'; +import { AlbumEntity } from '@app/database/entities/album.entity'; +import { UserAlbumEntity } from '@app/database/entities/user-album.entity'; +import { UserEntity } from '@app/database/entities/user.entity'; +import { AssetAlbumEntity } from '@app/database/entities/asset-album.entity'; @Module({ imports: [ CommunicationModule, BackgroundTaskModule, DownloadModule, - TypeOrmModule.forFeature([AssetEntity]), + TypeOrmModule.forFeature([AssetEntity, AlbumEntity, UserAlbumEntity, UserEntity, AssetAlbumEntity]), BullModule.registerQueue({ name: QueueNameEnum.ASSET_UPLOADED, defaultJobOptions: { @@ -42,6 +47,10 @@ import { DownloadModule } from '../../modules/download/download.module'; provide: ASSET_REPOSITORY, useClass: AssetRepository, }, + { + provide: ALBUM_REPOSITORY, + useClass: AlbumRepository, + }, ], exports: [AssetService], }) diff --git a/server/apps/immich/src/api-v1/asset/asset.service.spec.ts b/server/apps/immich/src/api-v1/asset/asset.service.spec.ts index 4c61c4e997..7f19e0684e 100644 --- a/server/apps/immich/src/api-v1/asset/asset.service.spec.ts +++ b/server/apps/immich/src/api-v1/asset/asset.service.spec.ts @@ -11,11 +11,13 @@ import { DownloadService } from '../../modules/download/download.service'; import { BackgroundTaskService } from '../../modules/background-task/background-task.service'; import { IAssetUploadedJob, IVideoTranscodeJob } from '@app/job'; import { Queue } from 'bull'; +import { IAlbumRepository } from "../album/album-repository"; describe('AssetService', () => { let sui: AssetService; let a: Repository; // TO BE DELETED AFTER FINISHED REFACTORING let assetRepositoryMock: jest.Mocked; + let albumRepositoryMock: jest.Mocked; let downloadServiceMock: jest.Mocked>; let backgroundTaskServiceMock: jest.Mocked; let assetUploadedQueueMock: jest.Mocked>; @@ -122,6 +124,7 @@ describe('AssetService', () => { getAssetWithNoThumbnail: jest.fn(), getAssetWithNoSmartInfo: jest.fn(), getExistingAssets: jest.fn(), + countByIdAndUser: jest.fn(), }; downloadServiceMock = { @@ -130,6 +133,7 @@ describe('AssetService', () => { sui = new AssetService( assetRepositoryMock, + albumRepositoryMock, a, backgroundTaskServiceMock, assetUploadedQueueMock, diff --git a/server/apps/immich/src/api-v1/asset/asset.service.ts b/server/apps/immich/src/api-v1/asset/asset.service.ts index df71d58837..6c9d72a0d5 100644 --- a/server/apps/immich/src/api-v1/asset/asset.service.ts +++ b/server/apps/immich/src/api-v1/asset/asset.service.ts @@ -54,6 +54,7 @@ import { InjectQueue } from '@nestjs/bull'; import { Queue } from 'bull'; import { DownloadService } from '../../modules/download/download.service'; import { DownloadDto } from './dto/download-library.dto'; +import { ALBUM_REPOSITORY, IAlbumRepository } from '../album/album-repository'; const fileInfo = promisify(stat); @@ -63,6 +64,9 @@ export class AssetService { @Inject(ASSET_REPOSITORY) private _assetRepository: IAssetRepository, + @Inject(ALBUM_REPOSITORY) + private _albumRepository: IAlbumRepository, + @InjectRepository(AssetEntity) private assetRepository: Repository, @@ -221,22 +225,18 @@ export class AssetService { return assets.map((asset) => mapAsset(asset)); } - public async getAssetById(authUser: AuthUserDto, assetId: string): Promise { + public async getAssetById(assetId: string): Promise { const asset = await this._assetRepository.getById(assetId); return mapAsset(asset); } - public async updateAssetById(authUser: AuthUserDto, assetId: string, dto: UpdateAssetDto): Promise { + public async updateAssetById(assetId: string, dto: UpdateAssetDto): Promise { const asset = await this._assetRepository.getById(assetId); if (!asset) { throw new BadRequestException('Asset not found'); } - if (authUser.id !== asset.userId) { - throw new ForbiddenException('Not the owner'); - } - const updatedAsset = await this._assetRepository.update(asset, dto); return mapAsset(updatedAsset); @@ -496,14 +496,13 @@ export class AssetService { } } - public async deleteAssetById(authUser: AuthUserDto, assetIds: DeleteAssetDto): Promise { + public async deleteAssetById(assetIds: DeleteAssetDto): Promise { const result: DeleteAssetResponseDto[] = []; const target = assetIds.ids; for (const assetId of target) { const res = await this.assetRepository.delete({ id: assetId, - userId: authUser.id, }); if (res.affected) { @@ -642,6 +641,26 @@ export class AssetService { getAssetCountByUserId(authUser: AuthUserDto): Promise { return this._assetRepository.getAssetCountByUserId(authUser.id); } + + async checkAssetsAccess(authUser: AuthUserDto, assetIds: string[], mustBeOwner = false) { + for (const assetId of assetIds) { + // Step 1: Check if user owns asset + if ((await this._assetRepository.countByIdAndUser(assetId, authUser.id)) == 1) { + continue; + } + + // Avoid additional checks if ownership is required + if (!mustBeOwner) { + // Step 2: Check if asset is part of an album shared with me + if ((await this._albumRepository.getSharedWithUserAlbumCount(authUser.id, assetId)) > 0) { + continue; + } + + //TODO: Step 3: Check if asset is part of a public album + } + throw new ForbiddenException(); + } + } } async function processETag(path: string, res: Res, headers: Record): Promise {