diff --git a/operators/s3_to_redshift.py b/operators/s3_to_redshift.py index 98f26e0..bc6145e 100644 --- a/operators/s3_to_redshift.py +++ b/operators/s3_to_redshift.py @@ -2,12 +2,14 @@ import random import string import logging + +from airflow.utils.db import provide_session +from airflow.models import Connection from airflow.utils.decorators import apply_defaults + from airflow.models import BaseOperator from airflow.hooks.S3_hook import S3Hook from airflow.hooks.postgres_hook import PostgresHook -from airflow.utils.db import provide_session -from airflow.models import Connection class S3ToRedshiftOperator(BaseOperator): @@ -46,9 +48,12 @@ class S3ToRedshiftOperator(BaseOperator): possible values include "mysql". :type origin_datatype: string :param load_type: The method of loading into Redshift that - should occur. Options are "append", - "rebuild", and "upsert". Defaults to - "append." + should occur. Options: + - "append" + - "rebuild" + - "truncate" + - "upsert" + Defaults to "append." :type load_type: string :param primary_key: *(optional)* The primary key for the destination table. Not enforced by redshift @@ -128,10 +133,10 @@ def __init__(self, self.sortkey = sortkey self.sort_type = sort_type - if self.load_type.lower() not in ["append", "rebuild", "upsert"]: + if self.load_type.lower() not in ("append", "rebuild", "truncate", "upsert"): raise Exception('Please choose "append", "rebuild", or "upsert".') - if self.schema_location.lower() not in ['s3', 'local']: + if self.schema_location.lower() not in ('s3', 'local'): raise Exception('Valid Schema Locations are "s3" or "local".') if not (isinstance(self.sortkey, str) or isinstance(self.sortkey, list)): @@ -152,9 +157,12 @@ def execute(self, context): letters = string.ascii_lowercase random_string = ''.join(random.choice(letters) for _ in range(7)) self.temp_suffix = '_tmp_{0}'.format(random_string) + if self.origin_schema: schema = self.read_and_format() + pg_hook = PostgresHook(self.redshift_conn_id) + self.create_if_not_exists(schema, pg_hook) self.reconcile_schemas(schema, pg_hook) self.copy_data(pg_hook, schema) @@ -221,7 +229,6 @@ def read_and_format(self): if i['type'] == e['avro']: i['type'] = e['redshift'] - print(schema) return schema def reconcile_schemas(self, schema, pg_hook): @@ -277,7 +284,7 @@ def getS3Conn(): elif aws_role_arn: creds = ("aws_iam_role={0}" .format(aws_role_arn)) - + return creds # Delete records from the destination table where the incremental_key @@ -331,6 +338,11 @@ def getS3Conn(): FILLTARGET '''.format(self.redshift_schema, self.table, self.temp_suffix) + drop_sql = \ + ''' + DROP TABLE IF EXISTS "{0}"."{1}" + '''.format(self.redshift_schema, self.table) + drop_temp_sql = \ ''' DROP TABLE IF EXISTS "{0}"."{1}{2}" @@ -366,6 +378,13 @@ def getS3Conn(): base_sql) if self.load_type == 'append': pg_hook.run(load_sql) + elif self.load_type == 'rebuild': + pg_hook.run(drop_sql) + self.create_if_not_exists(schema, pg_hook) + pg_hook.run(load_sql) + elif self.load_type == 'truncate': + pg_hook.run(truncate_sql) + pg_hook.run(load_sql) elif self.load_type == 'upsert': self.create_if_not_exists(schema, pg_hook, temp=True) load_temp_sql = \ @@ -378,9 +397,6 @@ def getS3Conn(): pg_hook.run(delete_confirm_sql) pg_hook.run(append_sql, autocommit=True) pg_hook.run(drop_temp_sql) - elif self.load_type == 'rebuild': - pg_hook.run(truncate_sql) - pg_hook.run(load_sql) def create_if_not_exists(self, schema, pg_hook, temp=False): output = ''