Skip to content

Commit 6903e1c

Browse files
committed
Apply similar changes to AZURE
1 parent 2bd9a25 commit 6903e1c

File tree

1 file changed

+34
-22
lines changed
  • src/filesystem/implementations

1 file changed

+34
-22
lines changed

src/filesystem/implementations/as.h

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ASFileSystem : public FileSystem {
116116

117117
Status DownloadFolder(
118118
const std::string& container, const std::string& path,
119-
const std::string& dest);
119+
const std::string& dest, const bool recursive);
120120

121121
std::shared_ptr<asb::BlobServiceClient> client_;
122122
re2::RE2 as_regex_;
@@ -392,7 +392,7 @@ ASFileSystem::FileExists(const std::string& path, bool* exists)
392392
Status
393393
ASFileSystem::DownloadFolder(
394394
const std::string& container, const std::string& path,
395-
const std::string& dest)
395+
const std::string& dest, const bool recursive)
396396
{
397397
auto container_client = client_->GetBlobContainerClient(container);
398398
auto func = [&](const std::vector<asb::Models::BlobItem>& blobs,
@@ -408,17 +408,20 @@ ASFileSystem::DownloadFolder(
408408
"Failed to download file at " + blob_item.Name + ":" + ex.what());
409409
}
410410
}
411-
for (const auto& directory_item : blob_prefixes) {
412-
const auto& local_path = JoinPath({dest, BaseName(directory_item)});
413-
int status = mkdir(
414-
const_cast<char*>(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR);
415-
if (status == -1) {
416-
return Status(
417-
Status::Code::INTERNAL,
418-
"Failed to create local folder: " + local_path +
419-
", errno:" + strerror(errno));
411+
if (recursive) {
412+
for (const auto& directory_item : blob_prefixes) {
413+
const auto& local_path = JoinPath({dest, BaseName(directory_item)});
414+
int status = mkdir(
415+
const_cast<char*>(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR);
416+
if (status == -1 && errno != EEXIST) {
417+
return Status(
418+
Status::Code::INTERNAL,
419+
"Failed to create local folder: " + local_path +
420+
", errno:" + strerror(errno));
421+
}
422+
RETURN_IF_ERROR(
423+
DownloadFolder(container, directory_item, local_path, recursive));
420424
}
421-
RETURN_IF_ERROR(DownloadFolder(container, directory_item, local_path));
422425
}
423426
return Status::Success;
424427
};
@@ -445,21 +448,30 @@ ASFileSystem::LocalizePath(
445448
"AS file localization not yet implemented " + path);
446449
}
447450

448-
std::string folder_template = "/tmp/folderXXXXXX";
449-
char* tmp_folder = mkdtemp(const_cast<char*>(folder_template.c_str()));
450-
if (tmp_folder == nullptr) {
451-
return Status(
452-
Status::Code::INTERNAL,
453-
"Failed to create local temp folder: " + folder_template +
454-
", errno:" + strerror(errno));
451+
// Create a local directory for s3 model store.
452+
// If `mount_dir` or ENV variable are not set,
453+
// creates a temporary directory under `/tmp` with the format: "folderXXXXXX".
454+
// Otherwise, will create a folder under specified directory with the name
455+
// indicated in path (i.e. everything after the last encounter of `/`).
456+
const char* env_mount_dir = std::getenv("TRITON_AZURE_MOUNT_DIRECTORY");
457+
std::string tmp_folder;
458+
if (mount_dir.empty() && env_mount_dir == nullptr) {
459+
RETURN_IF_ERROR(triton::core::MakeTemporaryDirectory(
460+
FileSystemType::LOCAL, &tmp_folder));
461+
} else {
462+
tmp_folder = mount_dir.empty() ? std::string(env_mount_dir) : mount_dir;
463+
tmp_folder =
464+
JoinPath({tmp_folder, path.substr(path.find_last_of('/') + 1)});
465+
RETURN_IF_ERROR(triton::core::MakeDirectory(
466+
tmp_folder, true /*recursive*/, true /*allow_dir_exist*/));
455467
}
456-
localized->reset(new LocalizedPath(path, tmp_folder));
457468

458-
std::string dest(folder_template);
469+
localized->reset(new LocalizedPath(path, tmp_folder));
459470

471+
std::string dest(tmp_folder);
460472
std::string container, blob;
461473
RETURN_IF_ERROR(ParsePath(path, &container, &blob));
462-
return DownloadFolder(container, blob, dest);
474+
return DownloadFolder(container, blob, dest, recursive);
463475
}
464476

465477
Status

0 commit comments

Comments
 (0)