/*
 * Decompiled with CFR 0.152.
 */
package net.schmizz.sshj.transport.kex;

import com.hierynomus.sshj.key.KeyAlgorithm;
import com.hierynomus.sshj.userauth.certificate.Certificate;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import javax.crypto.spec.DHParameterSpec;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.digest.Digest;
import net.schmizz.sshj.transport.kex.AbstractDH;
import net.schmizz.sshj.transport.kex.DH;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractDHGex
extends AbstractDH {
    private final Logger log = LoggerFactory.getLogger(this.getClass());
    private final int minBits = 1024;
    private final int maxBits = 8192;
    private final int preferredBits = 2048;

    public AbstractDHGex(Digest digest) {
        super(new DH(), digest);
    }

    @Override
    public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
        super.init(trans, V_S, V_C, I_S, I_C);
        this.digest.init();
        this.log.debug("Sending {}", (Object)Message.KEX_DH_GEX_REQUEST);
        trans.write((SSHPacket)((SSHPacket)((SSHPacket)new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(1024L)).putUInt32(2048L)).putUInt32(8192L));
    }

    @Override
    public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException {
        this.log.debug("Got message {}", (Object)msg);
        try {
            switch (msg) {
                case KEXDH_31: {
                    return this.parseGexGroup(buffer);
                }
                case KEX_DH_GEX_REPLY: {
                    return this.parseGexReply(buffer);
                }
            }
            throw new TransportException("Unexpected message " + (Object)((Object)msg));
        }
        catch (Buffer.BufferException be2) {
            throw new TransportException(be2);
        }
    }

    private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
        byte[] K_S = buffer.readBytes();
        byte[] f2 = buffer.readBytes();
        byte[] sig = buffer.readBytes();
        this.hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();
        this.dh.computeK(f2);
        BigInteger k2 = this.dh.getK();
        Buffer.PlainBuffer buf = (Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)this.initializedBuffer().putString(K_S)).putUInt32(1024L)).putUInt32(2048L)).putUInt32(8192L)).putMPInt(((DH)this.dh).getP())).putMPInt(((DH)this.dh).getG())).putBytes(this.dh.getE())).putBytes(f2)).putMPInt(k2);
        this.digest.update(buf.array(), buf.rpos(), buf.available());
        this.H = this.digest.digest();
        KeyAlgorithm keyAlgorithm = this.trans.getHostKeyAlgorithm();
        Signature signature = keyAlgorithm.newSignature();
        if (this.hostKey instanceof Certificate) {
            signature.initVerify((PublicKey)((Certificate)this.hostKey).getKey());
        } else {
            signature.initVerify(this.hostKey);
        }
        signature.update(this.H, 0, this.H.length);
        if (!signature.verify(sig)) {
            throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "KeyExchange signature verification failed");
        }
        return true;
    }

    private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
        BigInteger p2 = buffer.readMPInt();
        BigInteger g2 = buffer.readMPInt();
        int bitLength = p2.bitLength();
        if (bitLength < 1024 || bitLength > 8192) {
            throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)");
        }
        this.log.debug("Received server p bitlength {}", (Object)bitLength);
        this.dh.init(new DHParameterSpec(p2, g2), this.trans.getConfig().getRandomFactory());
        this.log.debug("Sending {}", (Object)Message.KEX_DH_GEX_INIT);
        this.trans.write((SSHPacket)new SSHPacket(Message.KEX_DH_GEX_INIT).putBytes(this.dh.getE()));
        return false;
    }
}

