|
| 1 | +/* |
| 2 | + * |
| 3 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 4 | + * SPDX-License-Identifier: MIT-0 |
| 5 | + */ |
| 6 | +import { Tracer } from '@aws-lambda-powertools/tracer' |
| 7 | +import { captureLambdaHandler } from '@aws-lambda-powertools/tracer/middleware' |
| 8 | +import { Logger } from '@aws-lambda-powertools/logger' |
| 9 | +import { injectLambdaContext } from '@aws-lambda-powertools/logger/middleware' |
| 10 | +import middy from '@middy/core' |
| 11 | +import { type FeedArticle } from '../../shared/common' |
| 12 | +import { |
| 13 | + DynamoDBClient, |
| 14 | + GetItemCommand, |
| 15 | + GetItemCommandInput |
| 16 | +} from '@aws-sdk/client-dynamodb' |
| 17 | +import axios from 'axios' |
| 18 | +import * as cheerio from 'cheerio' |
| 19 | +import { |
| 20 | + BedrockRuntimeClient, |
| 21 | + InvokeModelCommand, |
| 22 | + InvokeModelCommandInput |
| 23 | +} from '@aws-sdk/client-bedrock-runtime' |
| 24 | + |
| 25 | +const SERVICE_NAME = 'filter-articles-with-bedrock' |
| 26 | + |
| 27 | +const tracer = new Tracer({ serviceName: SERVICE_NAME }) |
| 28 | +const logger = new Logger({ serviceName: SERVICE_NAME }) |
| 29 | + |
| 30 | +const dynamodb = tracer.captureAWSv3Client(new DynamoDBClient()) |
| 31 | +const bedrockRuntimeClient = tracer.captureAWSv3Client( |
| 32 | + new BedrockRuntimeClient() |
| 33 | +) |
| 34 | + |
| 35 | +const DATA_FEED_TABLE = process.env.DATA_FEED_TABLE |
| 36 | +const BEDROCK_MODEL_ID = 'anthropic.claude-3-haiku-20240307-v1:0' |
| 37 | + |
| 38 | +interface FilterArticlesWithBedrockInput { |
| 39 | + dataFeedId: string |
| 40 | + articles: FeedArticle[] |
| 41 | +} |
| 42 | + |
| 43 | +const lambdaHandler = async ( |
| 44 | + event: FilterArticlesWithBedrockInput |
| 45 | +): Promise<FeedArticle[]> => { |
| 46 | + const { dataFeedId, articles } = event |
| 47 | + logger.debug('Filtering articles with Bedrock for Data Feed ID ', dataFeedId) |
| 48 | + logger.debug('Unfiltered new article count = ', { |
| 49 | + articleLength: articles.length |
| 50 | + }) |
| 51 | + const filteredArticles = await filterArticlesWithBedrock(articles, dataFeedId) |
| 52 | + logger.debug('Filtered article count = ' + filteredArticles.length) |
| 53 | + return filteredArticles |
| 54 | +} |
| 55 | + |
| 56 | +const filterArticlesWithBedrock = async ( |
| 57 | + articles: FeedArticle[], |
| 58 | + dataFeedId: string |
| 59 | +): Promise<FeedArticle[]> => { |
| 60 | + const filteredArticles: FeedArticle[] = [] |
| 61 | + const filterPrompt = await getFilterPrompt(dataFeedId) |
| 62 | + if (filterPrompt === null) { |
| 63 | + return articles |
| 64 | + } |
| 65 | + for (const article of articles) { |
| 66 | + logger.debug('Working on article', { article }) |
| 67 | + const siteContent = await getSiteContent(article.url) |
| 68 | + if (siteContent !== null) { |
| 69 | + const isFiltered = await isArticleFilteredWithBedrock( |
| 70 | + siteContent, |
| 71 | + filterPrompt |
| 72 | + ) |
| 73 | + if (!isFiltered) { |
| 74 | + console.debug('Article passed filter: ' + article.title) |
| 75 | + filteredArticles.push(article) |
| 76 | + } else { |
| 77 | + console.debug('Article filtered out: ' + article.title) |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + return filteredArticles |
| 82 | +} |
| 83 | + |
| 84 | +const getFilterPrompt = async (dataFeedId: string): Promise<string | null> => { |
| 85 | + // Get the filter prompt from dynamoDB using the dataFeedId |
| 86 | + logger.debug('Getting filter prompt for data feed ', dataFeedId) |
| 87 | + const input: GetItemCommandInput = { |
| 88 | + Key: { |
| 89 | + dataFeedId: { |
| 90 | + S: dataFeedId |
| 91 | + }, |
| 92 | + sk: { |
| 93 | + S: 'dataFeed' |
| 94 | + } |
| 95 | + }, |
| 96 | + TableName: DATA_FEED_TABLE, |
| 97 | + AttributesToGet: ['articleFilterPrompt'] |
| 98 | + } |
| 99 | + const command = new GetItemCommand(input) |
| 100 | + const result = await dynamodb.send(command) |
| 101 | + if ( |
| 102 | + result.Item !== undefined && |
| 103 | + result.Item.articleFilterPrompt?.S !== undefined |
| 104 | + ) { |
| 105 | + logger.debug( |
| 106 | + 'Filter prompt found for data feed ' + result.Item.articleFilterPrompt.S, |
| 107 | + dataFeedId |
| 108 | + ) |
| 109 | + return result.Item.articleFilterPrompt.S |
| 110 | + } else { |
| 111 | + logger.debug('No filter prompt found for data feed ', dataFeedId) |
| 112 | + return null |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +const isArticleFilteredWithBedrock = async ( |
| 117 | + articleContent: string, |
| 118 | + filterPrompt: string |
| 119 | +): Promise<boolean> => { |
| 120 | + if (filterPrompt === null) { |
| 121 | + return false |
| 122 | + } |
| 123 | + const prompt = |
| 124 | + 'You are an agent responsible for reading articles and determining if the article should be filtered out based on the filter prompt.' + |
| 125 | + "Is the article filtered out based on the filter prompt? Return either 'true' or 'false'." + |
| 126 | + "If the article is filtered out, return 'true', otherwise return 'false'." + |
| 127 | + 'Here is the article content:\n' + |
| 128 | + '<article>' + |
| 129 | + articleContent + |
| 130 | + '</article>\n' + |
| 131 | + 'Here is the filter prompt:\n' + |
| 132 | + '<filter_prompt>' + |
| 133 | + filterPrompt + |
| 134 | + '</filter_prompt>' + |
| 135 | + "Only return 'true' if the article is filtered out based on the filter prompt. Do not return any other content." + |
| 136 | + 'Place the response in a <filter_response> xml tag.' |
| 137 | + |
| 138 | + const input: InvokeModelCommandInput = { |
| 139 | + modelId: BEDROCK_MODEL_ID, |
| 140 | + contentType: 'application/json', |
| 141 | + accept: '*/*', |
| 142 | + body: new TextEncoder().encode( |
| 143 | + JSON.stringify({ |
| 144 | + max_tokens: 1000, |
| 145 | + anthropic_version: 'bedrock-2023-05-31', |
| 146 | + messages: [ |
| 147 | + { |
| 148 | + role: 'user', |
| 149 | + content: [ |
| 150 | + { |
| 151 | + type: 'text', |
| 152 | + text: prompt |
| 153 | + } |
| 154 | + ] |
| 155 | + } |
| 156 | + ] |
| 157 | + }) |
| 158 | + ) |
| 159 | + } |
| 160 | + const command = new InvokeModelCommand(input) |
| 161 | + const response = await bedrockRuntimeClient.send(command) |
| 162 | + const responseText = new TextDecoder().decode(response.body) |
| 163 | + console.debug('Response from Bedrock: ' + responseText) |
| 164 | + const responseObject = JSON.parse(responseText) |
| 165 | + return extractResponseValue(responseObject.content[0].text, 'filter_response') |
| 166 | +} |
| 167 | + |
| 168 | +const getSiteContent = async (url: string): Promise<string | null> => { |
| 169 | + logger.debug(`getSiteContent Called; url = ${url}`) |
| 170 | + tracer.putMetadata('url', url) |
| 171 | + let $: cheerio.Root |
| 172 | + try { |
| 173 | + logger.debug('URL of Provided Site = ' + url) |
| 174 | + const response = await axios.get(url) |
| 175 | + tracer.putAnnotation('url', 'Successfully Crawled') |
| 176 | + const text = response.data as string |
| 177 | + $ = cheerio.load(text) |
| 178 | + // Cutting out elements that aren't needed |
| 179 | + $('footer').remove() |
| 180 | + $('header').remove() |
| 181 | + $('script').remove() |
| 182 | + $('style').remove() |
| 183 | + $('nav').remove() |
| 184 | + } catch (error) { |
| 185 | + logger.error(`Failed to crawl; url = ${url}`) |
| 186 | + logger.error(JSON.stringify(error)) |
| 187 | + tracer.addErrorAsMetadata(error as Error) |
| 188 | + throw error |
| 189 | + } |
| 190 | + let articleText: string = '' |
| 191 | + if ($('article').length > 0) { |
| 192 | + articleText = $('article').text() |
| 193 | + } else { |
| 194 | + articleText = $('body').text() |
| 195 | + } |
| 196 | + if (articleText !== undefined) { |
| 197 | + return articleText |
| 198 | + } else { |
| 199 | + return null |
| 200 | + } |
| 201 | +} |
| 202 | + |
| 203 | +const extractResponseValue = (response: string, xml_tag: string): boolean => { |
| 204 | + const formattedInput = response |
| 205 | + .replace(/(\r\n|\n|\r)/gm, '') |
| 206 | + .replace(/\\n/g, '') |
| 207 | + const open_tag = `<${xml_tag}>` |
| 208 | + const close_tag = `</${xml_tag}>` |
| 209 | + const regex = new RegExp(`(?<=${open_tag})(.*?)(?=${close_tag})`, 'g') |
| 210 | + const match = formattedInput.match(regex) |
| 211 | + const isFiltered = match?.[0].toLocaleLowerCase() === 'true' |
| 212 | + return isFiltered |
| 213 | +} |
| 214 | + |
| 215 | +export const handler = middy() |
| 216 | + .handler(lambdaHandler) |
| 217 | + .use(captureLambdaHandler(tracer, { captureResponse: false })) |
| 218 | + .use(injectLambdaContext(logger)) |
0 commit comments