419 lines
12 KiB
TypeScript
419 lines
12 KiB
TypeScript
import { Injectable } from '@nestjs/common';
|
|
import {
|
|
BaseMessage,
|
|
HumanMessage,
|
|
ToolMessage,
|
|
} from '@langchain/core/messages';
|
|
import { END, START, StateGraph } from '@langchain/langgraph';
|
|
import { ChatOpenAI, DallEAPIWrapper } from '@langchain/openai';
|
|
import { TavilySearchResults } from '@langchain/community/tools/tavily_search';
|
|
import { ToolNode } from '@langchain/langgraph/prebuilt';
|
|
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
|
import dayjs from 'dayjs';
|
|
import { PostsService } from '@gitroom/nestjs-libraries/database/prisma/posts/posts.service';
|
|
import { z } from 'zod';
|
|
import { MediaService } from '@gitroom/nestjs-libraries/database/prisma/media/media.service';
|
|
import { UploadFactory } from '@gitroom/nestjs-libraries/upload/upload.factory';
|
|
import { GeneratorDto } from '@gitroom/nestjs-libraries/dtos/generator/generator.dto';
|
|
|
|
const tools = !process.env.TAVILY_API_KEY ? [] : [new TavilySearchResults({ maxResults: 3 })];
|
|
const toolNode = new ToolNode(tools);
|
|
|
|
const model = new ChatOpenAI({
|
|
apiKey: process.env.OPENAI_API_KEY || 'sk-proj-',
|
|
model: 'gpt-4o-2024-08-06',
|
|
temperature: 0.7,
|
|
});
|
|
|
|
const dalle = new DallEAPIWrapper({
|
|
apiKey: process.env.OPENAI_API_KEY || 'sk-proj-',
|
|
model: 'dall-e-3',
|
|
});
|
|
|
|
interface WorkflowChannelsState {
|
|
messages: BaseMessage[];
|
|
orgId: string;
|
|
question: string;
|
|
hook?: string;
|
|
fresearch?: string;
|
|
category?: string;
|
|
topic?: string;
|
|
date?: string;
|
|
format: 'one_short' | 'one_long' | 'thread_short' | 'thread_long';
|
|
tone: 'personal' | 'company';
|
|
content?: {
|
|
content: string;
|
|
website?: string;
|
|
prompt?: string;
|
|
image?: string;
|
|
}[];
|
|
isPicture?: boolean;
|
|
popularPosts?: { content: string; hook: string }[];
|
|
}
|
|
|
|
const category = z.object({
|
|
category: z.string().describe('The category for the post'),
|
|
});
|
|
|
|
const topic = z.object({
|
|
topic: z.string().describe('The topic for the post'),
|
|
});
|
|
|
|
const hook = z.object({
|
|
hook: z
|
|
.string()
|
|
.describe(
|
|
'Hook for the new post, don\'t take it from "the request of the user"'
|
|
),
|
|
});
|
|
|
|
const contentZod = (
|
|
isPicture: boolean,
|
|
format: 'one_short' | 'one_long' | 'thread_short' | 'thread_long'
|
|
) => {
|
|
const content = z.object({
|
|
content: z.string().describe('Content for the new post'),
|
|
website: z
|
|
.string()
|
|
.optional()
|
|
.describe(
|
|
"Website for the new post if exists, If one of the post present a brand, website link must be to the root domain of the brand or don't include it, website url should contain the brand name"
|
|
),
|
|
...(isPicture
|
|
? {
|
|
prompt: z
|
|
.string()
|
|
.describe(
|
|
"Prompt to generate a picture for this post later, make sure it doesn't contain brand names and make it very descriptive in terms of style"
|
|
),
|
|
}
|
|
: {}),
|
|
});
|
|
|
|
return z.object({
|
|
content:
|
|
format === 'one_short' || format === 'one_long'
|
|
? content
|
|
: z.array(content).min(2).describe(`Content for the new post`),
|
|
});
|
|
};
|
|
|
|
@Injectable()
|
|
export class AgentGraphService {
|
|
private storage = UploadFactory.createStorage();
|
|
constructor(
|
|
private _postsService: PostsService,
|
|
private _mediaService: MediaService
|
|
) {}
|
|
static state = () =>
|
|
new StateGraph<WorkflowChannelsState>({
|
|
channels: {
|
|
messages: {
|
|
reducer: (currentState, updateValue) =>
|
|
currentState.concat(updateValue),
|
|
default: () => [],
|
|
},
|
|
fresearch: null,
|
|
format: null,
|
|
tone: null,
|
|
question: null,
|
|
orgId: null,
|
|
hook: null,
|
|
content: null,
|
|
date: null,
|
|
category: null,
|
|
popularPosts: null,
|
|
topic: null,
|
|
isPicture: null,
|
|
},
|
|
});
|
|
|
|
async startCall(state: WorkflowChannelsState) {
|
|
const runTools = model.bindTools(tools);
|
|
const response = await ChatPromptTemplate.fromTemplate(
|
|
`
|
|
Today is ${dayjs().format()}, You are an assistant that gets a social media post or requests for a social media post.
|
|
You research should be on the most possible recent data.
|
|
You concat the text of the request together with an internet research based on the text.
|
|
{text}
|
|
`
|
|
)
|
|
.pipe(runTools)
|
|
.invoke({
|
|
text: state.messages[state.messages.length - 1].content,
|
|
});
|
|
|
|
return { messages: [response] };
|
|
}
|
|
|
|
async saveResearch(state: WorkflowChannelsState) {
|
|
const content = state.messages.filter((f) => f instanceof ToolMessage);
|
|
return { fresearch: content };
|
|
}
|
|
|
|
async findCategories(state: WorkflowChannelsState) {
|
|
const allCategories = await this._postsService.findAllExistingCategories();
|
|
const structuredOutput = model.withStructuredOutput(category);
|
|
const { category: outputCategory } = await ChatPromptTemplate.fromTemplate(
|
|
`
|
|
You are an assistant that gets a text that will be later summarized into a social media post
|
|
and classify it to one of the following categories: {categories}
|
|
text: {text}
|
|
`
|
|
)
|
|
.pipe(structuredOutput)
|
|
.invoke({
|
|
categories: allCategories.map((p) => p.category).join(', '),
|
|
text: state.fresearch,
|
|
});
|
|
|
|
return {
|
|
category: outputCategory,
|
|
};
|
|
}
|
|
|
|
async findTopic(state: WorkflowChannelsState) {
|
|
const allTopics = await this._postsService.findAllExistingTopicsOfCategory(
|
|
state?.category!
|
|
);
|
|
if (allTopics.length === 0) {
|
|
return { topic: null };
|
|
}
|
|
|
|
const structuredOutput = model.withStructuredOutput(topic);
|
|
const { topic: outputTopic } = await ChatPromptTemplate.fromTemplate(
|
|
`
|
|
You are an assistant that gets a text that will be later summarized into a social media post
|
|
and classify it to one of the following topics: {topics}
|
|
text: {text}
|
|
`
|
|
)
|
|
.pipe(structuredOutput)
|
|
.invoke({
|
|
topics: allTopics.map((p) => p.topic).join(', '),
|
|
text: state.fresearch,
|
|
});
|
|
|
|
return {
|
|
topic: outputTopic,
|
|
};
|
|
}
|
|
|
|
async findPopularPosts(state: WorkflowChannelsState) {
|
|
const popularPosts = await this._postsService.findPopularPosts(
|
|
state.category!,
|
|
state.topic
|
|
);
|
|
return { popularPosts };
|
|
}
|
|
|
|
async generateHook(state: WorkflowChannelsState) {
|
|
const structuredOutput = model.withStructuredOutput(hook);
|
|
const { hook: outputHook } = await ChatPromptTemplate.fromTemplate(
|
|
`
|
|
You are an assistant that gets content for a social media post, and generate only the hook.
|
|
The hook is the 1-2 sentences of the post that will be used to grab the attention of the reader.
|
|
You will be provided existing hooks you should use as inspiration.
|
|
- Avoid weird hook that starts with "Discover the secret...", "The best...", "The most...", "The top..."
|
|
- Make sure it sounds ${state.tone}
|
|
- Use ${state.tone === 'personal' ? '1st' : '3rd'} person mode
|
|
- Make sure it's engaging
|
|
- Don't be cringy
|
|
- Use simple english
|
|
- Make sure you add "\n" between the lines
|
|
- Don't take the hook from "request of the user"
|
|
|
|
<!-- BEGIN request of the user -->
|
|
{request}
|
|
<!-- END request of the user -->
|
|
|
|
<!-- BEGIN existing hooks -->
|
|
{hooks}
|
|
<!-- END existing hooks -->
|
|
|
|
<!-- BEGIN current content -->
|
|
{text}
|
|
<!-- END current content -->
|
|
|
|
`
|
|
)
|
|
.pipe(structuredOutput)
|
|
.invoke({
|
|
request: state.messages[0].content,
|
|
hooks: state.popularPosts!.map((p) => p.hook).join('\n'),
|
|
text: state.fresearch,
|
|
});
|
|
|
|
return {
|
|
hook: outputHook,
|
|
};
|
|
}
|
|
|
|
async generateContent(state: WorkflowChannelsState) {
|
|
const structuredOutput = model.withStructuredOutput(
|
|
contentZod(!!state.isPicture, state.format)
|
|
);
|
|
const { content: outputContent } = await ChatPromptTemplate.fromTemplate(
|
|
`
|
|
You are an assistant that gets existing hook of a social media, content and generate only the content.
|
|
- Don't add any hashtags
|
|
- Make sure it sounds ${state.tone}
|
|
- Use ${state.tone === 'personal' ? '1st' : '3rd'} person mode
|
|
- ${
|
|
state.format === 'one_short' || state.format === 'thread_short'
|
|
? 'Post should be maximum 200 chars to fit twitter'
|
|
: 'Post should be long'
|
|
}
|
|
- ${
|
|
state.format === 'one_short' || state.format === 'one_long'
|
|
? 'Post should have only 1 item'
|
|
: 'Post should have minimum 2 items'
|
|
}
|
|
- Use the hook as inspiration
|
|
- Make sure it's engaging
|
|
- Don't be cringy
|
|
- Use simple english
|
|
- The Content should not contain the hook
|
|
- Try to put some call to action at the end of the post
|
|
- Make sure you add "\n" between the lines
|
|
- Add "\n" after every "."
|
|
|
|
Hook:
|
|
{hook}
|
|
|
|
User request:
|
|
{request}
|
|
|
|
current content information:
|
|
{information}
|
|
`
|
|
)
|
|
.pipe(structuredOutput)
|
|
.invoke({
|
|
hook: state.hook,
|
|
request: state.messages[0].content,
|
|
information: state.fresearch,
|
|
});
|
|
|
|
return {
|
|
content: outputContent,
|
|
};
|
|
}
|
|
|
|
async fixArray(state: WorkflowChannelsState) {
|
|
if (state.format === 'one_short' || state.format === 'one_long') {
|
|
return {
|
|
content: [state.content],
|
|
};
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
async generatePictures(state: WorkflowChannelsState) {
|
|
if (!state.isPicture) {
|
|
return {};
|
|
}
|
|
|
|
const newContent = await Promise.all(
|
|
(state.content || []).map(async (p) => {
|
|
const image = await dalle.invoke(p.prompt!);
|
|
return {
|
|
...p,
|
|
image,
|
|
};
|
|
})
|
|
);
|
|
|
|
return {
|
|
content: newContent,
|
|
};
|
|
}
|
|
|
|
async uploadPictures(state: WorkflowChannelsState) {
|
|
const all = await Promise.all(
|
|
(state.content || []).map(async (p) => {
|
|
if (p.image) {
|
|
const upload = await this.storage.uploadSimple(p.image);
|
|
const name = upload.split('/').pop()!;
|
|
const uploadWithId = await this._mediaService.saveFile(
|
|
state.orgId,
|
|
name,
|
|
upload
|
|
);
|
|
|
|
return {
|
|
...p,
|
|
image: uploadWithId,
|
|
};
|
|
}
|
|
|
|
return p;
|
|
})
|
|
);
|
|
|
|
return { content: all };
|
|
}
|
|
|
|
async isGeneratePicture(state: WorkflowChannelsState) {
|
|
if (state.isPicture) {
|
|
return 'generate-picture';
|
|
}
|
|
|
|
return 'post-time';
|
|
}
|
|
|
|
async postDateTime(state: WorkflowChannelsState) {
|
|
return { date: await this._postsService.findFreeDateTime(state.orgId) };
|
|
}
|
|
|
|
start(orgId: string, body: GeneratorDto) {
|
|
const state = AgentGraphService.state();
|
|
const workflow = state
|
|
.addNode('agent', this.startCall.bind(this))
|
|
.addNode('research', toolNode)
|
|
.addNode('save-research', this.saveResearch.bind(this))
|
|
.addNode('find-category', this.findCategories.bind(this))
|
|
.addNode('find-topic', this.findTopic.bind(this))
|
|
.addNode('find-popular-posts', this.findPopularPosts.bind(this))
|
|
.addNode('generate-hook', this.generateHook.bind(this))
|
|
.addNode('generate-content', this.generateContent.bind(this))
|
|
.addNode('generate-content-fix', this.fixArray.bind(this))
|
|
.addNode('generate-picture', this.generatePictures.bind(this))
|
|
.addNode('upload-pictures', this.uploadPictures.bind(this))
|
|
.addNode('post-time', this.postDateTime.bind(this))
|
|
.addEdge(START, 'agent')
|
|
.addEdge('agent', 'research')
|
|
.addEdge('research', 'save-research')
|
|
.addEdge('save-research', 'find-category')
|
|
.addEdge('find-category', 'find-topic')
|
|
.addEdge('find-topic', 'find-popular-posts')
|
|
.addEdge('find-popular-posts', 'generate-hook')
|
|
.addEdge('generate-hook', 'generate-content')
|
|
.addEdge('generate-content', 'generate-content-fix')
|
|
.addConditionalEdges(
|
|
'generate-content-fix',
|
|
this.isGeneratePicture.bind(this)
|
|
)
|
|
.addEdge('generate-picture', 'upload-pictures')
|
|
.addEdge('upload-pictures', 'post-time')
|
|
.addEdge('post-time', END);
|
|
|
|
const app = workflow.compile();
|
|
|
|
return app.streamEvents(
|
|
{
|
|
messages: [new HumanMessage(body.research)],
|
|
isPicture: body.isPicture,
|
|
format: body.format,
|
|
tone: body.tone,
|
|
orgId,
|
|
},
|
|
{
|
|
streamMode: 'values',
|
|
version: 'v2',
|
|
}
|
|
);
|
|
}
|
|
}
|