package com.efuture.ocp.common.slice.filter;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
import com.alibaba.druid.sql.dialect.oracle.visitor.OracleOutputVisitor;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.alibaba.fastjson.JSONObject;
import com.efuture.ocp.common.filter.FtSqlSource;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Invocation;
import org.slf4j.Logger;
import org.springframework.util.StringUtils;

import java.util.List;

public class CollectionInsertWrapper extends CollectionUpdateWrapper
{
    protected void onInsertStatement(Logger logger, JSONObject logJson, Invocation invocation, MappedStatement mappedStatement, BoundSql boundSql, SQLStatement stmt, String partionKey, String sliceKey)
            throws Throwable
    {
        SQLInsertStatement insertStmt = (SQLInsertStatement)stmt;
        String collectionName = insertStmt.getTableName().getSimpleName();
        String organizationCode = partionKey;
        if (getCollectionMap().keySet().contains(collectionName.toLowerCase())) {
            String fieldName = (String)getCollectionMap().get(collectionName.toLowerCase());
            String partionCode = null;
            if (getOrganizationMap().containsKey(organizationCode))
            {
                partionCode = String.format("'%1$s'", new Object[] { getOrganizationMap().get(organizationCode) });
            }
            else
            {
                partionCode = String.format("'%1$s'", new Object[] { sliceKey });
            }
            if (!StringUtils.isEmpty( partionCode ))
            {
                List fields = insertStmt.getColumns();
                List valueList = insertStmt.getValuesList();

                boolean found = false;
                int fndIndex = 0;
                for (int i = 0; i < fields.size(); i++) {
                    SQLExpr field = (SQLExpr)fields.get(i);
                    if (fieldName.equalsIgnoreCase(field.toString())) {
                        found = true;
                        fndIndex = i;
                    }
                }

                if (found)
                {
                    for (int i = 0; i < valueList.size(); i++) {
                        SQLInsertStatement.ValuesClause dataRow = (SQLInsertStatement.ValuesClause)valueList.get(i);
                        dataRow.getValues().set(fndIndex, new SQLIdentifierExpr(partionCode));
                    }
                }
                else {
                    fields.add(new SQLIdentifierExpr(fieldName));

                    for (int i = 0; i < valueList.size(); i++) {
                        SQLInsertStatement.ValuesClause dataRow = (SQLInsertStatement.ValuesClause)valueList.get(i);
                        dataRow.getValues().add(new SQLIdentifierExpr(partionCode));
                    }
                }
            }
        }

        StringBuffer newSql = new StringBuffer();
        SQLASTOutputVisitor visitor = null;
        if (getDbType().equals("mysql"))
        {
            visitor = new MySqlOutputVisitor(newSql);
        }
        else if (getDbType().equals("oracle"))
        {
            visitor = new OracleOutputVisitor(newSql);
        }
        else
        {
            throw new Exception("未识别的dbType");
        }
        visitor.visit(insertStmt);
        visitor.println();
        visitor.endVisit(insertStmt);

        SqlSource sqlSource = new FtSqlSource(mappedStatement.getConfiguration(), newSql.toString(), boundSql);
        MappedStatement newMs = copyFromMappedStatement(mappedStatement, sqlSource);

        invocation.getArgs()[0] = newMs;

        logJson.put("newSql", newSql.toString());
        logger.debug(String.format("newSQL --->%1$s", new Object[] { newSql.toString().replaceAll("[\\n]", "") }));
    }
}