feat: credits

This commit is contained in:
Nevo David 2024-07-31 22:13:56 +07:00
parent 1b235652ac
commit 2352696830
9 changed files with 145 additions and 21 deletions

View File

@ -38,7 +38,8 @@ const authenticatedController = [
BillingController,
NotificationsController,
MarketplaceController,
MessagesController
MessagesController,
CopilotController,
];
@Module({
imports: [
@ -59,7 +60,7 @@ const authenticatedController = [
]
: []),
],
controllers: [StripeController, AuthController, CopilotController, ...authenticatedController],
controllers: [StripeController, AuthController, ...authenticatedController],
providers: [
AuthService,
StripeService,

View File

@ -1,22 +1,37 @@
import { Controller, Post, Req, Res } from '@nestjs/common';
import { Controller, Get, Post, Req, Res } from '@nestjs/common';
import {
CopilotRuntime,
OpenAIAdapter,
copilotRuntimeNestEndpoint,
} from '@copilotkit/runtime';
import { GetOrgFromRequest } from '@gitroom/nestjs-libraries/user/org.from.request';
import { Organization } from '@prisma/client';
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
@Controller('/copilot')
export class CopilotController {
constructor(private _subscriptionService: SubscriptionService) {}
@Post('/chat')
chat(@Req() req: Request, @Res() res: Response) {
const copilotRuntimeHandler = copilotRuntimeNestEndpoint({
endpoint: '/copilot/chat',
runtime: new CopilotRuntime(),
// @ts-ignore
serviceAdapter: new OpenAIAdapter({ model: req?.body?.variables?.data?.metadata?.requestType === 'TextareaCompletion' ? 'gpt-4o-mini' : 'gpt-4o' }),
serviceAdapter: new OpenAIAdapter({
model:
// @ts-ignore
req?.body?.variables?.data?.metadata?.requestType ===
'TextareaCompletion'
? 'gpt-4o-mini'
: 'gpt-4o',
}),
});
// @ts-ignore
return copilotRuntimeHandler(req, res);
}
}
@Get('/credits')
calculateCredits(@GetOrgFromRequest() organization: Organization) {
return this._subscriptionService.checkCredits(organization);
}
}

View File

@ -9,11 +9,15 @@ import { ApiTags } from '@nestjs/swagger';
import handleR2Upload from '@gitroom/nestjs-libraries/upload/r2.uploader';
import { FileInterceptor } from '@nestjs/platform-express';
import { CustomFileValidationPipe } from '@gitroom/nestjs-libraries/upload/custom.upload.validation';
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
@ApiTags('Media')
@Controller('/media')
export class MediaController {
constructor(private _mediaService: MediaService) {}
constructor(
private _mediaService: MediaService,
private _subscriptionService: SubscriptionService
) {}
@Post('/generate-image')
async generateImage(
@ -21,7 +25,12 @@ export class MediaController {
@Req() req: Request,
@Body('prompt') prompt: string
) {
return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt)};
const total = await this._subscriptionService.checkCredits(org);
if (total.credits <= 0) {
return false;
}
return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt, org)};
}
@Post('/upload-simple')

View File

