Skip to content

Commit

Permalink
add truncate option separate from rebuild option
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamingregory committed Apr 10, 2018
1 parent b275fd4 commit b1a2de5
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions operators/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import random
import string
import logging

from airflow.utils.db import provide_session
from airflow.models import Connection
from airflow.utils.decorators import apply_defaults

from airflow.models import BaseOperator
from airflow.hooks.S3_hook import S3Hook
from airflow.hooks.postgres_hook import PostgresHook
from airflow.utils.db import provide_session
from airflow.models import Connection


class S3ToRedshiftOperator(BaseOperator):
Expand Down Expand Up @@ -46,9 +48,12 @@ class S3ToRedshiftOperator(BaseOperator):
possible values include "mysql".
:type origin_datatype: string
:param load_type: The method of loading into Redshift that
should occur. Options are "append",
"rebuild", and "upsert". Defaults to
"append."
should occur. Options:
- "append"
- "rebuild"
- "truncate"
- "upsert"
Defaults to "append."
:type load_type: string
:param primary_key: *(optional)* The primary key for the
destination table. Not enforced by redshift
Expand Down Expand Up @@ -128,10 +133,10 @@ def __init__(self,
self.sortkey = sortkey
self.sort_type = sort_type

if self.load_type.lower() not in ["append", "rebuild", "upsert"]:
if self.load_type.lower() not in ("append", "rebuild", "truncate", "upsert"):
raise Exception('Please choose "append", "rebuild", or "upsert".')

if self.schema_location.lower() not in ['s3', 'local']:
if self.schema_location.lower() not in ('s3', 'local'):
raise Exception('Valid Schema Locations are "s3" or "local".')

if not (isinstance(self.sortkey, str) or isinstance(self.sortkey, list)):
Expand All @@ -152,9 +157,12 @@ def execute(self, context):
letters = string.ascii_lowercase
random_string = ''.join(random.choice(letters) for _ in range(7))
self.temp_suffix = '_tmp_{0}'.format(random_string)

if self.origin_schema:
schema = self.read_and_format()

pg_hook = PostgresHook(self.redshift_conn_id)

self.create_if_not_exists(schema, pg_hook)
self.reconcile_schemas(schema, pg_hook)
self.copy_data(pg_hook, schema)
Expand Down Expand Up @@ -221,7 +229,6 @@ def read_and_format(self):
if i['type'] == e['avro']:
i['type'] = e['redshift']

print(schema)
return schema

def reconcile_schemas(self, schema, pg_hook):
Expand Down Expand Up @@ -277,7 +284,7 @@ def getS3Conn():
elif aws_role_arn:
creds = ("aws_iam_role={0}"
.format(aws_role_arn))

return creds

# Delete records from the destination table where the incremental_key
Expand Down Expand Up @@ -331,6 +338,11 @@ def getS3Conn():
FILLTARGET
'''.format(self.redshift_schema, self.table, self.temp_suffix)

drop_sql = \
'''
DROP TABLE IF EXISTS "{0}"."{1}"
'''.format(self.redshift_schema, self.table)

drop_temp_sql = \
'''
DROP TABLE IF EXISTS "{0}"."{1}{2}"
Expand Down Expand Up @@ -366,6 +378,13 @@ def getS3Conn():
base_sql)
if self.load_type == 'append':
pg_hook.run(load_sql)
elif self.load_type == 'rebuild':
pg_hook.run(drop_sql)
self.create_if_not_exists(schema, pg_hook)
pg_hook.run(load_sql)
elif self.load_type == 'truncate':
pg_hook.run(truncate_sql)
pg_hook.run(load_sql)
elif self.load_type == 'upsert':
self.create_if_not_exists(schema, pg_hook, temp=True)
load_temp_sql = \
Expand All @@ -378,9 +397,6 @@ def getS3Conn():
pg_hook.run(delete_confirm_sql)
pg_hook.run(append_sql, autocommit=True)
pg_hook.run(drop_temp_sql)
elif self.load_type == 'rebuild':
pg_hook.run(truncate_sql)
pg_hook.run(load_sql)

def create_if_not_exists(self, schema, pg_hook, temp=False):
output = ''
Expand Down

0 comments on commit b1a2de5

Please sign in to comment.