fix(server): select asset face columns explicitly (#5564)

* select columns explicitly

* updated sql

* formatting
This commit is contained in:
Mert 2023-12-08 13:43:35 -05:00 committed by GitHub
parent 2f4ee622ab
commit 2553c54b26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 11 deletions

View File

@ -11,13 +11,19 @@ import { asVector, isValidInteger } from '../infra.utils';
@Injectable()
export class SmartInfoRepository implements ISmartInfoRepository {
private logger = new Logger(SmartInfoRepository.name);
private faceColumns: string[];
constructor(
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
) {}
) {
this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity)
.ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding');
}
async init(modelName: string): Promise<void> {
const { dimSize } = getCLIPModelInfo(modelName);
@ -79,13 +85,15 @@ export class SmartInfoRepository implements ISmartInfoRepository {
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.addSelect('1 + (faces.embedding <=> :embedding)', 'distance')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId = :ownerId')
.orderBy(`faces.embedding <=> :embedding`)
.setParameters({ ownerId, embedding: asVector(embedding) })
.limit(numResults);
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col));
results = await manager
.createQueryBuilder()
.select('res.*')

View File

@ -81,15 +81,15 @@ SET
WITH
"cte" AS (
SELECT
"faces"."id" AS "faces_id",
"faces"."assetId" AS "faces_assetId",
"faces"."personId" AS "faces_personId",
"faces"."imageWidth" AS "faces_imageWidth",
"faces"."imageHeight" AS "faces_imageHeight",
"faces"."boundingBoxX1" AS "faces_boundingBoxX1",
"faces"."boundingBoxY1" AS "faces_boundingBoxY1",
"faces"."boundingBoxX2" AS "faces_boundingBoxX2",
"faces"."boundingBoxY2" AS "faces_boundingBoxY2",
"faces"."id" AS "id",
"faces"."assetId" AS "assetId",
"faces"."personId" AS "personId",
"faces"."imageWidth" AS "imageWidth",
"faces"."imageHeight" AS "imageHeight",
"faces"."boundingBoxX1" AS "boundingBoxX1",
"faces"."boundingBoxY1" AS "boundingBoxY1",
"faces"."boundingBoxX2" AS "boundingBoxX2",
"faces"."boundingBoxY2" AS "boundingBoxY2",
1 + ("faces"."embedding" <= > $1) AS "distance"
FROM
"asset_faces" "faces"