# Copyright 2022 The XFL Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import pytest
import json
from multiprocess.pool import ApplyResult
import pandas as pd
import numpy as np
import service.fed_config
from service.fed_config import FedConfig
from algorithm.core.paillier_acceleration import embed
from algorithm.core.tree.xgboost_loss import get_xgb_loss_inst
from common.communication.gRPC.python.channel import BroadcastChannel, DualChannel
from common.communication.gRPC.python.commu import Commu
from common.crypto.paillier.paillier import Paillier
from algorithm.core.tree.tree_structure import Node
from algorithm.framework.vertical.xgboost.label_trainer import VerticalXgboostLabelTrainer
from algorithm.framework.vertical.xgboost.trainer import VerticalXgboostTrainer
from algorithm.framework.vertical.xgboost.decision_tree_label_trainer import VerticalDecisionTreeLabelTrainer
from algorithm.framework.vertical.xgboost.decision_tree_trainer import VerticalDecisionTreeTrainer
@pytest.fixture(scope='module', autouse=True)
def prepare_data(tmp_factory):
df = pd.DataFrame({
"x0": np.random.random(200),
# np.round(np.random.random(200) * 10.0),
"x1": np.random.randint(0, 10, 200),
"x2": np.random.uniform(200) * 2.0,
"x3": np.random.random(200) * 3.0,
"x4": np.random.randint(0, 10, 200), # np.arange(0, 200, 1),
'y': np.round(np.random.random(200))
})
df[['y', 'x0', 'x1', 'x2']].head(120).to_csv(
tmp_factory.join("train_guest.csv"), index=True, index_label='id'
)
df[['y', 'x0', 'x1', 'x2']].tail(80).to_csv(
tmp_factory.join("test_guest.csv"), index=True, index_label='id'
)
df[['x3', 'x4']].head(120).to_csv(
tmp_factory.join("train_host.csv"), index=True, index_label='id'
)
df[['x3', 'x4']].tail(80).to_csv(
tmp_factory.join("test_host.csv"), index=True, index_label='id'
)
Commu.node_id = "node-1"
Commu.trainer_ids = ['node-1', 'node-2']
class TestVerticalXGBoost:
@pytest.mark.parametrize('feature_index', [(1), (0)])
def test_decision_tree_trainer(self, mocker, tmp_factory, feature_index):
with open("python/algorithm/config/vertical_xgboost/trainer.json") as f:
conf = json.load(f)
conf["input"]["trainset"][0]["path"] = str(tmp_factory)
conf["input"]["trainset"][0]["name"] = "train_host.csv"
conf["input"]["valset"][0]["path"] = str(tmp_factory)
conf["input"]["valset"][0]["name"] = "test_host.csv"
del conf["input"]["testset"]
conf["output"]["path"] = str(tmp_factory)
# if conf["train_info"]["train_params"]["downsampling"]["row"]["run_goss"]:
# conf["train_info"]["train_params"]["downsampling"]["row"]["top_rate"] = 0.5
# conf["train_info"]["train_params"]["downsampling"]["row"]["other_rate"] = 0.5
conf["train_info"]["train_params"]["category"]["cat_features"]["col_index"] = "1"
conf["train_info"]["train_params"]["advanced"]["col_batch"] = 1
conf["train_info"]["train_params"]["advanced"]["row_batch"] = 1
# mocker channels in VerticalXgboostTrainer.__init__
mocker.patch.object(
DualChannel, "__init__", return_value=None
)
mocker.patch.object(
BroadcastChannel, "send", return_value=None
)
mocker.patch.object(
DualChannel, "send", return_value=None
)
def mock_func(*args, **kwargs):
"""
mock encryption keys
Args:
*args:
**kwargs:
Returns:
the paillier context
"""
config = {
"train_info": {
"interaction_params": {
"save_frequency": -1,
"echo_training_metrics": True,
"write_training_prediction": True,
"write_validation_prediction": True
},
"train_params": {
"lossfunc": {
"BCEWithLogitsLoss": {}
},
"num_trees": 10,
"num_bins": 16,
"downsampling": {
"row": {
"run_goss": True
}
},
"encryption": {
"paillier": {
"key_bit_size": 2048,
"precision": 7,
"djn_on": True,
"parallelize_on": True
}
},
"batch_size_val": 40960
}
}
}
if mock_broadcast_recv.call_count == 1:
return config
elif mock_broadcast_recv.call_count == 2:
encryption = config["train_info"]["train_params"]["encryption"]
if "paillier" in encryption:
encryption = encryption["paillier"]
private_context = Paillier.context(
encryption["key_bit_size"], encryption["djn_on"])
return private_context.to_public().serialize()
else:
return None
mock_broadcast_recv = mocker.patch.object(
BroadcastChannel, "recv", side_effect=mock_func
)
mocker.patch.object(
service.fed_config.FedConfig, "get_label_trainer", return_value=["node-1"]
)
mocker.patch.object(
service.fed_config.FedConfig, "get_trainer", return_value=["node-2"]
)
xgb_trainer = VerticalXgboostTrainer(conf)
sampled_features, feature_id_mapping = xgb_trainer.col_sample()
cat_columns_after_sampling = list(filter(
lambda x: feature_id_mapping[x] in xgb_trainer.cat_columns, list(feature_id_mapping.keys())))
split_points_after_sampling = [
xgb_trainer.split_points[feature_id_mapping[k]] for k in feature_id_mapping.keys()]
sample_index = [2, 4, 6, 7, 8, 10]
def mock_grad_hess(*args, **kwargs):
private_context = Paillier.context(xgb_trainer.xgb_config.encryption_param.key_bit_size,
xgb_trainer.xgb_config.encryption_param.djn_on)
# grad = np.random.random(xgb_trainer.xgb_config.num_bins)
# hess = np.random.random(xgb_trainer.xgb_config.num_bins)
grad = np.random.random(len(sample_index))
hess = np.random.random(len(sample_index))
grad_hess = embed([grad, hess], interval=(1 << 128), precision=64)
enc_grad_hess = Paillier.encrypt(private_context,
data=grad_hess,
precision=0, # must be 0
obfuscation=True,
num_cores=999)
return Paillier.serialize(enc_grad_hess, compression=False)
mocker.patch.object(
BroadcastChannel, "recv", side_effect=mock_grad_hess
)
decision_tree =
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
XFL-master.zip (661个子文件)
make.bat 804B
ca-bundle.crt 203KB
ca.crt 2KB
.gitignore 1KB
trainer_config_node-1.json 4KB
trainer_config_node-1.json 4KB
trainer_config_node-1.json 4KB
trainer_config_node-1.json 4KB
label_trainer.json 4KB
trainer_config_node-1.json 4KB
trainer_config_node-1.json 4KB
trainer_config_node-3.json 4KB
trainer_config_node-2.json 4KB
trainer_config_node-2.json 4KB
label_trainer.json 3KB
assist_trainer.json 3KB
trainer.json 3KB
trainer.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-3.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-3.json 3KB
trainer.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-3.json 3KB
trainer_config_node-2.json 3KB
assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-3.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-3.json 3KB
trainer_config_node-2.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-2.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_node-1.json 3KB
trainer_config_node-3.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer.json 3KB
assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
trainer_config_assist_trainer.json 3KB
label_trainer.json 3KB
assist_trainer.json 2KB
assist_trainer.json 2KB
assist_trainer.json 2KB
assist_trainer.json 2KB
trainer.json 2KB
trainer.json 2KB
assist_trainer.json 2KB
trainer.json 2KB
trainer.json 2KB
trainer_config_node-1.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-1.json 2KB
trainer_config_node-3.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-3.json 2KB
trainer_config_node-1.json 2KB
trainer_config_node-1.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_node-2.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_node-1.json 2KB
trainer_config_node-1.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-2.json 2KB
trainer_config_node-3.json 2KB
trainer.json 2KB
trainer_config_assist_trainer.json 2KB
trainer_config_assist_trainer.json 2KB
label_trainer.json 2KB
trainer_config_node-2.json 2KB
trainer.json 2KB
trainer_config_node-1.json 2KB
label_trainer.json 2KB
共 661 条
- 1
- 2
- 3
- 4
- 5
- 6
- 7
资源评论
m0_72731342
- 粉丝: 2
- 资源: 1832
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功