Skip to content

Commit

Permalink
Merge pull request #3 from airflow-plugins/feature/add-keys
Browse files Browse the repository at this point in the history
Feature/add keys
  • Loading branch information
Ben committed Feb 2, 2018
2 parents d776c51 + 855088f commit c7478fd
Showing 1 changed file with 90 additions and 10 deletions.
100 changes: 90 additions & 10 deletions operators/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,33 @@ class S3ToRedshiftOperator(BaseOperator):
with. Only required if using a load_type of
"upsert".
:type incremental_key: string
:param foreign_key: *(optional)* This specifies any foreign_keys
in the table and which corresponding table
and key they reference. This may be either
a dictionary or list of dictionaries (for
multiple foreign keys). The fields that are
required in each dictionary are:
- column_name
- reftable
- ref_column
:type foreign_key: dictionary
:param distkey: *(optional)* The distribution key for the
table. Only one key may be specified.
:type distkey: string
:param sortkey: *(optional)* The sort keys for the table.
If more than one key is specified, set this
as a list.
:type sortkey: string
:param sort_type: *(optional)* The style of distribution
to sort the table. Possible values include:
- compound
- interleaved
Defaults to "compound".
:type sort_type: string
"""

template_fields = ('s3_key',
'origin_schema',
'com')
'origin_schema')

@apply_defaults
def __init__(self,
Expand All @@ -81,7 +103,10 @@ def __init__(self,
load_type='append',
primary_key=None,
incremental_key=None,
timeformat='auto',
foreign_key={},
distkey=None,
sortkey='',
sort_type='COMPOUND',
*args,
**kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -98,14 +123,29 @@ def __init__(self,
self.load_type = load_type
self.primary_key = primary_key
self.incremental_key = incremental_key
self.timeformat = timeformat
self.foreign_key = foreign_key
self.distkey = distkey
self.sortkey = sortkey
self.sort_type = sort_type

if self.load_type.lower() not in ["append", "rebuild", "upsert"]:
raise Exception('Please choose "append", "rebuild", or "upsert".')

if self.schema_location.lower() not in ['s3', 'local']:
raise Exception('Valid Schema Locations are "s3" or "local".')

if not (isinstance(self.sortkey, str) or isinstance(self.sortkey, list)):
raise Exception('Sort Keys must be specified as either a string or list.')

if not (isinstance(self.foreign_key, dict) or isinstance(self.foreign_key, list)):
raise Exception('Foreign Keys must be specified as either a dictionary or a list of dictionaries.')

if ((',' in self.distkey) or not isinstance(self.distkey, str)):
raise Exception('Only one distribution key may be specified.')

if self.sort_type.lower() not in ('compound', 'interleaved'):
raise Exception('Please choose "compound" or "interleaved" for sort type.')

def execute(self, context):
# Append a random string to the end of the staging table to ensure
# no conflicts if multiple processes running concurrently.
Expand Down Expand Up @@ -337,6 +377,8 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
for item in schema:
k = "{quote}{key}{quote}".format(quote='"', key=item['name'])
field = ' '.join([k, item['type']])
if isinstance(self.sortkey, str) and self.sortkey == item['name']:
field += ' sortkey'
output += field
output += ', '
# Remove last comma and space after schema items loop ends
Expand All @@ -346,12 +388,50 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
else:
copy_table = self.table
create_schema_query = \
'''CREATE SCHEMA IF NOT EXISTS "{0}";'''.format(
self.redshift_schema)
'''
CREATE SCHEMA IF NOT EXISTS "{0}";
'''.format(self.redshift_schema)

pk = ''
fk = ''
dk = ''
sk = ''

if self.primary_key:
pk = ', primary key("{0}")'.format(self.primary_key)

if self.foreign_key:
if isinstance(self.foreign_key, list):
fk = ', '
for i, e in enumerate(self.foreign_key):
fk += 'foreign key("{0}") references {1}("{2}")'.format(e['column_name'],
e['reftable'],
e['ref_column'])
if i != (len(self.foreign_key) - 1):
fk += ', '""
elif isinstance(self.foreign_key, dict):
fk += ', '
fk += 'foreign key("{0}") references {1}("{2}")'.format(self.foreign_key['column_name'],
self.foreign_key['reftable'],
self.foreign_key['ref_column'])
if self.distkey:
dk = 'distkey({})'.format(self.distkey)

if self.sortkey:
if isinstance(self.sortkey, list):
sk += '{0} sortkey({1})'.format(self.sort_type, ', '.join(["{}".format(e) for e in self.sortkey]))

create_table_query = \
'''CREATE TABLE IF NOT EXISTS "{0}"."{1}" ({2})'''.format(
self.redshift_schema,
copy_table,
output)
'''
CREATE TABLE IF NOT EXISTS "{schema}"."{table}"
({fields}{primary_key}{foreign_key}) {distkey} {sortkey}
'''.format(schema=self.redshift_schema,
table=copy_table,
fields=output,
primary_key=pk,
foreign_key=fk,
distkey=dk,
sortkey=sk)

pg_hook.run(create_schema_query)
pg_hook.run(create_table_query)

0 comments on commit c7478fd

Please sign in to comment.