/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, SparkSession}

trait ShimCometTestBase {
  type SparkSessionType = SparkSession

  def createSparkSessionWithExtensions(conf: SparkConf): SparkSessionType = {
    SparkSession
      .builder()
      .config(conf)
      .master("local[1]")
      .withExtensions(new org.apache.comet.CometSparkSessionExtensions)
      .getOrCreate()
  }

  def datasetOfRows(spark: SparkSession, plan: LogicalPlan): DataFrame = {
    Dataset.ofRows(spark, plan)
  }

  def getColumnFromExpression(expr: Expression): Column = {
    new Column(ExpressionColumnNode.apply(expr))
  }

  def extractLogicalPlan(df: DataFrame): LogicalPlan = {
    df.queryExecution.analyzed
  }

  def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
    new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, true)))
  }
}
