@@ -116,7 +116,7 @@ class ASFileSystem : public FileSystem {
116
116
117
117
Status DownloadFolder (
118
118
const std::string& container, const std::string& path,
119
- const std::string& dest);
119
+ const std::string& dest, const bool recursive );
120
120
121
121
std::shared_ptr<asb::BlobServiceClient> client_;
122
122
re2::RE2 as_regex_;
@@ -392,7 +392,7 @@ ASFileSystem::FileExists(const std::string& path, bool* exists)
392
392
Status
393
393
ASFileSystem::DownloadFolder (
394
394
const std::string& container, const std::string& path,
395
- const std::string& dest)
395
+ const std::string& dest, const bool recursive )
396
396
{
397
397
auto container_client = client_->GetBlobContainerClient (container);
398
398
auto func = [&](const std::vector<asb::Models::BlobItem>& blobs,
@@ -408,17 +408,20 @@ ASFileSystem::DownloadFolder(
408
408
" Failed to download file at " + blob_item.Name + " :" + ex.what ());
409
409
}
410
410
}
411
- for (const auto & directory_item : blob_prefixes) {
412
- const auto & local_path = JoinPath ({dest, BaseName (directory_item)});
413
- int status = mkdir (
414
- const_cast <char *>(local_path.c_str ()), S_IRUSR | S_IWUSR | S_IXUSR);
415
- if (status == -1 ) {
416
- return Status (
417
- Status::Code::INTERNAL,
418
- " Failed to create local folder: " + local_path +
419
- " , errno:" + strerror (errno));
411
+ if (recursive) {
412
+ for (const auto & directory_item : blob_prefixes) {
413
+ const auto & local_path = JoinPath ({dest, BaseName (directory_item)});
414
+ int status = mkdir (
415
+ const_cast <char *>(local_path.c_str ()), S_IRUSR | S_IWUSR | S_IXUSR);
416
+ if (status == -1 && errno != EEXIST) {
417
+ return Status (
418
+ Status::Code::INTERNAL,
419
+ " Failed to create local folder: " + local_path +
420
+ " , errno:" + strerror (errno));
421
+ }
422
+ RETURN_IF_ERROR (
423
+ DownloadFolder (container, directory_item, local_path, recursive));
420
424
}
421
- RETURN_IF_ERROR (DownloadFolder (container, directory_item, local_path));
422
425
}
423
426
return Status::Success;
424
427
};
@@ -445,21 +448,30 @@ ASFileSystem::LocalizePath(
445
448
" AS file localization not yet implemented " + path);
446
449
}
447
450
448
- std::string folder_template = " /tmp/folderXXXXXX" ;
449
- char * tmp_folder = mkdtemp (const_cast <char *>(folder_template.c_str ()));
450
- if (tmp_folder == nullptr ) {
451
- return Status (
452
- Status::Code::INTERNAL,
453
- " Failed to create local temp folder: " + folder_template +
454
- " , errno:" + strerror (errno));
451
+ // Create a local directory for s3 model store.
452
+ // If `mount_dir` or ENV variable are not set,
453
+ // creates a temporary directory under `/tmp` with the format: "folderXXXXXX".
454
+ // Otherwise, will create a folder under specified directory with the name
455
+ // indicated in path (i.e. everything after the last encounter of `/`).
456
+ const char * env_mount_dir = std::getenv (" TRITON_AZURE_MOUNT_DIRECTORY" );
457
+ std::string tmp_folder;
458
+ if (mount_dir.empty () && env_mount_dir == nullptr ) {
459
+ RETURN_IF_ERROR (triton::core::MakeTemporaryDirectory (
460
+ FileSystemType::LOCAL, &tmp_folder));
461
+ } else {
462
+ tmp_folder = mount_dir.empty () ? std::string (env_mount_dir) : mount_dir;
463
+ tmp_folder =
464
+ JoinPath ({tmp_folder, path.substr (path.find_last_of (' /' ) + 1 )});
465
+ RETURN_IF_ERROR (triton::core::MakeDirectory (
466
+ tmp_folder, true /* recursive*/ , true /* allow_dir_exist*/ ));
455
467
}
456
- localized->reset (new LocalizedPath (path, tmp_folder));
457
468
458
- std::string dest (folder_template );
469
+ localized-> reset ( new LocalizedPath (path, tmp_folder) );
459
470
471
+ std::string dest (tmp_folder);
460
472
std::string container, blob;
461
473
RETURN_IF_ERROR (ParsePath (path, &container, &blob));
462
- return DownloadFolder (container, blob, dest);
474
+ return DownloadFolder (container, blob, dest, recursive );
463
475
}
464
476
465
477
Status
0 commit comments