package oracle.pgx.api.beta.mllib;

import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.beta.frames.PgxFrame;
import oracle.pgx.api.beta.frames.internal.PgxFrameImpl;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.mllib.Pg2vecModelMetadata;
import oracle.pgx.common.ObjectHolder;
import oracle.pgx.common.types.PropertyType;
import oracle.pgx.common.util.ErrorMessages;

/* loaded from: input_file:oracle/pgx/api/beta/mllib/Pg2vecModel.class */
public class Pg2vecModel {
    private final PgxSession session;
    private final Core core;
    private final Pg2vecModelMetadata modelMetadata;

    private Pg2vecModel(PgxSession pgxSession, Core core, Pg2vecModelMetadata pg2vecModelMetadata) {
        this.core = core;
        this.session = pgxSession;
        this.modelMetadata = pg2vecModelMetadata;
    }

    public PgxFuture<Void> fitAsync(PgxGraph pgxGraph) {
        ObjectHolder objectHolder = new ObjectHolder();
        return pgxGraph.getVertexPropertyAsync(this.modelMetadata.getWalkPropertyName()).thenCompose(vertexProperty -> {
            return vertexProperty != null ? PgxFuture.exceptionallyCompletedFuture(new IllegalStateException(ErrorMessages.getMessage("PROPERTY_EXISTS", new Object[]{this.modelMetadata.getWalkPropertyName()}))) : pgxGraph.createVertexPropertyAsync(PropertyType.STRING, this.modelMetadata.getWalkPropertyName());
        }).thenCompose(vertexProperty2 -> {
            return pgxGraph.getVertexPropertyAsync(this.modelMetadata.getGraphletSizePropertyName()).thenApply(vertexProperty2 -> {
                return vertexProperty2 != null ? PgxFuture.exceptionallyCompletedFuture(new IllegalStateException(ErrorMessages.getMessage("PROPERTY_EXISTS", new Object[]{this.modelMetadata.getGraphletSizePropertyName()}))) : pgxGraph.createVertexPropertyAsync(PropertyType.LONG, this.modelMetadata.getGraphletSizePropertyName());
            });
        }).thenCompose(pgxFuture -> {
            return pgxGraph.getVertexPropertyAsync(this.modelMetadata.getGraphLetIdPropertyName()).thenApply(vertexProperty3 -> {
                objectHolder.set(vertexProperty3.getType().toString());
                return null;
            });
        }).thenCompose(obj -> {
            return this.core.fitPg2vecModel(this.session.getId(), this.modelMetadata.getModelName(), pgxGraph.getName(), (String) objectHolder.get()).thenApply(d -> {
                this.modelMetadata.setLoss(d);
                return null;
            });
        });
    }

    public void fit(PgxGraph pgxGraph) throws ExecutionException, InterruptedException {
        if (pgxGraph.isDirected()) {
            throw new IllegalStateException(ErrorMessages.getMessage("ML_INPUT_GRAPH_DIRECTED", new Object[]{"Pg2vecModel"}));
        }
        fitAsync(pgxGraph).get();
    }

