db_to_parquet.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from database import conn_string
  2. import pandas as pd
  3. import numpy as np
  4. import fastparquet
  5. from sqlalchemy import create_engine, inspect, schema, Table
  6. # Copied from pandas with modifications
  7. def __get_dtype(column, sqltype):
  8. import sqlalchemy.dialects as sqld
  9. from sqlalchemy.types import Integer, Float, Boolean, DateTime, Date, TIMESTAMP
  10. if isinstance(sqltype, Float):
  11. return float
  12. elif isinstance(sqltype, Integer):
  13. # Since DataFrame cannot handle nullable int, convert nullable ints to floats
  14. if column.nullable:
  15. return float
  16. # TODO: Refine integer size.
  17. return np.dtype("int64")
  18. elif isinstance(sqltype, TIMESTAMP):
  19. # we have a timezone capable type
  20. if not sqltype.timezone:
  21. return np.dtype("datetime64[ns]")
  22. return pd.DatetimeTZDtype
  23. elif isinstance(sqltype, DateTime):
  24. # Caution: np.datetime64 is also a subclass of np.number.
  25. return np.dtype("datetime64[ns]")
  26. elif isinstance(sqltype, Date):
  27. return np.date
  28. elif isinstance(sqltype, Boolean):
  29. return bool
  30. elif isinstance(sqltype, sqld.mssql.base.BIT):
  31. # Handling database provider specific types
  32. return np.dtype("u1")
  33. # Catch all type - handle provider specific types in another elif block
  34. return object
  35. def __write_parquet(output_path: str, batch_array, column_dict, write_index: bool, compression: str, append: bool):
  36. # Create the DataFrame to hold the batch array contents
  37. b_df = pd.DataFrame(batch_array, columns=column_dict)
  38. # Cast the DataFrame columns to the sqlalchemy column analogues
  39. b_df = b_df.astype(dtype=column_dict)
  40. # Write to the parquet file (first write needs append=False)
  41. fastparquet.write(output_path, b_df, write_index=write_index, compression=compression, append=append)
  42. def table_to_parquet(
  43. output_path: str, table_name: str, con, batch_size: int = 10000, write_index: bool = True, compression: str = None
  44. ):
  45. # Get database schema using sqlalchemy reflection
  46. db_engine = create_engine(con)
  47. db_inspect = inspect(db_engine)
  48. db_tables = db_inspect.get_table_names(schema="import")
  49. # Get the columns for the parquet file
  50. column_dict = dict()
  51. for column in db_inspect.get_columns(table_name, "import"):
  52. dtype = __get_dtype(column, column.type)
  53. column_dict[column.name] = dtype
  54. # Query the table
  55. with db_engine.connect() as conn:
  56. # print(db_table.select())
  57. result = conn.execute("SELECT * FROM import.journal_accountings")
  58. row_batch = result.fetchmany(size=batch_size)
  59. append = False
  60. while len(row_batch) > 0:
  61. __write_parquet(output_path, row_batch, column_dict, write_index, compression, append)
  62. append = True
  63. row_batch = result.fetchmany(size=batch_size)
  64. if __name__ == "__main__":
  65. dsn = {
  66. "user": "sa",
  67. "password": "Mffu3011#",
  68. "server": "localhost\\GLOBALCUBE",
  69. "database": "LOCOSOFT",
  70. "driver": "mssql",
  71. "schema": "import",
  72. }
  73. conn_str = conn_string(dsn)
  74. table_to_parquet("temp", "journal_accountings", conn_str)
  75. # print(timeit.timeit(s))