s3: Refactor smbd_smb2_request_process_negprot
[metze/samba/wip.git] / source3 / smbd / smb2_negprot.c
index d086566d5cc6adc797fa5a5ea3e9af68cd359c39..5fa1fbbe008bdbf90e5f746e97925ea246e7e9ca 100644 (file)
@@ -25,6 +25,8 @@
 #include "../lib/tsocket/tsocket.h"
 #include "../librpc/ndr/libndr.h"
 
+extern fstring remote_proto;
+
 /*
  * this is the entry point if SMB2 is selected via
  * the SMB negprot and the given dialect.
@@ -80,139 +82,153 @@ void reply_smb20ff(struct smb_request *req, uint16_t choice)
        reply_smb20xx(req, SMB2_DIALECT_REVISION_2FF);
 }
 
-NTSTATUS smbd_smb2_request_process_negprot(struct smbd_smb2_request *req)
+enum protocol_types smbd_smb2_protocol_dialect_match(const uint8_t *indyn,
+                               const int dialect_count,
+                               uint16_t *dialect)
 {
-       NTSTATUS status;
-       const uint8_t *inbody;
-       const uint8_t *indyn = NULL;
-       DATA_BLOB outbody;
-       DATA_BLOB outdyn;
-       DATA_BLOB negprot_spnego_blob;
-       uint16_t security_offset;
-       DATA_BLOB security_buffer;
-       size_t expected_dyn_size = 0;
-       size_t c;
-       uint16_t security_mode;
-       uint16_t dialect_count;
-       uint16_t in_security_mode;
-       uint32_t in_capabilities;
-       DATA_BLOB in_guid_blob;
-       struct GUID in_guid;
-       uint16_t dialect = 0;
-       uint32_t capabilities;
-       DATA_BLOB out_guid_blob;
-       struct GUID out_guid;
+       size_t c = 0;
        enum protocol_types protocol = PROTOCOL_NONE;
-       uint32_t max_limit;
-       uint32_t max_trans = lp_smb2_max_trans();
-       uint32_t max_read = lp_smb2_max_read();
-       uint32_t max_write = lp_smb2_max_write();
-       NTTIME now = timeval_to_nttime(&req->request_time);
-
-       status = smbd_smb2_request_verify_sizes(req, 0x24);
-       if (!NT_STATUS_IS_OK(status)) {
-               return smbd_smb2_request_error(req, status);
-       }
-       inbody = SMBD_SMB2_IN_BODY_PTR(req);
-
-       dialect_count = SVAL(inbody, 0x02);
-
-       in_security_mode = SVAL(inbody, 0x04);
-       in_capabilities = IVAL(inbody, 0x08);
-       in_guid_blob = data_blob_const(inbody + 0x0C, 16);
-
-       if (dialect_count == 0) {
-               return smbd_smb2_request_error(req, NT_STATUS_INVALID_PARAMETER);
-       }
-
-       status = GUID_from_ndr_blob(&in_guid_blob, &in_guid);
-       if (!NT_STATUS_IS_OK(status)) {
-               return smbd_smb2_request_error(req, status);
-       }
-
-       expected_dyn_size = dialect_count * 2;
-       if (SMBD_SMB2_IN_DYN_LEN(req) < expected_dyn_size) {
-               return smbd_smb2_request_error(req, NT_STATUS_INVALID_PARAMETER);
-       }
-       indyn = SMBD_SMB2_IN_DYN_PTR(req);
 
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB3_00) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB3_00) {
                        break;
                }
-               if (lp_srv_minprotocol() > PROTOCOL_SMB3_00) {
+               if (lp_server_min_protocol() > PROTOCOL_SMB3_00) {
                        break;
                }
 
-               dialect = SVAL(indyn, c*2);
-               if (dialect == SMB3_DIALECT_REVISION_300) {
+               *dialect = SVAL(indyn, c*2);
+               if (*dialect == SMB3_DIALECT_REVISION_300) {
                        protocol = PROTOCOL_SMB3_00;
                        break;
                }
        }
 
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB2_24) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB2_24) {
                        break;
                }
-               if (lp_srv_minprotocol() > PROTOCOL_SMB2_24) {
+               if (lp_server_min_protocol() > PROTOCOL_SMB2_24) {
                        break;
                }
 
-               dialect = SVAL(indyn, c*2);
-               if (dialect == SMB2_DIALECT_REVISION_224) {
+               *dialect = SVAL(indyn, c*2);
+               if (*dialect == SMB2_DIALECT_REVISION_224) {
                        protocol = PROTOCOL_SMB2_24;
                        break;
                }
        }
 
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB2_22) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB2_22) {
                        break;
                }
-               if (lp_srv_minprotocol() > PROTOCOL_SMB2_22) {
+               if (lp_server_min_protocol() > PROTOCOL_SMB2_22) {
                        break;
                }
 
-               dialect = SVAL(indyn, c*2);
-               if (dialect == SMB2_DIALECT_REVISION_222) {
+               *dialect = SVAL(indyn, c*2);
+               if (*dialect == SMB2_DIALECT_REVISION_222) {
                        protocol = PROTOCOL_SMB2_22;
                        break;
                }
        }
 
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB2_10) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB2_10) {
                        break;
                }
-               if (lp_srv_minprotocol() > PROTOCOL_SMB2_10) {
+               if (lp_server_min_protocol() > PROTOCOL_SMB2_10) {
                        break;
                }
 
-               dialect = SVAL(indyn, c*2);
-               if (dialect == SMB2_DIALECT_REVISION_210) {
+               *dialect = SVAL(indyn, c*2);
+               if (*dialect == SMB2_DIALECT_REVISION_210) {
                        protocol = PROTOCOL_SMB2_10;
                        break;
                }
        }
 
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB2_02) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB2_02) {
                        break;
                }
-               if (lp_srv_minprotocol() > PROTOCOL_SMB2_02) {
+               if (lp_server_min_protocol() > PROTOCOL_SMB2_02) {
                        break;
                }
 
-               dialect = SVAL(indyn, c*2);
-               if (dialect == SMB2_DIALECT_REVISION_202) {
+               *dialect = SVAL(indyn, c*2);
+               if (*dialect == SMB2_DIALECT_REVISION_202) {
                        protocol = PROTOCOL_SMB2_02;
                        break;
                }
        }
 
+       return protocol;
+}
+
+NTSTATUS smbd_smb2_request_process_negprot(struct smbd_smb2_request *req)
+{
+       NTSTATUS status;
+       const uint8_t *inbody;
+       const uint8_t *indyn = NULL;
+       DATA_BLOB outbody;
+       DATA_BLOB outdyn;
+       DATA_BLOB negprot_spnego_blob;
+       uint16_t security_offset;
+       DATA_BLOB security_buffer;
+       size_t expected_dyn_size = 0;
+       size_t c;
+       uint16_t security_mode;
+       uint16_t dialect_count;
+       uint16_t in_security_mode;
+       uint32_t in_capabilities;
+       DATA_BLOB in_guid_blob;
+       struct GUID in_guid;
+       uint16_t dialect = 0;
+       uint32_t capabilities;
+       DATA_BLOB out_guid_blob;
+       struct GUID out_guid;
+       enum protocol_types protocol = PROTOCOL_NONE;
+       uint32_t max_limit;
+       uint32_t max_trans = lp_smb2_max_trans();
+       uint32_t max_read = lp_smb2_max_read();
+       uint32_t max_write = lp_smb2_max_write();
+       NTTIME now = timeval_to_nttime(&req->request_time);
+
+       status = smbd_smb2_request_verify_sizes(req, 0x24);
+       if (!NT_STATUS_IS_OK(status)) {
+               return smbd_smb2_request_error(req, status);
+       }
+       inbody = SMBD_SMB2_IN_BODY_PTR(req);
+
+       dialect_count = SVAL(inbody, 0x02);
+
+       in_security_mode = SVAL(inbody, 0x04);
+       in_capabilities = IVAL(inbody, 0x08);
+       in_guid_blob = data_blob_const(inbody + 0x0C, 16);
+
+       if (dialect_count == 0) {
+               return smbd_smb2_request_error(req, NT_STATUS_INVALID_PARAMETER);
+       }
+
+       status = GUID_from_ndr_blob(&in_guid_blob, &in_guid);
+       if (!NT_STATUS_IS_OK(status)) {
+               return smbd_smb2_request_error(req, status);
+       }
+
+       expected_dyn_size = dialect_count * 2;
+       if (SMBD_SMB2_IN_DYN_LEN(req) < expected_dyn_size) {
+               return smbd_smb2_request_error(req, NT_STATUS_INVALID_PARAMETER);
+       }
+       indyn = SMBD_SMB2_IN_DYN_PTR(req);
+
+       protocol = smbd_smb2_protocol_dialect_match(indyn,
+                                       dialect_count,
+                                       &dialect);
+
        for (c=0; protocol == PROTOCOL_NONE && c < dialect_count; c++) {
-               if (lp_srv_maxprotocol() < PROTOCOL_SMB2_10) {
+               if (lp_server_max_protocol() < PROTOCOL_SMB2_10) {
                        break;
                }
 
@@ -234,6 +250,12 @@ NTSTATUS smbd_smb2_request_process_negprot(struct smbd_smb2_request *req)
                set_remote_arch(RA_VISTA);
        }
 
+       fstr_sprintf(remote_proto, "SMB%X_%02X",
+                    (dialect >> 8) & 0xFF, dialect & 0xFF);
+
+       reload_services(req->sconn, conn_snum_used, true);
+       DEBUG(3,("Selected protocol %s\n", remote_proto));
+
        /* negprot_spnego() returns a the server guid in the first 16 bytes */
        negprot_spnego_blob = negprot_spnego(req, req->sconn);
        if (negprot_spnego_blob.data == NULL) {
@@ -254,6 +276,12 @@ NTSTATUS smbd_smb2_request_process_negprot(struct smbd_smb2_request *req)
                capabilities |= SMB2_CAP_DFS;
        }
 
+       if ((protocol >= PROTOCOL_SMB2_24) &&
+           (lp_smb_encrypt(-1) != SMB_SIGNING_OFF) &&
+           (in_capabilities & SMB2_CAP_ENCRYPTION)) {
+               capabilities |= SMB2_CAP_ENCRYPTION;
+       }
+
        /*
         * 0x10000 (65536) is the maximum allowed message size
         * for SMB 2.0
@@ -306,7 +334,7 @@ NTSTATUS smbd_smb2_request_process_negprot(struct smbd_smb2_request *req)
                return smbd_smb2_request_error(req, status);
        }
 
-       outbody = data_blob_talloc(req->out.vector, NULL, 0x40);
+       outbody = smbd_smb2_generate_outbody(req, 0x40);
        if (outbody.data == NULL) {
                return smbd_smb2_request_error(req, NT_STATUS_NO_MEMORY);
        }