diff --git a/libraries/nestjs-libraries/src/database/prisma/database.module.ts b/libraries/nestjs-libraries/src/database/prisma/database.module.ts index 1bf52972..17198f05 100644 --- a/libraries/nestjs-libraries/src/database/prisma/database.module.ts +++ b/libraries/nestjs-libraries/src/database/prisma/database.module.ts @@ -1,5 +1,5 @@ import { Global, Module } from '@nestjs/common'; -import { PrismaRepository, PrismaService } from './prisma.service'; +import { PrismaRepository, PrismaService, PrismaTransaction } from './prisma.service'; import { OrganizationRepository } from '@gitroom/nestjs-libraries/database/prisma/organizations/organization.repository'; import { OrganizationService } from '@gitroom/nestjs-libraries/database/prisma/organizations/organization.service'; import { UsersService } from '@gitroom/nestjs-libraries/database/prisma/users/users.service'; @@ -49,6 +49,7 @@ import { FalService } from '@gitroom/nestjs-libraries/openai/fal.service'; providers: [ PrismaService, PrismaRepository, + PrismaTransaction, UsersService, UsersRepository, OrganizationService, diff --git a/libraries/nestjs-libraries/src/database/prisma/media/media.service.ts b/libraries/nestjs-libraries/src/database/prisma/media/media.service.ts index dc0d37e2..73623dbd 100644 --- a/libraries/nestjs-libraries/src/database/prisma/media/media.service.ts +++ b/libraries/nestjs-libraries/src/database/prisma/media/media.service.ts @@ -37,16 +37,17 @@ export class MediaService { org: Organization, generatePromptFirst?: boolean ) { - if (generatePromptFirst) { - prompt = await this._openAi.generatePromptForPicture(prompt); - console.log('Prompt:', prompt); - } - const image = await this._openAi.generateImage( - prompt, - !!generatePromptFirst + return await this._subscriptionService.useCredit( + org, + 'ai_images', + async () => { + if (generatePromptFirst) { + prompt = await this._openAi.generatePromptForPicture(prompt); + console.log('Prompt:', prompt); + } + return this._openAi.generateImage(prompt, !!generatePromptFirst); + } ); - await this._subscriptionService.useCredit(org); - return image; } saveFile(org: string, fileName: string, filePath: string) { @@ -99,17 +100,21 @@ export class MediaService { throw new HttpException('This video is not available in trial mode', 406); } - const loadedData = await video.instance.processAndValidate( - body.output, - body.customParams + await video.instance.processAndValidate(body.customParams); + + return await this._subscriptionService.useCredit( + org, + 'ai_videos', + async () => { + const loadedData = await video.instance.process( + body.output, + body.customParams + ); + + const file = await this.storage.uploadSimple(loadedData); + return this.saveFile(org.id, file.split('/').pop(), file); + } ); - - const file = await this.storage.uploadSimple(loadedData); - const save = await this.saveFile(org.id, file.split('/').pop(), file); - - await this._subscriptionService.useCredit(org, 'ai_videos'); - - return save; } async videoFunction(identifier: string, functionName: string, body: any) { diff --git a/libraries/nestjs-libraries/src/database/prisma/prisma.service.ts b/libraries/nestjs-libraries/src/database/prisma/prisma.service.ts index 7e1755d1..eeb4ebc3 100644 --- a/libraries/nestjs-libraries/src/database/prisma/prisma.service.ts +++ b/libraries/nestjs-libraries/src/database/prisma/prisma.service.ts @@ -25,3 +25,12 @@ export class PrismaRepository { this.model = this._prismaService; } } + + +@Injectable() +export class PrismaTransaction { + public model: Pick; + constructor(private _prismaService: PrismaService) { + this.model = this._prismaService; + } +} diff --git a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.repository.ts b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.repository.ts index 98dbcd0e..6bfd2ec5 100644 --- a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.repository.ts +++ b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.repository.ts @@ -1,5 +1,8 @@ import { Injectable } from '@nestjs/common'; -import { PrismaRepository } from '@gitroom/nestjs-libraries/database/prisma/prisma.service'; +import { + PrismaRepository, + PrismaTransaction, +} from '@gitroom/nestjs-libraries/database/prisma/prisma.service'; import dayjs from 'dayjs'; import { Organization } from '@prisma/client'; @@ -203,7 +206,11 @@ export class SubscriptionRepository { }); } - async getCreditsFrom(organizationId: string, from: dayjs.Dayjs, type = 'ai_images') { + async getCreditsFrom( + organizationId: string, + from: dayjs.Dayjs, + type = 'ai_images' + ) { const load = await this._credits.model.credits.groupBy({ by: ['organizationId'], where: { @@ -221,14 +228,29 @@ export class SubscriptionRepository { return load?.[0]?._sum?.credits || 0; } - useCredit(org: Organization, type = 'ai_images') { - return this._credits.model.credits.create({ + async useCredit( + org: Organization, + type = 'ai_images', + func: () => Promise + ) { + const data = await this._credits.model.credits.create({ data: { organizationId: org.id, credits: 1, type, }, }); + + try { + return await func(); + } catch (err) { + await this._credits.model.credits.delete({ + where: { + id: data.id, + }, + }); + throw err; + } } setCustomerId(orgId: string, customerId: string) { diff --git a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts index e829d177..9d57b522 100644 --- a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts +++ b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts @@ -21,8 +21,8 @@ export class SubscriptionService { ); } - useCredit(organization: Organization, type = 'ai_images') { - return this._subscriptionRepository.useCredit(organization, type); + useCredit(organization: Organization, type = 'ai_images', func: () => Promise) : Promise { + return this._subscriptionRepository.useCredit(organization, type, func); } getCode(code: string) { diff --git a/libraries/nestjs-libraries/src/videos/video.interface.ts b/libraries/nestjs-libraries/src/videos/video.interface.ts index e78fd41b..c509cfbf 100644 --- a/libraries/nestjs-libraries/src/videos/video.interface.ts +++ b/libraries/nestjs-libraries/src/videos/video.interface.ts @@ -6,7 +6,6 @@ export abstract class VideoAbstract { dto: Type; async processAndValidate( - output: 'vertical' | 'horizontal', customParams?: T ) { const validationPipe = new ValidationPipe({ @@ -17,15 +16,13 @@ export abstract class VideoAbstract { }, }); - const transformed = await validationPipe.transform(customParams, { + await validationPipe.transform(customParams, { type: 'body', metatype: this.dto, }); - - return this.process(output, transformed); } - protected abstract process( + abstract process( output: 'vertical' | 'horizontal', customParams?: T ): Promise;