s3:libsmb: use new simplified smb_signing code for the client side
[abartlet/samba.git/.git] / source3 / libsmb / clitrans.c
index 0266c0307e5a296139cd58b5fcb1f7007a11f0b0..c566972b211102e98efb84da0ae6f248eac2d8fa 100644 (file)
@@ -94,14 +94,12 @@ bool cli_send_trans(struct cli_state *cli, int trans,
                return False;
        }
 
-       /* Note we're in a trans state. Save the sequence
-        * numbers for replies. */
-       client_set_trans_sign_state_on(cli, mid);
+       cli_state_seqnum_persistent(cli, mid);
 
        if (this_ldata < ldata || this_lparam < lparam) {
                /* receive interim response */
                if (!cli_receive_smb(cli) || cli_is_error(cli)) {
-                       client_set_trans_sign_state_off(cli, mid);
+                       cli_state_seqnum_remove(cli, mid);
                        return(False);
                }
 
@@ -137,12 +135,11 @@ bool cli_send_trans(struct cli_state *cli, int trans,
 
                        show_msg(cli->outbuf);
 
-                       client_set_trans_sign_state_off(cli, mid);
                        cli->mid = mid;
                        if (!cli_send_smb(cli)) {
+                               cli_state_seqnum_remove(cli, mid);
                                return False;
                        }
-                       client_set_trans_sign_state_on(cli, mid);
 
                        tot_data += this_ldata;
                        tot_param += this_lparam;
@@ -165,10 +162,14 @@ bool cli_receive_trans(struct cli_state *cli,int trans,
        unsigned int this_data,this_param;
        NTSTATUS status;
        bool ret = False;
+       uint16_t mid;
 
        *data_len = *param_len = 0;
 
+       mid = SVAL(cli->inbuf,smb_mid);
+
        if (!cli_receive_smb(cli)) {
+               cli_state_seqnum_remove(cli, mid);
                return False;
        }
 
@@ -179,6 +180,7 @@ bool cli_receive_trans(struct cli_state *cli,int trans,
                DEBUG(0,("Expected %s response, got command 0x%02x\n",
                         trans==SMBtrans?"SMBtrans":"SMBtrans2",
                         CVAL(cli->inbuf,smb_com)));
+               cli_state_seqnum_remove(cli, mid);
                return False;
        }
 
@@ -331,6 +333,8 @@ bool cli_receive_trans(struct cli_state *cli,int trans,
 
   out:
 
+       cli_state_seqnum_remove(cli, mid);
+
        if (ret) {
                /* Ensure the last 2 bytes of param and data are 2 null
                 * bytes. These are malloc'ed, but not included in any
@@ -344,7 +348,6 @@ bool cli_receive_trans(struct cli_state *cli,int trans,
                }
        }
 
-       client_set_trans_sign_state_off(cli, SVAL(cli->inbuf,smb_mid));
        return ret;
 }
 
@@ -412,14 +415,12 @@ bool cli_send_nt_trans(struct cli_state *cli,
                return False;
        }
 
-       /* Note we're in a trans state. Save the sequence
-        * numbers for replies. */
-       client_set_trans_sign_state_on(cli, mid);
+       cli_state_seqnum_persistent(cli, mid);
 
        if (this_ldata < ldata || this_lparam < lparam) {
                /* receive interim response */
                if (!cli_receive_smb(cli) || cli_is_error(cli)) {
-                       client_set_trans_sign_state_off(cli, mid);
+                       cli_state_seqnum_remove(cli, mid);
                        return(False);
                }
 
@@ -454,12 +455,11 @@ bool cli_send_nt_trans(struct cli_state *cli,
 
                        show_msg(cli->outbuf);
 
-                       client_set_trans_sign_state_off(cli, mid);
                        cli->mid = mid;
                        if (!cli_send_smb(cli)) {
+                               cli_state_seqnum_remove(cli, mid);
                                return False;
                        }
-                       client_set_trans_sign_state_on(cli, mid);
 
                        tot_data += this_ldata;
                        tot_param += this_lparam;
@@ -483,10 +483,14 @@ bool cli_receive_nt_trans(struct cli_state *cli,
        uint8 eclass;
        uint32 ecode;
        bool ret = False;
+       uint16_t mid;
 
        *data_len = *param_len = 0;
 
+       mid = SVAL(cli->inbuf,smb_mid);
+
        if (!cli_receive_smb(cli)) {
+               cli_state_seqnum_remove(cli, mid);
                return False;
        }
 
@@ -496,6 +500,7 @@ bool cli_receive_nt_trans(struct cli_state *cli,
        if (CVAL(cli->inbuf,smb_com) != SMBnttrans) {
                DEBUG(0,("Expected SMBnttrans response, got command 0x%02x\n",
                         CVAL(cli->inbuf,smb_com)));
+               cli_state_seqnum_remove(cli, mid);
                return(False);
        }
 
@@ -669,6 +674,8 @@ bool cli_receive_nt_trans(struct cli_state *cli,
 
   out:
 
+       cli_state_seqnum_remove(cli, mid);
+
        if (ret) {
                /* Ensure the last 2 bytes of param and data are 2 null
                 * bytes. These are malloc'ed, but not included in any
@@ -682,7 +689,6 @@ bool cli_receive_nt_trans(struct cli_state *cli,
                }
        }
 
-       client_set_trans_sign_state_off(cli, SVAL(cli->inbuf,smb_mid));
        return ret;
 }
 
@@ -696,6 +702,7 @@ struct cli_trans_state {
        struct event_context *ev;
        uint8_t cmd;
        uint16_t mid;
+       uint32_t seqnum;
        const char *pipe_name;
        uint16_t fid;
        uint16_t function;
@@ -919,6 +926,7 @@ static struct async_req *cli_ship_trans(TALLOC_CTX *mem_ctx,
                cli_req = talloc_get_type_abort(result->private_data,
                                                struct cli_request);
                state->mid = cli_req->mid;
+               state->seqnum = cli_req->seqnum;
        } else {
                uint16_t num_bytes = talloc_get_size(bytes);
                /*
@@ -939,12 +947,10 @@ static struct async_req *cli_ship_trans(TALLOC_CTX *mem_ctx,
                cli_req->recv_helper.fn = cli_trans_recv_helper;
                cli_req->recv_helper.priv = state;
                cli_req->mid = state->mid;
-               client_set_trans_sign_state_off(state->cli, state->mid);
                cli_chain_uncork(state->cli);
+               state->seqnum = cli_req->seqnum;
        }
 
-       client_set_trans_sign_state_on(state->cli, state->mid);
-
  fail:
        TALLOC_FREE(frame);
        return result;
@@ -953,6 +959,8 @@ static struct async_req *cli_ship_trans(TALLOC_CTX *mem_ctx,
 static void cli_trans_ship_rest(struct async_req *req,
                                struct cli_trans_state *state)
 {
+       struct cli_request *cli_req;
+
        state->secondary_request_ctx = talloc_new(state);
        if (state->secondary_request_ctx == NULL) {
                async_req_nterror(req, NT_STATUS_NO_MEMORY);
@@ -961,14 +969,19 @@ static void cli_trans_ship_rest(struct async_req *req,
 
        while ((state->param_sent < state->num_param)
               || (state->data_sent < state->num_data)) {
-               struct async_req *cli_req;
+               struct async_req *subreq;
 
-               cli_req = cli_ship_trans(state->secondary_request_ctx, state);
-               if (cli_req == NULL) {
+               subreq = cli_ship_trans(state->secondary_request_ctx, state);
+               if (subreq == NULL) {
                        async_req_nterror(req, NT_STATUS_NO_MEMORY);
                        return;
                }
        }
+
+       cli_req = talloc_get_type_abort(req->private_data,
+                                       struct cli_request);
+
+       cli_req->seqnum = state->seqnum;
 }
 
 static NTSTATUS cli_pull_trans(struct async_req *req,
@@ -1174,7 +1187,6 @@ static void cli_trans_recv_helper(struct async_req *req)
 
        if ((state->rparam.total == state->rparam.received)
            && (state->rdata.total == state->rdata.received)) {
-               client_set_trans_sign_state_off(state->cli, state->mid);
                async_req_done(req);
        }
 }