@ -1,23 +1,44 @@
import React from 'react';
import React, { useCallback } from 'react';
import { observer } from 'mobx-react-lite';
import { InputGroup, Button } from '@blueprintjs/core';
import { InputGroup } from '@blueprintjs/core';
import { Clean } from '@blueprintjs/icons';
import { SectionTab } from 'polotno/side-panel';
import { getKey } from 'polotno/utils/validate-key';
import { getImageSize } from 'polotno/utils/image';
import { ImagesGrid } from 'polotno/side-panel/images-grid';
import { getAPI } from 'polotno/utils/api';
import { useFetch } from '@gitroom/helpers/utils/custom.fetch';
import useSWR from 'swr';
import { Button } from '@gitroom/react/form/button';
import { useToaster } from '@gitroom/react/toaster/toaster';
const GenerateTab = observer(({ store }: any) => {
const inputRef = React.useRef<any>(null);
const [image, setImage] = React.useState(null);
const [loading, setLoading] = React.useState(false);
const fetch = useFetch();
const toast = useToaster();
const loadCredits = useCallback(async () => {
return (
await fetch(`/copilot/credits`, {
method: 'GET',
})
).json();
}, []);
const {data, mutate} = useSWR('copilot-credits', loadCredits);
const handleGenerate = async () => {
if (data?.credits <= 0) {
window.open('/billing', '_blank');
return ;
}
if (!inputRef.current.value) {
toast.show('Please type your prompt', 'warning');
return ;
}
setLoading(true);
setImage(null);
@ -33,14 +54,15 @@ const GenerateTab = observer(({ store }: any) => {
alert('Something went wrong, please try again later...');
return;
}
const data = await req.json();
setImage(data.output);
mutate();
const newData = await req.json();
setImage(newData.output);
};
return (
<>
<div style={{ height: '40px', paddingTop: '5px' }}>
Generate image with AI
Generate image with AI {data?.credits ? `(${data?.credits} left)` : ``}
</div>
<InputGroup
placeholder="Type your image generation prompt here..."
@ -56,11 +78,10 @@ const GenerateTab = observer(({ store }: any) => {
/>
<Button
onClick={handleGenerate}
intent="primary"
loading={loading}
loading={loading} innerClassName="invert"
style={{ marginBottom: '40px' }}
>
Generate
{data?.credits <= 0 ? 'Click to purchase more credits' : 'Generate'}
</Button>
{image && (
<ImagesGrid

View File

@ -1,16 +1,21 @@
import {Injectable} from "@nestjs/common";
import {MediaRepository} from "@gitroom/nestjs-libraries/database/prisma/media/media.repository";
import { OpenaiService } from '@gitroom/nestjs-libraries/openai/openai.service';
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
import { Organization } from '@prisma/client';
@Injectable()
export class MediaService {
constructor(
private _mediaRepository: MediaRepository,
private _openAi: OpenaiService
private _openAi: OpenaiService,
private _subscriptionService: SubscriptionService
){}
generateImage(prompt: string) {
return this._openAi.generateImage(prompt);
async generateImage(prompt: string, org: Organization) {
const image = await this._openAi.generateImage(prompt);
await this._subscriptionService.useCredit(org);
return image;
}
saveFile(org: string, fileName: string, filePath: string) {

View File

@ -107,6 +107,7 @@ export class OrganizationRepository {
subscriptionTier: true,
totalChannels: true,
isLifetime: true,
createdAt: true,
},
},
},

View File

@ -28,6 +28,7 @@ model Organization {
notifications Notifications[]
buyerOrganization MessagesGroup[]
usedCodes UsedCodes[]
credits Credits[]
}
model User {
@ -168,6 +169,18 @@ model Media {
@@index([organizationId])
}
model Credits {
id String @id @default(uuid())
organization Organization @relation(fields: [organizationId], references: [id])
organizationId String
credits Int
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
@@index([organizationId])
@@index([createdAt])
}
model Subscription {
id String @id @default(cuid())
organizationId String @unique

View File

@ -1,5 +1,7 @@
import { Injectable } from '@nestjs/common';
import { PrismaRepository } from '@gitroom/nestjs-libraries/database/prisma/prisma.service';
import dayjs from 'dayjs';
import { Organization } from '@prisma/client';
@Injectable()
export class SubscriptionRepository {
@ -7,6 +9,7 @@ export class SubscriptionRepository {
private readonly _subscription: PrismaRepository<'subscription'>,
private readonly _organization: PrismaRepository<'organization'>,
private readonly _user: PrismaRepository<'user'>,
private readonly _credits: PrismaRepository<'credits'>,
private _usedCodes: PrismaRepository<'usedCodes'>
) {}
@ -173,4 +176,30 @@ export class SubscriptionRepository {
},
});
}
async getCreditsFrom(organizationId: string, from: dayjs.Dayjs) {
const load = await this._credits.model.credits.groupBy({
by: ['organizationId'],
where: {
organizationId,
createdAt: {
gte: from.toDate(),
},
},
_sum: {
credits: true,
},
});
return load?.[0]?._sum?.credits || 0;
}
useCredit(org: Organization) {
return this._credits.model.credits.create({
data: {
organizationId: org.id,
credits: 1,
},
});
}
}

View File

@ -3,6 +3,8 @@ import { pricing } from '@gitroom/nestjs-libraries/database/prisma/subscriptions
import { SubscriptionRepository } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.repository';
import { IntegrationService } from '@gitroom/nestjs-libraries/database/prisma/integrations/integration.service';
import { OrganizationService } from '@gitroom/nestjs-libraries/database/prisma/organizations/organization.service';
import { Organization } from '@prisma/client';
import dayjs from 'dayjs';
@Injectable()
export class SubscriptionService {
@ -18,6 +20,10 @@ export class SubscriptionService {
);
}
useCredit(organization: Organization) {
return this._subscriptionRepository.useCredit(organization);
}
getCode(code: string) {
return this._subscriptionRepository.getCode(code);
}
@ -152,4 +158,28 @@ export class SubscriptionService {
async getSubscription(organizationId: string) {
return this._subscriptionRepository.getSubscription(organizationId);
}
async checkCredits(organization: Organization) {
// @ts-ignore
const type = organization?.subscription?.subscriptionTier || 'FREE';
if (type === 'FREE') {
return {credits: 0};
}
// @ts-ignore
let date = dayjs(organization.subscription.createdAt);
while (date.isBefore(dayjs())) {
date = date.add(1, 'month');
}
const checkFromMonth = date.subtract(1, 'month');
const imageGenerationCount = pricing[type].image_generation_count;
const totalUse = await this._subscriptionRepository.getCreditsFrom(organization.id, checkFromMonth);
return {
credits: imageGenerationCount - totalUse,
}
}
}