// Copyright 2023 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <chrono>
#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "libspu/core/context.h"
#include "libspu/device/symbol_table.h"
#include "yacl/link/link.h"
#include "engine/datasource/datasource_adaptor_mgr.h"
#include "engine/datasource/router.h"
#include "engine/framework/party_info.h"
#include "engine/framework/tensor_table.h"
#include "engine/util/logging.h"
#include "engine/util/psi_detail_logger.h"
#include "api/common.pb.h"
#include "api/engine.pb.h"
namespace scql::engine {
// The normal state transition process is:
// INITIALIZED -> RUNNING -> SUCCEEDED/FAILED
// When the user manually terminates the query, the transition process will be:
// INITIALIZED -> RUNNING -> ABORTING -> FAILED
enum class SessionState {
INITIALIZED = 0,
RUNNING = 1,
ABORTING = 2,
SUCCEEDED = 3,
FAILED = 4,
};
pb::JobState ConvertSessionStateToJobState(SessionState state);
struct LinkConfig {
uint32_t link_recv_timeout_ms = 30 * 1000; // 30s
size_t link_throttle_window_size = 0;
size_t link_chunked_send_parallel_size = 8;
yacl::link::RetryOptions link_retry_options;
size_t http_max_payload_size = 1024 * 1024; // 1M
};
struct PsiConfig {
// if the value here is 0, it would use the gflags config instead
int64_t unbalance_psi_ratio_threshold = 0;
int64_t unbalance_psi_larger_party_rows_count_threshold = 0;
int32_t psi_curve_type = 0;
};
struct LogConfig {
bool enable_session_logger_separation = false;
};
struct SessionOptions {
util::LogOptions log_options;
LinkConfig link_config;
PsiConfig psi_config;
LogConfig log_config;
};
struct StreamingOptions {
std::filesystem::path dump_file_dir;
bool batched;
// if row num is less than this threshold, close streaming mode and keep all
// data in memory
size_t streaming_row_num_threshold;
// if working in streaming mode, max row num in one batch
size_t batch_row_num;
};
/// @brief Session holds everything needed to run the execution plan.
class Session {
public:
explicit Session(const SessionOptions& session_opt,
const pb::JobStartParams& params,
pb::DebugOptions debug_opts,
yacl::link::ILinkFactory* link_factory, Router* router,
DatasourceAdaptorMgr* ds_mgr,
const std::vector<spu::ProtocolKind>& allowed_spu_protocols);
~Session();
/// @return session id
std::string Id() const { return id_; }
std::string TimeZone() const { return time_zone_; }
const std::string& SelfPartyCode() const { return parties_.SelfPartyCode(); }
size_t SelfRank() const { return parties_.SelfRank(); }
// each session has its own logger, it may contain session id in each log
// message.
std::shared_ptr<spdlog::logger> GetLogger() const { return logger_; }
std::shared_ptr<util::PsiDetailLogger> GetPsiLogger() const {
return psi_logger_;
}
Router* GetRouter() const { return router_; }
DatasourceAdaptorMgr* GetDatasourceAdaptorMgr() const { return ds_mgr_; }
ssize_t GetPartyRank(const std::string& party_code) const {
return parties_.GetRank(party_code);
}
std::shared_ptr<yacl::link::Context> GetLink() const { return lctx_; }
TensorTable* GetTensorTable() const { return tensor_table_.get(); }
spu::SPUContext* GetSpuContext() const { return spu_ctx_.get(); }
spu::device::SymbolTable* GetDeviceSymbols() { return &device_symbols_; }
SessionState GetState() const { return state_.load(); }
void SetState(SessionState new_state) { state_.store(new_state); }
// compare and swap state_ to avoid race condition
bool CASState(SessionState old_state, SessionState new_state) {
return state_.compare_exchange_strong(old_state, new_state);
}
std::chrono::time_point<std::chrono::system_clock> GetStartTime() const {
return start_time_;
}
void SetAffectedRows(int64_t affected_rows) {
affected_rows_ = affected_rows;
}
int64_t GetAffectedRows() const { return affected_rows_; }
void SetNodesCount(int32_t nodes_count) { nodes_count_ = nodes_count; }
int32_t GetNodesCount() const { return nodes_count_; }
void IncExecutedNodes() { ++executed_nodes_; }
void SetExecutedNodes(int32_t executed_nodes) {
executed_nodes_ = executed_nodes;
}
int32_t GetExecutedNodes() const { return executed_nodes_; }
auto GetCurrentNodeInfo() {
std::lock_guard<std::mutex> guard(progress_mutex_);
return std::make_pair(node_start_time_, current_node_name_);
}
void SetCurrentNodeInfo(
std::chrono::time_point<std::chrono::system_clock> start_time,
const std::string& name) {
std::lock_guard<std::mutex> guard(progress_mutex_);
node_start_time_ = start_time;
current_node_name_ = name;
}
void StorePeerError(const std::string& party_code, const pb::Status& status) {
std::lock_guard<std::mutex> guard(peer_error_mutex_);
peer_errors_.emplace_back(party_code, status);
}
std::vector<std::pair<std::string, pb::Status>> GetPeerErrors() const {
std::lock_guard<std::mutex> guard(peer_error_mutex_);
return peer_errors_;
}
void AddPublishResult(std::shared_ptr<pb::Tensor> pb) {
publish_results_.emplace_back(std::move(pb));
}
std::vector<std::shared_ptr<pb::Tensor>> GetPublishResults() const {
return publish_results_;
}
std::shared_ptr<const yacl::link::Statistics> GetLinkStats() const {
return lctx_->GetStats();
}
void MergeDeviceSymbolsFrom(const spu::device::SymbolTable& other);
TensorPtr StringToHash(const Tensor& string_tensor);
TensorPtr HashToString(const Tensor& hash_tensor);
using RefNums = std::vector<std::tuple<std::string, int>>;
// set origin ref num
void UpdateRefName(const std::vector<std::string>& input_ref_names,
const RefNums& output_ref_nums);
void DelTensors(const std::vector<std::string>& tensor_names);
const SessionOptions& GetSessionOptions() const { return session_opt_; }
StreamingOptions GetStreamingOptions() { return streaming_options_; }
void SetStreamingOptions(const StreamingOptions& streaming_options) {
streaming_options_ = streaming_options;
}
void EnableStreamingBatched();
private:
void InitLink();
bool ValidateSPUContext();
private:
const std::string id_;
const SessionOptions session_opt_;
const std::string time_zone_;
PartyInfo parties_;
std::atomic<SessionState> state_;
std::chrono::time_point<std::chrono::system_clock> start_time_;
yacl::link::ILinkFactory* link_factory_;
std::shared_ptr<spdlog::logger> logger_; // session's own logger
Router* router_;
DatasourceAdaptorMgr* ds_mgr_;
// private (plaintext) tensors
std::unique_ptr<TensorTable> tensor_table_;
std::shared_ptr<yacl::link::Context> lctx_;
std::unique_ptr<spu::SPUContext> spu_ctx_; // SPUContext
spu::device::SymbolTable device_symbols_; // spu device symbols table
absl::flat_hash_map<size_t, std::string> hash_to_string_values_;
std::vector<std::shared_ptr<pb::Tensor>> publish_results_;
int64_t affected_rows_ = 0;
mutable std::mutex mutex_;
absl::flat_hash_map<std::string, int> tensor_ref_nums_;
mutable std::mutex peer_error_mutex_;
std::vector<std::pair<std::string, pb::Status>> peer_errors_;
std::shared_ptr<util::
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
安全协作查询语言 (SCQL) 是一种将 SQL 语句转换为安全多方计算 (SMC) 基元并在数据库系统联合上执行它们的系统 (1014个子文件)
make.bat 765B
BUILD.bazel 23KB
BUILD.bazel 8KB
BUILD.bazel 5KB
BUILD.bazel 4KB
BUILD.bazel 4KB
BUILD.bazel 3KB
BUILD.bazel 3KB
BUILD.bazel 3KB
BUILD.bazel 2KB
BUILD.bazel 2KB
BUILD.bazel 2KB
BUILD.bazel 1KB
BUILD.bazel 1KB
BUILD.bazel 1KB
BUILD.bazel 582B
BUILD.bazel 581B
BUILD.bazel 581B
.bazelrc 1KB
.bazelversion 5B
arrow.BUILD 8KB
aws_sdk_cpp.BUILD 8KB
poco.BUILD 2KB
thrift.BUILD 2KB
mysql.BUILD 2KB
curl.BUILD 2KB
duckdb.BUILD 2KB
aws_checksums.BUILD 2KB
aws_c_common.BUILD 1KB
xsimd.BUILD 1KB
ncurses.BUILD 1KB
brotli.BUILD 1KB
snappy.BUILD 1KB
postgres.BUILD 1KB
lz4.BUILD 1KB
bzip2.BUILD 1KB
aws_c_event_stream.BUILD 1015B
double-conversion.BUILD 945B
sqlite3.BUILD 921B
rapidjson.BUILD 917B
gperftools.BUILD 754B
engine_deps.bzl 14KB
curl.bzl 6KB
repositories.bzl 1KB
join_test.cc 35KB
compare_test.cc 29KB
engine_service_impl.cc 27KB
engine_service_impl_test.cc 26KB
psi_helper.cc 25KB
case_when_test.cc 20KB
join.cc 20KB
in.cc 19KB
in_test.cc 17KB
main.cc 16KB
group_agg_test.cc 16KB
if_test.cc 16KB
oblivious_group_agg_test.cc 14KB
reduce_test.cc 14KB
session.cc 14KB
arithmetic_test.cc 13KB
audit_log.cc 13KB
bucket.cc 12KB
logical_test.cc 11KB
oblivious_group_agg.cc 11KB
test_util.cc 11KB
session_manager.cc 11KB
group_he_sum.cc 11KB
trigonometric_test.cc 11KB
duckdb_wrapper.cc 10KB
dump_file.cc 10KB
window.cc 10KB
case_when.cc 9KB
window_test.cc 9KB
broadcast_to_test.cc 9KB
limit_test.cc 8KB
mux_link_factory.cc 8KB
group_agg.cc 8KB
csvdb_adaptor_test.cc 8KB
insert_table.cc 8KB
arithmetic.cc 8KB
read_write_bench.cc 8KB
filter_test.cc 8KB
dump_file_test.cc 7KB
spu_io.cc 7KB
csvdb_adaptor.cc 7KB
pipeline.cc 7KB
reduce.cc 7KB
cast_test.cc 7KB
duckdb_wrapper_test.cc 7KB
sort_test.cc 7KB
copy.cc 7KB
filter.cc 7KB
flags.cc 7KB
kuscia_datamesh_router.cc 7KB
logical.cc 7KB
bucket_test.cc 7KB
shape_test.cc 7KB
tensor_util.cc 6KB
odbc_adaptor.cc 6KB
ndarray_to_arrow.cc 6KB
共 1014 条
- 1
- 2
- 3
- 4
- 5
- 6
- 11
资源评论
Java程序员-张凯
- 粉丝: 1w+
- 资源: 7266
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功