Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sample code and add claude v3 code generation #220

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
167 changes: 140 additions & 27 deletions 01_Text_generation/01_code_generation_w_bedrock.ipynb
Expand Up @@ -178,7 +178,7 @@
"- The date with the highest revenue\n",
"- Visualize monthly sales using a bar chart\n",
"\n",
"Ensure the code is syntactically correct, bug-free, optimized, not span multiple lines unnessarily, and prefer to use standard libraries. Return only python code without any surrounding text, explanation or context.\n",
"Ensure the code is syntactically correct, bug-free, optimized, not span multiple lines unnessarily, and prefer to use standard libraries that comes with Python. Return only python code without any surrounding text, explanation or context.\n",
"\n",
"Assistant:\n",
"\"\"\""
Expand Down Expand Up @@ -234,7 +234,8 @@
"response = boto3_bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)\n",
"response_body = json.loads(response.get('body').read())\n",
"\n",
"response_body.get('completion')"
"response_text = response_body.get('completion')\n",
"print(response_text)"
]
},
{
Expand All @@ -252,49 +253,161 @@
"metadata": {},
"outputs": [],
"source": [
"# Sample Generated Python Code ( Generated with Amazon Bedrock in previous step)\n",
"# Sample Generated Python Code (Generated with Amazon Bedrock in previous step)\n",
"\n",
"import csv\n",
"from collections import defaultdict\n",
"import matplotlib.pyplot as plt\n",
"\n",
"revenue = 0\n",
"monthly_revenue = defaultdict(int)\n",
"product_revenue = defaultdict(int)\n",
"max_revenue = 0\n",
"max_revenue_date = ''\n",
"max_revenue_product = ''\n",
"max_revenue_product = None\n",
"max_revenue_date = None\n",
"monthly_sales = defaultdict(int)\n",
"\n",
"with open('sales.csv') as f:\n",
" reader = csv.reader(f)\n",
" next(reader)\n",
" next(reader) # skip header\n",
" for row in reader:\n",
" date = row[0]\n",
" product = row[1]\n",
" price = float(row[2])\n",
" units = int(row[3])\n",
"\n",
" revenue += price * units\n",
" product_revenue[product] += price * units\n",
" monthly_revenue[date[:7]] += price * units\n",
"\n",
" if revenue > max_revenue:\n",
" max_revenue = revenue\n",
" max_revenue_date = date\n",
" date, product, price, units = row\n",
" revenue += float(price) * int(units)\n",
" product_revenue = float(price) * int(units)\n",
" if product_revenue > max_revenue:\n",
" max_revenue = product_revenue\n",
" max_revenue_product = product\n",
" max_revenue_date = date\n",
" month = date.split('-')[1]\n",
" monthly_sales[month] += product_revenue\n",
" \n",
"print(f'Total revenue: {revenue}') \n",
"print(f'Product with highest revenue: {max_revenue_product}')\n",
"print(f'Date with highest revenue: {max_revenue_date}')\n",
"\n",
"months = list(monthly_revenue.keys())\n",
"values = list(monthly_revenue.values())\n",
"\n",
"plt.bar(months, values)\n",
"plt.bar(monthly_sales.keys(), monthly_sales.values())\n",
"plt.xlabel('Month')\n",
"plt.ylabel('Revenue')\n",
"plt.title('Monthly Revenue')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a9464363",
"metadata": {},
"source": [
"Let's then try Anthropic Claude v3 model."
]
},
{
"cell_type": "markdown",
"id": "505708ec",
"metadata": {},
"source": [
"#### Invoke the Anthropic Claude v3 model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5bf7a55",
"metadata": {},
"outputs": [],
"source": [
"session = boto3.session.Session()\n",
"region = session.region_name\n",
"bedrock_client = boto3.client('bedrock-runtime', region_name = region)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4337ec93",
"metadata": {},
"outputs": [],
"source": [
"messages=[{ \"role\":'user', \"content\":[{'type':'text','text': prompt_data}]}]\n",
"sonnet_payload = json.dumps({\n",
" \"anthropic_version\": \"bedrock-2023-05-31\",\n",
" \"max_tokens\": 512,\n",
" \"messages\": messages,\n",
" \"temperature\": 0.5,\n",
" \"top_p\": 1\n",
" })\n",
"\n",
"modelId = 'anthropic.claude-3-sonnet-20240229-v1:0' # change this to use a different version from the model provider\n",
"accept = 'application/json'\n",
"contentType = 'application/json'\n",
"response = bedrock_client.invoke_model(body=sonnet_payload, modelId=modelId, accept=accept, contentType=contentType)\n",
"response_body = json.loads(response.get('body').read())\n",
"response_text = response_body.get('content')[0]['text']\n",
"\n",
"print(response_text)"
]
},
{
"cell_type": "markdown",
"id": "0885f663",
"metadata": {},
"source": [
"#### (Optional) Execute the Bedrock generated code for validation. Go to text editor to copy the generated code as printed output can be trucncated. Replace the code in below cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "465755c3",
"metadata": {},
"outputs": [],
"source": [
"# Sample Generated Python Code (Generated with Amazon Bedrock in previous step)\n",
"\n",
"import csv\n",
"from collections import defaultdict\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Read data from CSV file\n",
"data = []\n",
"with open('sales.csv', 'r') as file:\n",
" reader = csv.DictReader(file)\n",
" for row in reader:\n",
" data.append(row)\n",
"\n",
"# Calculate total revenue\n",
"total_revenue = sum(float(row['price']) * float(row['units_sold']) for row in data)\n",
"\n",
"# Find product with highest revenue\n",
"product_revenue = defaultdict(float)\n",
"for row in data:\n",
" product_revenue[row['product_id']] += float(row['price']) * float(row['units_sold'])\n",
"highest_revenue_product = max(product_revenue, key=product_revenue.get)\n",
"\n",
"# Find date with highest revenue\n",
"date_revenue = defaultdict(float)\n",
"for row in data:\n",
" date_revenue[row['date']] += float(row['price']) * float(row['units_sold'])\n",
"highest_revenue_date = max(date_revenue, key=date_revenue.get)\n",
"\n",
"# Visualize monthly sales\n",
"monthly_sales = defaultdict(float)\n",
"for row in data:\n",
" year, month, _ = row['date'].split('-')\n",
" monthly_sales[f\"{year}-{month}\"] += float(row['units_sold'])\n",
"\n",
"months = sorted(monthly_sales.keys())\n",
"units_sold = [monthly_sales[month] for month in months]\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"plt.bar(months, units_sold)\n",
"plt.xlabel('Month')\n",
"plt.ylabel('Units Sold')\n",
"plt.title('Monthly Sales')\n",
"plt.xticks(rotation=90)\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print('Total Revenue:', revenue)\n",
"print('Product with max revenue:', max_revenue_product)\n",
"print('Date with max revenue:', max_revenue_date)"
"print(f\"Total revenue: ${total_revenue:.2f}\")\n",
"print(f\"Product with highest revenue: {highest_revenue_product}\")\n",
"print(f\"Date with highest revenue: {highest_revenue_date}\")"
]
},
{
Expand Down