from pyspark.sql import SparkSession
import sys
from awsglue.utils import getResolvedOptions
args = getResolvedOptions(sys.argv, ['JOB_NAME','s3_bucket_arn','input_data_path', 'namespace','s3_table_name','region'])
region = args["region"]
s3_bucket_arn = args["s3_bucket_arn"]
input_data_path = args["input_data_path"]
namespace = args["namespace"]
s3_table_name = args["s3_table_name"]
spark = SparkSession.builder.appName("glue-s3-tables-rest") \
    .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
    .config("spark.sql.defaultCatalog", "s3_rest_catalog") \
    .config("spark.sql.catalog.s3_rest_catalog", "org.apache.iceberg.spark.SparkCatalog") \
    .config("spark.sql.catalog.s3_rest_catalog.type", "rest") \
    .config("spark.sql.catalog.s3_rest_catalog.uri", f"https://s3tables.{region}.amazonaws.com/iceberg") \
    .config("spark.sql.catalog.s3_rest_catalog.warehouse", s3_bucket_arn) \
    .config("spark.sql.catalog.s3_rest_catalog.rest.sigv4-enabled", "true") \
    .config("spark.sql.catalog.s3_rest_catalog.rest.signing-name", "s3tables") \
    .config("spark.sql.catalog.s3_rest_catalog.rest.signing-region", region) \
    .config('spark.sql.catalog.s3_rest_catalog.io-impl','org.apache.iceberg.aws.s3.S3FileIO') \
    .config('spark.sql.catalog.s3_rest_catalog.rest-metrics-reporting-enabled','false') \
    .getOrCreate()

dataframe = spark.read.parquet(input_data_path)
dataframe.createOrReplaceTempView("tmp_taxi")
spark.sql(f"""CREATE TABLE IF NOT EXISTS {namespace}.{s3_table_name} (vendor_id bigint, tpep_pickup_datetime string, tpep_dropoff_datetime string, passenger_count double, trip_distance double, rate_code_id double, store_and_fwd_flag string, pu_location_id bigint, do_location_id bigint, payment_type bigint, fare_amount double, extra double, mta_tax double, tip_amount double, tolls_amount double, improvement_surcharge double, total_amount double, congestion_surcharge double, airport_fee double)""")
spark.sql(f"""INSERT INTO {namespace}.{s3_table_name} SELECT VendorID as vendor_id, tpep_pickup_datetime, tpep_dropoff_datetime, passenger_count , trip_distance , RatecodeID as rate_code_id, store_and_fwd_flag, PULocationID as pu_location_id, DOLocationID as do_location_id, payment_type, fare_amount , extra , mta_tax , tip_amount , tolls_amount , improvement_surcharge , total_amount , congestion_surcharge , airport_fee FROM tmp_taxi""")