import { Injectable } from '@nestjs/common'; import { BaseMessage, HumanMessage, ToolMessage, } from '@langchain/core/messages'; import { END, START, StateGraph } from '@langchain/langgraph'; import { ChatOpenAI, DallEAPIWrapper } from '@langchain/openai'; import { TavilySearchResults } from '@langchain/community/tools/tavily_search'; import { ToolNode } from '@langchain/langgraph/prebuilt'; import { ChatPromptTemplate } from '@langchain/core/prompts'; import dayjs from 'dayjs'; import { PostsService } from '@gitroom/nestjs-libraries/database/prisma/posts/posts.service'; import { z } from 'zod'; import { MediaService } from '@gitroom/nestjs-libraries/database/prisma/media/media.service'; import { UploadFactory } from '@gitroom/nestjs-libraries/upload/upload.factory'; import { GeneratorDto } from '@gitroom/nestjs-libraries/dtos/generator/generator.dto'; const tools = !process.env.TAVILY_API_KEY ? [] : [new TavilySearchResults({ maxResults: 3 })]; const toolNode = new ToolNode(tools); const model = new ChatOpenAI({ apiKey: process.env.OPENAI_API_KEY || 'sk-proj-', model: 'gpt-4o-2024-08-06', temperature: 0.7, }); const dalle = new DallEAPIWrapper({ apiKey: process.env.OPENAI_API_KEY || 'sk-proj-', model: 'dall-e-3', }); interface WorkflowChannelsState { messages: BaseMessage[]; orgId: string; question: string; hook?: string; fresearch?: string; category?: string; topic?: string; date?: string; format: 'one_short' | 'one_long' | 'thread_short' | 'thread_long'; tone: 'personal' | 'company'; content?: { content: string; website?: string; prompt?: string; image?: string; }[]; isPicture?: boolean; popularPosts?: { content: string; hook: string }[]; } const category = z.object({ category: z.string().describe('The category for the post'), }); const topic = z.object({ topic: z.string().describe('The topic for the post'), }); const hook = z.object({ hook: z .string() .describe( 'Hook for the new post, don\'t take it from "the request of the user"' ), }); const contentZod = ( isPicture: boolean, format: 'one_short' | 'one_long' | 'thread_short' | 'thread_long' ) => { const content = z.object({ content: z.string().describe('Content for the new post'), website: z .string() .optional() .describe( "Website for the new post if exists, If one of the post present a brand, website link must be to the root domain of the brand or don't include it, website url should contain the brand name" ), ...(isPicture ? { prompt: z .string() .describe( "Prompt to generate a picture for this post later, make sure it doesn't contain brand names and make it very descriptive in terms of style" ), } : {}), }); return z.object({ content: format === 'one_short' || format === 'one_long' ? content : z.array(content).min(2).describe(`Content for the new post`), }); }; @Injectable() export class AgentGraphService { private storage = UploadFactory.createStorage(); constructor( private _postsService: PostsService, private _mediaService: MediaService ) {} static state = () => new StateGraph({ channels: { messages: { reducer: (currentState, updateValue) => currentState.concat(updateValue), default: () => [], }, fresearch: null, format: null, tone: null, question: null, orgId: null, hook: null, content: null, date: null, category: null, popularPosts: null, topic: null, isPicture: null, }, }); async startCall(state: WorkflowChannelsState) { const runTools = model.bindTools(tools); const response = await ChatPromptTemplate.fromTemplate( ` Today is ${dayjs().format()}, You are an assistant that gets a social media post or requests for a social media post. You research should be on the most possible recent data. You concat the text of the request together with an internet research based on the text. {text} ` ) .pipe(runTools) .invoke({ text: state.messages[state.messages.length - 1].content, }); return { messages: [response] }; } async saveResearch(state: WorkflowChannelsState) { const content = state.messages.filter((f) => f instanceof ToolMessage); return { fresearch: content }; } async findCategories(state: WorkflowChannelsState) { const allCategories = await this._postsService.findAllExistingCategories(); const structuredOutput = model.withStructuredOutput(category); const { category: outputCategory } = await ChatPromptTemplate.fromTemplate( ` You are an assistant that gets a text that will be later summarized into a social media post and classify it to one of the following categories: {categories} text: {text} ` ) .pipe(structuredOutput) .invoke({ categories: allCategories.map((p) => p.category).join(', '), text: state.fresearch, }); return { category: outputCategory, }; } async findTopic(state: WorkflowChannelsState) { const allTopics = await this._postsService.findAllExistingTopicsOfCategory( state?.category! ); if (allTopics.length === 0) { return { topic: null }; } const structuredOutput = model.withStructuredOutput(topic); const { topic: outputTopic } = await ChatPromptTemplate.fromTemplate( ` You are an assistant that gets a text that will be later summarized into a social media post and classify it to one of the following topics: {topics} text: {text} ` ) .pipe(structuredOutput) .invoke({ topics: allTopics.map((p) => p.topic).join(', '), text: state.fresearch, }); return { topic: outputTopic, }; } async findPopularPosts(state: WorkflowChannelsState) { const popularPosts = await this._postsService.findPopularPosts( state.category!, state.topic ); return { popularPosts }; } async generateHook(state: WorkflowChannelsState) { const structuredOutput = model.withStructuredOutput(hook); const { hook: outputHook } = await ChatPromptTemplate.fromTemplate( ` You are an assistant that gets content for a social media post, and generate only the hook. The hook is the 1-2 sentences of the post that will be used to grab the attention of the reader. You will be provided existing hooks you should use as inspiration. - Avoid weird hook that starts with "Discover the secret...", "The best...", "The most...", "The top..." - Make sure it sounds ${state.tone} - Use ${state.tone === 'personal' ? '1st' : '3rd'} person mode - Make sure it's engaging - Don't be cringy - Use simple english - Make sure you add "\n" between the lines - Don't take the hook from "request of the user" {request} {hooks} {text} ` ) .pipe(structuredOutput) .invoke({ request: state.messages[0].content, hooks: state.popularPosts!.map((p) => p.hook).join('\n'), text: state.fresearch, }); return { hook: outputHook, }; } async generateContent(state: WorkflowChannelsState) { const structuredOutput = model.withStructuredOutput( contentZod(!!state.isPicture, state.format) ); const { content: outputContent } = await ChatPromptTemplate.fromTemplate( ` You are an assistant that gets existing hook of a social media, content and generate only the content. - Don't add any hashtags - Make sure it sounds ${state.tone} - Use ${state.tone === 'personal' ? '1st' : '3rd'} person mode - ${ state.format === 'one_short' || state.format === 'thread_short' ? 'Post should be maximum 200 chars to fit twitter' : 'Post should be long' } - ${ state.format === 'one_short' || state.format === 'one_long' ? 'Post should have only 1 item' : 'Post should have minimum 2 items' } - Use the hook as inspiration - Make sure it's engaging - Don't be cringy - Use simple english - The Content should not contain the hook - Try to put some call to action at the end of the post - Make sure you add "\n" between the lines - Add "\n" after every "." Hook: {hook} User request: {request} current content information: {information} ` ) .pipe(structuredOutput) .invoke({ hook: state.hook, request: state.messages[0].content, information: state.fresearch, }); return { content: outputContent, }; } async fixArray(state: WorkflowChannelsState) { if (state.format === 'one_short' || state.format === 'one_long') { return { content: [state.content], }; } return {}; } async generatePictures(state: WorkflowChannelsState) { if (!state.isPicture) { return {}; } const newContent = await Promise.all( (state.content || []).map(async (p) => { const image = await dalle.invoke(p.prompt!); return { ...p, image, }; }) ); return { content: newContent, }; } async uploadPictures(state: WorkflowChannelsState) { const all = await Promise.all( (state.content || []).map(async (p) => { if (p.image) { const upload = await this.storage.uploadSimple(p.image); const name = upload.split('/').pop()!; const uploadWithId = await this._mediaService.saveFile( state.orgId, name, upload ); return { ...p, image: uploadWithId, }; } return p; }) ); return { content: all }; } async isGeneratePicture(state: WorkflowChannelsState) { if (state.isPicture) { return 'generate-picture'; } return 'post-time'; } async postDateTime(state: WorkflowChannelsState) { return { date: await this._postsService.findFreeDateTime(state.orgId) }; } start(orgId: string, body: GeneratorDto) { const state = AgentGraphService.state(); const workflow = state .addNode('agent', this.startCall.bind(this)) .addNode('research', toolNode) .addNode('save-research', this.saveResearch.bind(this)) .addNode('find-category', this.findCategories.bind(this)) .addNode('find-topic', this.findTopic.bind(this)) .addNode('find-popular-posts', this.findPopularPosts.bind(this)) .addNode('generate-hook', this.generateHook.bind(this)) .addNode('generate-content', this.generateContent.bind(this)) .addNode('generate-content-fix', this.fixArray.bind(this)) .addNode('generate-picture', this.generatePictures.bind(this)) .addNode('upload-pictures', this.uploadPictures.bind(this)) .addNode('post-time', this.postDateTime.bind(this)) .addEdge(START, 'agent') .addEdge('agent', 'research') .addEdge('research', 'save-research') .addEdge('save-research', 'find-category') .addEdge('find-category', 'find-topic') .addEdge('find-topic', 'find-popular-posts') .addEdge('find-popular-posts', 'generate-hook') .addEdge('generate-hook', 'generate-content') .addEdge('generate-content', 'generate-content-fix') .addConditionalEdges( 'generate-content-fix', this.isGeneratePicture.bind(this) ) .addEdge('generate-picture', 'upload-pictures') .addEdge('upload-pictures', 'post-time') .addEdge('post-time', END); const app = workflow.compile(); return app.streamEvents( { messages: [new HumanMessage(body.research)], isPicture: body.isPicture, format: body.format, tone: body.tone, orgId, }, { streamMode: 'values', version: 'v2', } ); } }