package oracle.pgx.engine.mllib;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Iterator;
import java.util.List;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.internal.mllib.Pg2vecModelMetadata;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.engine.Server;
import oracle.pgx.engine.Session;
import oracle.pgx.engine.exec.FunctionRequest;
import oracle.pgx.engine.exec.TaskType;
import oracle.pgx.engine.instance.CachedRowTable;
import oracle.pgx.engine.instance.InstanceManager;
import oracle.pgx.loaders.api.StorerException;
import oracle.pgx.mllib.api.Pg2vecModel;
import oracle.pgx.mllib.api.Pg2vecModelProvider;
import oracle.pgx.runtime.GmGraphWithProperties;
import oracle.pgx.runtime.LoaderException;
import oracle.pgx.vfs.VirtualFileManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:oracle/pgx/engine/mllib/Pg2vecServerModel.class */
public class Pg2vecServerModel implements MlServerModel<Pg2vecModel> {
    private static final Logger LOG = LoggerFactory.getLogger(Pg2vecServerModel.class);
    private static final VirtualFileManager VFM = VirtualFileManager.getInstance();
    private static final Pg2vecModelProvider PG2_VEC_MODEL_PROVIDER = getPg2VecProvider();
    private final Session session;
    private final Pg2vecModelMetadata modelMetadata;
    private final InstanceManager instanceManager;
    private Pg2vecModel model;

    private static Pg2vecModelProvider getPg2VecProvider() {
        Iterator<?> it = MlServerModel.initMlModelProviders(Pg2vecModelProvider.class).iterator();
        if (it.hasNext()) {
            return (Pg2vecModelProvider) it.next();
        }
        throw new IllegalStateException(ErrorMessages.getMessage("ML_PROVIDER_NOT_FOUND", new Object[]{"Pg2vec"}));
    }

    public Pg2vecServerModel(Session session, InstanceManager instanceManager, Pg2vecModelMetadata pg2vecModelMetadata) {
        this.model = PG2_VEC_MODEL_PROVIDER.getModel(Pg2vecModelConfigUtils.fromPg2vecModelMetadata(pg2vecModelMetadata));
        this.modelMetadata = pg2vecModelMetadata;
        this.instanceManager = instanceManager;
        this.session = session;
    }

    public Pg2vecModel getModel() {
        return this.model;
    }

    public void setModel(Pg2vecModel pg2vecModel) {
        this.model = pg2vecModel;
    }

    @Override // oracle.pgx.engine.mllib.MlServerModel
    public PgxFuture<Double> fit(String str) {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.FIT_ML_MODEL, (session, request) -> {
            LOG.info("Initiated: Train Pg2vecModel");
            return Double.valueOf(this.model.fit(extractGraphWithProperties(session, str)));
        }));
    }

    public PgxFuture<CachedRowTable> inferGraphletVector(String str) {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.INFER_ML_MODEL, (session, request) -> {
            LOG.info("Inferring vector for {}: ", str);
            return CachedRowTable.buildCachedRowTable(session, "graphletVectorFrame", this.model.inferGraphletVector(request.getDataStructureFactory(), extractGraphWithProperties(session, str)));
        }));
    }

    public PgxFuture<CachedRowTable> inferGraphletVectorBatched(String str) {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.INFER_ML_MODEL, (session, request) -> {
            LOG.info("Inferring vectors for graphlets in {}: ", str);
            return CachedRowTable.buildCachedRowTable(session, "graphletVectorBatchedFrame", this.model.inferGraphletVectorBatched(request.getDataStructureFactory(), extractGraphWithProperties(session, str)));
        }));
    }

    public PgxFuture<CachedRowTable> computeSimilars(String str, int i) {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.INFER_ML_MODEL, (session, request) -> {
            LOG.info("Computing similars for {}: ", str);
            return CachedRowTable.buildCachedRowTable(session, "similarsGraphletsFrame", this.model.computeSimilars(request.getDataStructureFactory(), str, i));
        }));
    }

    public PgxFuture<CachedRowTable> computeSimilarsBatched(List<String> list, int i) {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.INFER_ML_MODEL, (session, request) -> {
            LOG.info("Computing similars for {}: ", list);
            return CachedRowTable.buildCachedRowTable(session, "similarsGraphletsBatchedFrame", this.model.computeSimilarsBatched(request.getDataStructureFactory(), list, i));
        }));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // oracle.pgx.engine.mllib.MlServerModel
    public Pg2vecModel loadModel(String str) throws LoaderException {
        try {
            InputStream inputStream = VFM.getInputStream(str);
            Throwable th = null;
            try {
                try {
                    Pg2vecModel load = PG2_VEC_MODEL_PROVIDER.getSerializer().load(inputStream);
                    if (inputStream != null) {
                        if (0 != 0) {
                            try {
                                inputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            inputStream.close();
                        }
                    }
                    return load;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new LoaderException(ErrorMessages.getMessage("CANNOT_READ_THE_FILE", new Object[]{str}), e);
        }
    }

    @Override // oracle.pgx.engine.mllib.MlServerModel
    public void storeModel(String str) throws StorerException {
        try {
            OutputStream outputStream = VFM.getOutputStream(str, true);
            Throwable th = null;
            try {
                try {
                    PG2_VEC_MODEL_PROVIDER.getSerializer().store(outputStream, this.model);
                    if (outputStream != null) {
                        if (0 != 0) {
                            try {
                                outputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new StorerException(ErrorMessages.getMessage("CANNOT_WRITE_THE_MODEL", new Object[]{str}), e);
        }
    }

    public Pg2vecModelMetadata getModelMetadata() {
        return this.modelMetadata;
    }

    public PgxFuture<CachedRowTable> getTrainedGraphletVectors() {
        return Server.enqueue(new FunctionRequest(this.session.getId(), TaskType.FETCH_MODEL_DATA, (session, request) -> {
            LOG.info("Fetching graphlet vectors");
            CachedRowTable buildCachedRowTable = CachedRowTable.buildCachedRowTable(session, "graphletVectorsFrame", this.model.getTrainedGraphletVectors(request.getDataStructureFactory()));
            LOG.info("Fetched the graphlet vectors");
            return buildCachedRowTable;
        }));
    }

    private GmGraphWithProperties extractGraphWithProperties(Session session, String str) {
        return MlLibUtils.extractGraphWithProperties(this.instanceManager, session, str);
    }
}