    public PgxFuture<PgxFrame> computeSimilarsAsync(Object obj, int i) {
        return this.core.computeSimilarsPg2vecModel(this.session.getId(), this.modelMetadata.getModelName(), String.valueOf(obj), i).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData);
        });
    }

    public PgxFrame computeSimilars(Object obj, int i) throws ExecutionException, InterruptedException {
        return computeSimilarsAsync(obj, i).get();
    }

    public PgxFuture<PgxFrame> computeSimilarsAsync(List<Object> list, int i) {
        return this.core.computeSimilarsBatchedPg2vecModel(this.session.getId(), this.modelMetadata.getModelName(), Lists.transform(list, Functions.toStringFunction()), i).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData);
        });
    }

    public PgxFrame computeSimilars(List<Object> list, int i) throws ExecutionException, InterruptedException {
        return computeSimilarsAsync(list, i).get();
    }

    public PgxFuture<PgxFrame> inferGraphletVectorAsync(PgxGraph pgxGraph) {
        pgxGraph.createVertexProperty(PropertyType.STRING, this.modelMetadata.getWalkPropertyName());
        pgxGraph.createVertexProperty(PropertyType.LONG, this.modelMetadata.getGraphletSizePropertyName());
        return this.core.inferGraphletVectorPg2vecModel(this.session.getId(), this.modelMetadata.getModelName(), pgxGraph.getName(), pgxGraph.getVertexProperty(this.modelMetadata.getGraphLetIdPropertyName()).getType().toString()).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData);
        });
    }

    public PgxFrame inferGraphletVector(PgxGraph pgxGraph) throws ExecutionException, InterruptedException {
        if (pgxGraph.isDirected()) {
            throw new IllegalStateException(ErrorMessages.getMessage("ML_INPUT_GRAPH_DIRECTED", new Object[]{"Pg2vecModel"}));
        }
        return inferGraphletVectorAsync(pgxGraph).get();
    }

    public PgxFuture<PgxFrame> inferGraphletVectorBatchedAsync(PgxGraph pgxGraph) {
        pgxGraph.createVertexProperty(PropertyType.STRING, this.modelMetadata.getWalkPropertyName());
        pgxGraph.createVertexProperty(PropertyType.LONG, this.modelMetadata.getGraphletSizePropertyName());
        return this.core.inferGraphletVectorBatchedPg2vecModel(this.session.getId(), this.modelMetadata.getModelName(), pgxGraph.getName(), pgxGraph.getVertexProperty(this.modelMetadata.getGraphLetIdPropertyName()).getType().toString()).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData);
        });
    }

    public PgxFrame inferGraphletVectorBatched(PgxGraph pgxGraph) throws ExecutionException, InterruptedException {
        if (pgxGraph.isDirected()) {
            throw new IllegalStateException(ErrorMessages.getMessage("ML_INPUT_GRAPH_DIRECTED", new Object[]{"Pg2vecModel"}));
        }
        return inferGraphletVectorBatchedAsync(pgxGraph).get();
    }

    public PgxFuture<PgxFrame> getTrainedGraphletVectorsAsync() {
        return this.core.getTrainedGraphletVectorsPg2vecModel(this.session.getId(), this.modelMetadata.getModelName()).thenApply(frameMetaData -> {
            return new PgxFrameImpl(this.session, this.core, frameMetaData);
        });
    }

    public PgxFrame getTrainedGraphletVectors() throws ExecutionException, InterruptedException {
        return getTrainedGraphletVectorsAsync().get();
    }

    public PgxFuture<Void> storeAsync(String str) {
        return this.core.storeModel(this.session.getId(), this.modelMetadata.getModelName(), str);
    }

    public void store(String str) throws ExecutionException, InterruptedException {
        storeAsync(str).get();
    }

    public PgxFuture<Void> destroyAsync() {
        return this.core.destroyMlModel(this.session.getId(), this.modelMetadata.getModelName(), false);
    }

    public void destroy() throws ExecutionException, InterruptedException {
        destroyAsync().get();
    }

    public String getGraphLetIdPropertyName() {
        return this.modelMetadata.getGraphLetIdPropertyName();
    }

    public Collection<String> getVertexPropertyNames() {
        return this.modelMetadata.getVertexPropertyNames();
    }

    public int getMinWordFrequency() {
        return this.modelMetadata.getMinWordFrequency();
    }

    public int getNumEpochs() {
        return this.modelMetadata.getNumEpochs();
    }

    public int getLayerSize() {
        return this.modelMetadata.getLayerSize();
    }

    public double getLearningRate() {
        return this.modelMetadata.getLearningRate();
    }

    public double getMinLearningRate() {
        return this.modelMetadata.getMinLearningRate();
    }

    public int getWindowSize() {
        return this.modelMetadata.getWindowSize();
    }

    public int getWalkLength() {
        return this.modelMetadata.getWalkLength();
    }

    public int getWalksPerVertex() {
        return this.modelMetadata.getWalksPerVertex();
    }

    public String getGraphletSizePropertyName() {
        return this.modelMetadata.getGraphletSizePropertyName();
    }

    public String getWalkPropertyName() {
        return this.modelMetadata.getWalkPropertyName();
    }

    public double getValidationFraction() {
        return this.modelMetadata.getValidationFraction();
    }

    public double getLoss() {
        return this.modelMetadata.getLoss().doubleValue();
    }
}
