From 2352696830f984c23a02605f3b7cbda6a3011343 Mon Sep 17 00:00:00 2001 From: Nevo David Date: Wed, 31 Jul 2024 22:13:56 +0700 Subject: [PATCH] feat: credits --- apps/backend/src/api/api.module.ts | 5 ++- .../src/api/routes/copilot.controller.ts | 23 +++++++++-- .../src/api/routes/media.controller.ts | 13 +++++- .../polonto/polonto.picture.generation.tsx | 41 ++++++++++++++----- .../database/prisma/media/media.service.ts | 11 +++-- .../organizations/organization.repository.ts | 1 + .../src/database/prisma/schema.prisma | 13 ++++++ .../subscriptions/subscription.repository.ts | 29 +++++++++++++ .../subscriptions/subscription.service.ts | 30 ++++++++++++++ 9 files changed, 145 insertions(+), 21 deletions(-) diff --git a/apps/backend/src/api/api.module.ts b/apps/backend/src/api/api.module.ts index bc6a5bd7..9bd5b14c 100644 --- a/apps/backend/src/api/api.module.ts +++ b/apps/backend/src/api/api.module.ts @@ -38,7 +38,8 @@ const authenticatedController = [ BillingController, NotificationsController, MarketplaceController, - MessagesController + MessagesController, + CopilotController, ]; @Module({ imports: [ @@ -59,7 +60,7 @@ const authenticatedController = [ ] : []), ], - controllers: [StripeController, AuthController, CopilotController, ...authenticatedController], + controllers: [StripeController, AuthController, ...authenticatedController], providers: [ AuthService, StripeService, diff --git a/apps/backend/src/api/routes/copilot.controller.ts b/apps/backend/src/api/routes/copilot.controller.ts index 38284a60..2f46b48a 100644 --- a/apps/backend/src/api/routes/copilot.controller.ts +++ b/apps/backend/src/api/routes/copilot.controller.ts @@ -1,22 +1,37 @@ -import { Controller, Post, Req, Res } from '@nestjs/common'; +import { Controller, Get, Post, Req, Res } from '@nestjs/common'; import { CopilotRuntime, OpenAIAdapter, copilotRuntimeNestEndpoint, } from '@copilotkit/runtime'; +import { GetOrgFromRequest } from '@gitroom/nestjs-libraries/user/org.from.request'; +import { Organization } from '@prisma/client'; +import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service'; @Controller('/copilot') export class CopilotController { + constructor(private _subscriptionService: SubscriptionService) {} @Post('/chat') chat(@Req() req: Request, @Res() res: Response) { const copilotRuntimeHandler = copilotRuntimeNestEndpoint({ endpoint: '/copilot/chat', runtime: new CopilotRuntime(), - // @ts-ignore - serviceAdapter: new OpenAIAdapter({ model: req?.body?.variables?.data?.metadata?.requestType === 'TextareaCompletion' ? 'gpt-4o-mini' : 'gpt-4o' }), + serviceAdapter: new OpenAIAdapter({ + model: + // @ts-ignore + req?.body?.variables?.data?.metadata?.requestType === + 'TextareaCompletion' + ? 'gpt-4o-mini' + : 'gpt-4o', + }), }); // @ts-ignore return copilotRuntimeHandler(req, res); } -} \ No newline at end of file + + @Get('/credits') + calculateCredits(@GetOrgFromRequest() organization: Organization) { + return this._subscriptionService.checkCredits(organization); + } +} diff --git a/apps/backend/src/api/routes/media.controller.ts b/apps/backend/src/api/routes/media.controller.ts index ec2a1729..fc379d96 100644 --- a/apps/backend/src/api/routes/media.controller.ts +++ b/apps/backend/src/api/routes/media.controller.ts @@ -9,11 +9,15 @@ import { ApiTags } from '@nestjs/swagger'; import handleR2Upload from '@gitroom/nestjs-libraries/upload/r2.uploader'; import { FileInterceptor } from '@nestjs/platform-express'; import { CustomFileValidationPipe } from '@gitroom/nestjs-libraries/upload/custom.upload.validation'; +import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service'; @ApiTags('Media') @Controller('/media') export class MediaController { - constructor(private _mediaService: MediaService) {} + constructor( + private _mediaService: MediaService, + private _subscriptionService: SubscriptionService + ) {} @Post('/generate-image') async generateImage( @@ -21,7 +25,12 @@ export class MediaController { @Req() req: Request, @Body('prompt') prompt: string ) { - return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt)}; + const total = await this._subscriptionService.checkCredits(org); + if (total.credits <= 0) { + return false; + } + + return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt, org)}; } @Post('/upload-simple') diff --git a/apps/frontend/src/components/launches/polonto/polonto.picture.generation.tsx b/apps/frontend/src/components/launches/polonto/polonto.picture.generation.tsx index 3384f00e..61ec8ef8 100644 --- a/apps/frontend/src/components/launches/polonto/polonto.picture.generation.tsx +++ b/apps/frontend/src/components/launches/polonto/polonto.picture.generation.tsx @@ -1,23 +1,44 @@ -import React from 'react'; +import React, { useCallback } from 'react'; import { observer } from 'mobx-react-lite'; -import { InputGroup, Button } from '@blueprintjs/core'; +import { InputGroup } from '@blueprintjs/core'; import { Clean } from '@blueprintjs/icons'; import { SectionTab } from 'polotno/side-panel'; -import { getKey } from 'polotno/utils/validate-key'; import { getImageSize } from 'polotno/utils/image'; import { ImagesGrid } from 'polotno/side-panel/images-grid'; -import { getAPI } from 'polotno/utils/api'; import { useFetch } from '@gitroom/helpers/utils/custom.fetch'; +import useSWR from 'swr'; +import { Button } from '@gitroom/react/form/button'; +import { useToaster } from '@gitroom/react/toaster/toaster'; const GenerateTab = observer(({ store }: any) => { const inputRef = React.useRef(null); const [image, setImage] = React.useState(null); const [loading, setLoading] = React.useState(false); const fetch = useFetch(); + const toast = useToaster(); + + const loadCredits = useCallback(async () => { + return ( + await fetch(`/copilot/credits`, { + method: 'GET', + }) + ).json(); + }, []); + + const {data, mutate} = useSWR('copilot-credits', loadCredits); const handleGenerate = async () => { + if (data?.credits <= 0) { + window.open('/billing', '_blank'); + return ; + } + + if (!inputRef.current.value) { + toast.show('Please type your prompt', 'warning'); + return ; + } setLoading(true); setImage(null); @@ -33,14 +54,15 @@ const GenerateTab = observer(({ store }: any) => { alert('Something went wrong, please try again later...'); return; } - const data = await req.json(); - setImage(data.output); + mutate(); + const newData = await req.json(); + setImage(newData.output); }; return ( <>
- Generate image with AI + Generate image with AI {data?.credits ? `(${data?.credits} left)` : ``}
{ /> {image && ( , private readonly _organization: PrismaRepository<'organization'>, private readonly _user: PrismaRepository<'user'>, + private readonly _credits: PrismaRepository<'credits'>, private _usedCodes: PrismaRepository<'usedCodes'> ) {} @@ -173,4 +176,30 @@ export class SubscriptionRepository { }, }); } + + async getCreditsFrom(organizationId: string, from: dayjs.Dayjs) { + const load = await this._credits.model.credits.groupBy({ + by: ['organizationId'], + where: { + organizationId, + createdAt: { + gte: from.toDate(), + }, + }, + _sum: { + credits: true, + }, + }); + + return load?.[0]?._sum?.credits || 0; + } + + useCredit(org: Organization) { + return this._credits.model.credits.create({ + data: { + organizationId: org.id, + credits: 1, + }, + }); + } } diff --git a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts index 6dc55b7d..839268cb 100644 --- a/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts +++ b/libraries/nestjs-libraries/src/database/prisma/subscriptions/subscription.service.ts @@ -3,6 +3,8 @@ import { pricing } from '@gitroom/nestjs-libraries/database/prisma/subscriptions import { SubscriptionRepository } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.repository'; import { IntegrationService } from '@gitroom/nestjs-libraries/database/prisma/integrations/integration.service'; import { OrganizationService } from '@gitroom/nestjs-libraries/database/prisma/organizations/organization.service'; +import { Organization } from '@prisma/client'; +import dayjs from 'dayjs'; @Injectable() export class SubscriptionService { @@ -18,6 +20,10 @@ export class SubscriptionService { ); } + useCredit(organization: Organization) { + return this._subscriptionRepository.useCredit(organization); + } + getCode(code: string) { return this._subscriptionRepository.getCode(code); } @@ -152,4 +158,28 @@ export class SubscriptionService { async getSubscription(organizationId: string) { return this._subscriptionRepository.getSubscription(organizationId); } + + async checkCredits(organization: Organization) { + // @ts-ignore + const type = organization?.subscription?.subscriptionTier || 'FREE'; + + if (type === 'FREE') { + return {credits: 0}; + } + + // @ts-ignore + let date = dayjs(organization.subscription.createdAt); + while (date.isBefore(dayjs())) { + date = date.add(1, 'month'); + } + + const checkFromMonth = date.subtract(1, 'month'); + const imageGenerationCount = pricing[type].image_generation_count; + + const totalUse = await this._subscriptionRepository.getCreditsFrom(organization.id, checkFromMonth); + + return { + credits: imageGenerationCount - totalUse, + } + } }