from itertools import islice
from literal_value_generator import dump_to_sql_statement, dump_to_csv,\
dump_to_oracle_insert_statements
import random
from migrate.changeset.constraint import ForeignKeyConstraint
from datetime import datetime
import time
from copy import deepcopy
import pickle
import sqlalchemy
import logging
# from clean import cleaners
from sqlalchemy.sql import select
from sqlalchemy.schema import CreateTable, Column
from sqlalchemy.sql.schema import Table, Index
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import create_engine, MetaData, func, and_
from sqlalchemy.engine import reflection
from sqlalchemy.inspection import inspect
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.types import Text, Numeric, BigInteger, Integer, DateTime, Date, TIMESTAMP, String, BINARY, LargeBinary
from sqlalchemy.dialects.postgresql import BYTEA
import inspect as ins
import re
import csv
from schema_transformer import SchemaTransformer
from etlalchemy_exceptions import DBApiNotFound
import os
# Parse the connn_string to find relevant info for each db engine #
"""
An instance of 'ETLAlchemySource' represents 1 DB. This DB can be sent to
multiple 'ETLAlchemyTargets' via calls to ETLAlchemySource.migrate().
See examples (on github) for info...
"""
class ETLAlchemySource():
def __init__(self,
conn_string,
global_ignored_col_suffixes=[],
global_renamed_col_suffixes={},
column_schema_transformation_file=None,
table_schema_transformation_file=None,
included_tables=None,
excluded_tables=None,
skip_table_if_empty=False,
skip_column_if_empty=False,
compress_varchar=False,
log_file=None):
# TODO: Store unique columns in here, and ADD the unique constraints
# after data has been migrated, rather than before
self.unique_columns = []
self.compress_varchar = compress_varchar
self.logger = logging.getLogger("ETLAlchemySource")
self.logger.propagate = False
for h in list(self.logger.handlers):
# Clean up any old loggers...(useful during testing w/ multiple
# log_files)
self.logger.removeHandler(h)
handler = logging.StreamHandler()
if log_file is not None:
handler = logging.FileHandler(log_file)
formatter = logging.Formatter('%(name)s (%(levelname)s) - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
# Load the json dict of cleaners...
# {'table': [cleaner1, cleaner2,...etc],
# 'table2': [cleaner1,...cleanerN]}
self.included_tables = included_tables
self.excluded_tables = excluded_tables
# Set this to 'False' if you are using either of the
# following MSSQL Environments:
# 1.) AWS SQL Server
# ---> The 'bulkadmin' role required for BULK INSERT permissions
# is not available in AWS
# (see https://forums.aws.amazon.com/thread.jspa?threadID=122351)
# 2.) Azure SQL
# ---> The 'BULK INSERT' feature is disabled in the Microsoft Azure
# cloud.
# ** Otherwise, setting this to 'True' will vastly improve run-time...
self.enable_mssql_bulk_insert = False
self.current_ordered_table_columns = []
self.cleaners = {}
self.schema_transformer = SchemaTransformer(
column_transform_file=column_schema_transformation_file,
table_transform_file=table_schema_transformation_file,
global_renamed_col_suffixes=global_renamed_col_suffixes)
self.tgt_insp = None
self.src_insp = None
self.dst_engine = None
self.constraints = {}
self.indexes = {}
self.fks = {}
self.engine = None
self.connection = None
self.orm = None
self.database_url = conn_string
self.total_rows = 0
self.column_count = 0
self.table_count = 0
self.empty_table_count = 0
self.empty_tables = []
self.deleted_table_count = 0
self.deleted_column_count = 0
self.deleted_columns = []
self.null_column_count = 0
self.null_columns = []
self.referential_integrity_violations = 0
self.unique_constraint_violations = []
self.unique_constraint_violation_count = 0
self.skip_column_if_empty = skip_column_if_empty
self.skip_table_if_empty = skip_table_if_empty
self.total_indexes = 0
self.index_count = 0
self.skipped_index_count = 0
self.total_fks = 0
self.fk_count = 0
self.skipped_fk_count = 0
# Config
self.check_referential_integrity = False
self.riv_arr = []
self.start = datetime.now()
self.global_ignored_col_suffixes = global_ignored_col_suffixes
self.times = {} # Map Tables to Names...
def standardize_column_type(self, column, raw_rows):
old_column_class = column.type.__class__
column_copy = Column(column.name,
column.type,
nullable=column.nullable,
unique=column.unique,
primary_key=column.primary_key)
if column.unique:
self.unique_columns.append(column.name)
""""""""""""""""""""""""""""""""
""" *** STANDARDIZATION *** """
""""""""""""""""""""""""""""""""
idx = self.current_ordered_table_columns.index(column.name)
##############################
# Duck-typing to remove
# database-vendor specific column types
##############################
base_classes = map(
lambda c: c.__name__.upper(),
column.type.__class__.__bases__)
self.logger.info("({0}) {1}".format(column.name,
column.type.__class__.__name__))
self.logger.info("Bases: {0}".format(str(base_classes)))
# Assume the column is empty, unless told otherwise
null = True
if "ENUM" in base_classes:
for r in raw_rows:
if r[idx] is not None:
null = False
# Hack for error 'postgresql enum type requires a name'
if self.dst_engine.dialect.name.lower() == "postgresql":
column_copy.type = column.type
column_copy.type.__class__ = column.type.__class__.__bases__[0]
# Name the enumeration 'table_column'
column_copy.type.name = str(column).replace(".", "_")
else:
column_copy.type.__class__ = column.type.__class__.__bases__[0]
elif "STRING" in base_classes\
or "VARCHAR" in base_classes\
or "TEXT" in base_classes:
#########################################
# Get the VARCHAR size of the column...
########################################
varchar_length = column.type.length
##################################
# Strip collation here ...
##################################
column_copy.type.collation = None
max_data_length = 0
for row in raw_rows:
data = row[idx]
if data is not None:
null = False
# Update varchar(size)
if len(data) > max_data_length:
max_data_length = len(data)
# Ignore non-utf8 chars
row[idx] = row[idx].decode('utf-8','ignore').encode("utf-8")
if max_data_length > 256 or "TEXT" in base_classes:
self.