feat: credits
This commit is contained in:
parent
1b235652ac
commit
2352696830
|
|
@ -38,7 +38,8 @@ const authenticatedController = [
|
|||
BillingController,
|
||||
NotificationsController,
|
||||
MarketplaceController,
|
||||
MessagesController
|
||||
MessagesController,
|
||||
CopilotController,
|
||||
];
|
||||
@Module({
|
||||
imports: [
|
||||
|
|
@ -59,7 +60,7 @@ const authenticatedController = [
|
|||
]
|
||||
: []),
|
||||
],
|
||||
controllers: [StripeController, AuthController, CopilotController, ...authenticatedController],
|
||||
controllers: [StripeController, AuthController, ...authenticatedController],
|
||||
providers: [
|
||||
AuthService,
|
||||
StripeService,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,37 @@
|
|||
import { Controller, Post, Req, Res } from '@nestjs/common';
|
||||
import { Controller, Get, Post, Req, Res } from '@nestjs/common';
|
||||
import {
|
||||
CopilotRuntime,
|
||||
OpenAIAdapter,
|
||||
copilotRuntimeNestEndpoint,
|
||||
} from '@copilotkit/runtime';
|
||||
import { GetOrgFromRequest } from '@gitroom/nestjs-libraries/user/org.from.request';
|
||||
import { Organization } from '@prisma/client';
|
||||
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
|
||||
|
||||
@Controller('/copilot')
|
||||
export class CopilotController {
|
||||
constructor(private _subscriptionService: SubscriptionService) {}
|
||||
@Post('/chat')
|
||||
chat(@Req() req: Request, @Res() res: Response) {
|
||||
const copilotRuntimeHandler = copilotRuntimeNestEndpoint({
|
||||
endpoint: '/copilot/chat',
|
||||
runtime: new CopilotRuntime(),
|
||||
// @ts-ignore
|
||||
serviceAdapter: new OpenAIAdapter({ model: req?.body?.variables?.data?.metadata?.requestType === 'TextareaCompletion' ? 'gpt-4o-mini' : 'gpt-4o' }),
|
||||
serviceAdapter: new OpenAIAdapter({
|
||||
model:
|
||||
// @ts-ignore
|
||||
req?.body?.variables?.data?.metadata?.requestType ===
|
||||
'TextareaCompletion'
|
||||
? 'gpt-4o-mini'
|
||||
: 'gpt-4o',
|
||||
}),
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
return copilotRuntimeHandler(req, res);
|
||||
}
|
||||
}
|
||||
|
||||
@Get('/credits')
|
||||
calculateCredits(@GetOrgFromRequest() organization: Organization) {
|
||||
return this._subscriptionService.checkCredits(organization);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,11 +9,15 @@ import { ApiTags } from '@nestjs/swagger';
|
|||
import handleR2Upload from '@gitroom/nestjs-libraries/upload/r2.uploader';
|
||||
import { FileInterceptor } from '@nestjs/platform-express';
|
||||
import { CustomFileValidationPipe } from '@gitroom/nestjs-libraries/upload/custom.upload.validation';
|
||||
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
|
||||
|
||||
@ApiTags('Media')
|
||||
@Controller('/media')
|
||||
export class MediaController {
|
||||
constructor(private _mediaService: MediaService) {}
|
||||
constructor(
|
||||
private _mediaService: MediaService,
|
||||
private _subscriptionService: SubscriptionService
|
||||
) {}
|
||||
|
||||
@Post('/generate-image')
|
||||
async generateImage(
|
||||
|
|
@ -21,7 +25,12 @@ export class MediaController {
|
|||
@Req() req: Request,
|
||||
@Body('prompt') prompt: string
|
||||
) {
|
||||
return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt)};
|
||||
const total = await this._subscriptionService.checkCredits(org);
|
||||
if (total.credits <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return {output: 'data:image/png;base64,' + await this._mediaService.generateImage(prompt, org)};
|
||||
}
|
||||
|
||||
@Post('/upload-simple')
|
||||
|
|
|
|||
|
|
@ -1,23 +1,44 @@
|
|||
import React from 'react';
|
||||
import React, { useCallback } from 'react';
|
||||
import { observer } from 'mobx-react-lite';
|
||||
import { InputGroup, Button } from '@blueprintjs/core';
|
||||
import { InputGroup } from '@blueprintjs/core';
|
||||
import { Clean } from '@blueprintjs/icons';
|
||||
|
||||
import { SectionTab } from 'polotno/side-panel';
|
||||
import { getKey } from 'polotno/utils/validate-key';
|
||||
import { getImageSize } from 'polotno/utils/image';
|
||||
|
||||
import { ImagesGrid } from 'polotno/side-panel/images-grid';
|
||||
import { getAPI } from 'polotno/utils/api';
|
||||
import { useFetch } from '@gitroom/helpers/utils/custom.fetch';
|
||||
import useSWR from 'swr';
|
||||
import { Button } from '@gitroom/react/form/button';
|
||||
import { useToaster } from '@gitroom/react/toaster/toaster';
|
||||
|
||||
const GenerateTab = observer(({ store }: any) => {
|
||||
const inputRef = React.useRef<any>(null);
|
||||
const [image, setImage] = React.useState(null);
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const fetch = useFetch();
|
||||
const toast = useToaster();
|
||||
|
||||
const loadCredits = useCallback(async () => {
|
||||
return (
|
||||
await fetch(`/copilot/credits`, {
|
||||
method: 'GET',
|
||||
})
|
||||
).json();
|
||||
}, []);
|
||||
|
||||
const {data, mutate} = useSWR('copilot-credits', loadCredits);
|
||||
|
||||
const handleGenerate = async () => {
|
||||
if (data?.credits <= 0) {
|
||||
window.open('/billing', '_blank');
|
||||
return ;
|
||||
}
|
||||
|
||||
if (!inputRef.current.value) {
|
||||
toast.show('Please type your prompt', 'warning');
|
||||
return ;
|
||||
}
|
||||
setLoading(true);
|
||||
setImage(null);
|
||||
|
||||
|
|
@ -33,14 +54,15 @@ const GenerateTab = observer(({ store }: any) => {
|
|||
alert('Something went wrong, please try again later...');
|
||||
return;
|
||||
}
|
||||
const data = await req.json();
|
||||
setImage(data.output);
|
||||
mutate();
|
||||
const newData = await req.json();
|
||||
setImage(newData.output);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div style={{ height: '40px', paddingTop: '5px' }}>
|
||||
Generate image with AI
|
||||
Generate image with AI {data?.credits ? `(${data?.credits} left)` : ``}
|
||||
</div>
|
||||
<InputGroup
|
||||
placeholder="Type your image generation prompt here..."
|
||||
|
|
@ -56,11 +78,10 @@ const GenerateTab = observer(({ store }: any) => {
|
|||
/>
|
||||
<Button
|
||||
onClick={handleGenerate}
|
||||
intent="primary"
|
||||
loading={loading}
|
||||
loading={loading} innerClassName="invert"
|
||||
style={{ marginBottom: '40px' }}
|
||||
>
|
||||
Generate
|
||||
{data?.credits <= 0 ? 'Click to purchase more credits' : 'Generate'}
|
||||
</Button>
|
||||
{image && (
|
||||
<ImagesGrid
|
||||
|
|
|
|||
|
|
@ -1,16 +1,21 @@
|
|||
import {Injectable} from "@nestjs/common";
|
||||
import {MediaRepository} from "@gitroom/nestjs-libraries/database/prisma/media/media.repository";
|
||||
import { OpenaiService } from '@gitroom/nestjs-libraries/openai/openai.service';
|
||||
import { SubscriptionService } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.service';
|
||||
import { Organization } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class MediaService {
|
||||
constructor(
|
||||
private _mediaRepository: MediaRepository,
|
||||
private _openAi: OpenaiService
|
||||
private _openAi: OpenaiService,
|
||||
private _subscriptionService: SubscriptionService
|
||||
){}
|
||||
|
||||
generateImage(prompt: string) {
|
||||
return this._openAi.generateImage(prompt);
|
||||
async generateImage(prompt: string, org: Organization) {
|
||||
const image = await this._openAi.generateImage(prompt);
|
||||
await this._subscriptionService.useCredit(org);
|
||||
return image;
|
||||
}
|
||||
|
||||
saveFile(org: string, fileName: string, filePath: string) {
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ export class OrganizationRepository {
|
|||
subscriptionTier: true,
|
||||
totalChannels: true,
|
||||
isLifetime: true,
|
||||
createdAt: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ model Organization {
|
|||
notifications Notifications[]
|
||||
buyerOrganization MessagesGroup[]
|
||||
usedCodes UsedCodes[]
|
||||
credits Credits[]
|
||||
}
|
||||
|
||||
model User {
|
||||
|
|
@ -168,6 +169,18 @@ model Media {
|
|||
@@index([organizationId])
|
||||
}
|
||||
|
||||
model Credits {
|
||||
id String @id @default(uuid())
|
||||
organization Organization @relation(fields: [organizationId], references: [id])
|
||||
organizationId String
|
||||
credits Int
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@index([organizationId])
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
model Subscription {
|
||||
id String @id @default(cuid())
|
||||
organizationId String @unique
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { Injectable } from '@nestjs/common';
|
||||
import { PrismaRepository } from '@gitroom/nestjs-libraries/database/prisma/prisma.service';
|
||||
import dayjs from 'dayjs';
|
||||
import { Organization } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class SubscriptionRepository {
|
||||
|
|
@ -7,6 +9,7 @@ export class SubscriptionRepository {
|
|||
private readonly _subscription: PrismaRepository<'subscription'>,
|
||||
private readonly _organization: PrismaRepository<'organization'>,
|
||||
private readonly _user: PrismaRepository<'user'>,
|
||||
private readonly _credits: PrismaRepository<'credits'>,
|
||||
private _usedCodes: PrismaRepository<'usedCodes'>
|
||||
) {}
|
||||
|
||||
|
|
@ -173,4 +176,30 @@ export class SubscriptionRepository {
|
|||
},
|
||||
});
|
||||
}
|
||||
|
||||
async getCreditsFrom(organizationId: string, from: dayjs.Dayjs) {
|
||||
const load = await this._credits.model.credits.groupBy({
|
||||
by: ['organizationId'],
|
||||
where: {
|
||||
organizationId,
|
||||
createdAt: {
|
||||
gte: from.toDate(),
|
||||
},
|
||||
},
|
||||
_sum: {
|
||||
credits: true,
|
||||
},
|
||||
});
|
||||
|
||||
return load?.[0]?._sum?.credits || 0;
|
||||
}
|
||||
|
||||
useCredit(org: Organization) {
|
||||
return this._credits.model.credits.create({
|
||||
data: {
|
||||
organizationId: org.id,
|
||||
credits: 1,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import { pricing } from '@gitroom/nestjs-libraries/database/prisma/subscriptions
|
|||
import { SubscriptionRepository } from '@gitroom/nestjs-libraries/database/prisma/subscriptions/subscription.repository';
|
||||
import { IntegrationService } from '@gitroom/nestjs-libraries/database/prisma/integrations/integration.service';
|
||||
import { OrganizationService } from '@gitroom/nestjs-libraries/database/prisma/organizations/organization.service';
|
||||
import { Organization } from '@prisma/client';
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
@Injectable()
|
||||
export class SubscriptionService {
|
||||
|
|
@ -18,6 +20,10 @@ export class SubscriptionService {
|
|||
);
|
||||
}
|
||||
|
||||
useCredit(organization: Organization) {
|
||||
return this._subscriptionRepository.useCredit(organization);
|
||||
}
|
||||
|
||||
getCode(code: string) {
|
||||
return this._subscriptionRepository.getCode(code);
|
||||
}
|
||||
|
|
@ -152,4 +158,28 @@ export class SubscriptionService {
|
|||
async getSubscription(organizationId: string) {
|
||||
return this._subscriptionRepository.getSubscription(organizationId);
|
||||
}
|
||||
|
||||
async checkCredits(organization: Organization) {
|
||||
// @ts-ignore
|
||||
const type = organization?.subscription?.subscriptionTier || 'FREE';
|
||||
|
||||
if (type === 'FREE') {
|
||||
return {credits: 0};
|
||||
}
|
||||
|
||||
// @ts-ignore
|
||||
let date = dayjs(organization.subscription.createdAt);
|
||||
while (date.isBefore(dayjs())) {
|
||||
date = date.add(1, 'month');
|
||||
}
|
||||
|
||||
const checkFromMonth = date.subtract(1, 'month');
|
||||
const imageGenerationCount = pricing[type].image_generation_count;
|
||||
|
||||
const totalUse = await this._subscriptionRepository.getCreditsFrom(organization.id, checkFromMonth);
|
||||
|
||||
return {
|
||||
credits: imageGenerationCount - totalUse,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue